Skip to content

thawro/pytorch-human-pose

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Multi Person Pose Estimation with PyTorch

HigherHRNet architecture implemented and trained from scratch using ImageNet and COCO datasets. The model is trained in two steps:

  1. Classification Backbone pretraining on ImageNet dataset (ClassificationHRNet model)
  2. Human Pose model Training on COCO dataset (HigherHRNet model)

πŸ“œ Table of Contents

πŸ’» Environment

The environment management is handled with the use of poetry. To install the virtual environment:

  1. Clone the repository
git clone https://github.com/thawro/pytorch-human-pose.git
  1. Move to the repository (<project_root>)
cd pytorch-human-pose
  1. Install poetry - follow documentation

  2. Install the virtual environment and activate it (the script runs poetry install and poetry shell)

make env
  1. Create directories for training/inference purposes
make dirs

πŸŽ’ Prerequisites

NOTE: If you have installed the environment already (with make env or poetry install) you can activate it with poetry shell.

Data preparation

NOTE: The data preparation scripts use tqdm to show progress bars for file unzipping, so make sure to install and activate the Environment first.

ImageNet

  1. Download dataset from kagle
    1. Go to link
    2. Sign in to kaggle
    3. Scroll down and click "Download All" button
  2. Move downloaded imagenet-object-localization-challenge.zip file to <project_root>/data directory
  3. Run ImageNet preparation script from the <project_root> directory (it may take a while)
make imagenet

The script will unzip the downloaded imagenet-object-localization-challenge.zip file, remove it, create the ImageNet directory and move unzipped files from ILSVRC/Data/CLS-LOC directory to ImageNet directory. Then it will move the val image files to separate directories (named by wodnet labels) using this script and it will download the json mapping for ImageNet labels (from this source)

After these steps there should be a directory data/ImageNet with the following directory structure:

data/ImageNet
β”œβ”€β”€ wordnet_labels.yaml
β”œβ”€β”€ train
|   β”œβ”€β”€ n01440764
|   β”œβ”€β”€ n01443537
|   ...
|   └── n15075141
└── val
    β”œβ”€β”€ n01440764
    β”œβ”€β”€ n01443537
    ...
    └── n15075141

COCO

  1. Run COCO preparation scripts from the <project_root> directory (it may take a while)
make coco
make save_coco_annots

make coco script will create data/COCO directory, download files from the COCO website (2017 Train images [118K/18GB], 2017 Val images [5K/1GB], 2017 Test images [41K/6GB], 2017 Train/Val annotations [241MB]) to the data/COCO directory, unzip the files, move the files to images and annotations subdirectories and remove the redundant zip files. make save_coco_annots will parse COCO annotation .json files and save the per-sample annotations to .yaml files and per-sample crowd-masks (used in loss function) to .npy files.

After these steps there should be a directory data/COCO with the following directory structure:

data/COCO
β”œβ”€β”€ annotations
β”‚   β”œβ”€β”€ captions_train2017.json
β”‚   β”œβ”€β”€ captions_val2017.json
β”‚   β”œβ”€β”€ instances_train2017.json
β”‚   β”œβ”€β”€ instances_val2017.json
β”‚   β”œβ”€β”€ person_keypoints_train2017.json
β”‚   └── person_keypoints_val2017.json
└── images
    β”œβ”€β”€ test2017
    β”œβ”€β”€ train2017
    └── val2017

Virtual environment installation

Install the poetry virtual environment following Environment steps.

Checkpoints with trained models

The checkpoints are available at Google Drive:

  • hrnet_32.pt - backbone pretrained on the ImageNet
  • higher_hrnet_32.pt - pose estimation model trained on COCO

After download, place the checkpoints inside the pretrained directory.

πŸ“Š Inference

NOTE: Checkpoints must be present in pretrained directory to perform the inference.

NOTE: You must first install and activate the Environment to perform the inference.

Classification (ClassificationHRNet)

Inference using the ClassificationHRNet model trained on ImageNet dataset (1000 classes). The parameters configurable via CLI:

  • --inference.input_size - smaller edge of the image will be matched to this number (default: 256)
  • --inference.ckpt_path - checkpoint path (default: pretrained/hrnet_32.pt)

ImageNet data

NOTE: ImageNet data must be prepared to perform inference on it.

Run inference on ImageNet val split with default input_size (256)

python src/classification/bin/inference.py --mode "val"

with changed input size

python src/classification/bin/inference.py --mode "val" --inference.input_size=512

Custom data

python src/classification/bin/inference.py --mode "custom" --dirpath "data/examples/classification"

Example outputs (top-5 probs):

coyote

πŸ‘‰ more examples

fox shark whale

Human Pose (HigherHRNet)

Inference using the HigherHRNet model trained on COCO keypoints dataset (17 keypoints). The parameters configurable via CLI:

  • --inference.input_size - smaller edge of the image will be matched to this number (default: 256)
  • --inference.ckpt_path - checkpoint path (default: pretrained/higher_hrnet_32.pt)
  • --inference.det_thr - detection threshold used in grouping (default: 0.05)
  • --inference.tag_thr - associative embedding tags threshold used in grouping (default: 0.5)
  • --inference.use_flip - whether to use horizontal flip and average the results (default: False)

COCO data

NOTE: COCO data must be prepared to perform inference on it.

Run inference on COCO val split with default inference parameters

python src/keypoints/bin/inference.py --mode "val"

with changed input_size, use_flip and det_thr

python src/keypoints/bin/inference.py --mode "val" --inference.input_size=256 --inference.use_flip=True --inference.det_thr=0.1

Custom data

python src/keypoints/bin/inference.py --mode "custom" --path "data/examples/keypoints/"

Video

python src/keypoints/bin/inference.py --mode "custom" --path "data/examples/keypoints/simple_3.mp4"

Example outputs (images)

Each sample is composed of Connections plot, Associative Embeddings visualization (after grouping) and Heatmaps plot

  1. Baseball AE_baseball HM_baseball
πŸ‘‰ More examples
  1. Jump AE_jump HM_jump

  2. Park AE_park HM_park

Example outputs (videos)

Each sample with two input_sizes variants

  1. Two people (size: 256)
two_256.mp4
πŸ‘‰ More examples
  1. Two people (size: 512)
two_512.mp4
  1. Three people (size: 256)
three_256.mp4
  1. Three people (size: 512)
three_512.mp4

πŸ“‰ Training from scratch

NOTE: You must first install and activate the Environment to perform the training.

IMPORTANT: MLFlow logging is enabled by default, so before every training one must run make mlflow_server to start the server.

Most important training configurable CLI parameters (others can be checked in config python files):

  • setup.ckpt_path - Path to checkpoint file saved during training (for training resume)
  • setup.pretrained_ckpt_path - Path to checkpoint file with pretrained network weights
  • trainer.accelerator - Device for training (cpu or gpu)
  • trainer.limit_batches - How many batches are used for training. Parameter used to run a debug experiment. When limit_batches > 0, then experiment is considered as debug
  • trainer.use_DDP - Whether to run Distributed Data Parallel (DDP) training on multiple GPUs
  • trainer.sync_batchnorm - Whether to use SyncBatchnorm class for DDP training

Backbone

NOTE: ImageNet data must be prepared to train the backbone model.

Pretraining ClassificationHRNet on ImageNet

Using single GPU

python src/classification/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False

--setup.ckpt_path=None to ensure that new experiment is created, --trainer.use_DDP=False to ensure that single GPU is used

Using multiple GPUs - use torchrun

torchrun --standalone --nproc_per_node=2 src/classification/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=True

Evaluation of ClassificationHRNet on Imagenet

TODO

Human Pose

NOTE: COCO data must be prepared to train the human pose model.

NOTE: src/keypoints/datasets/coco/CocoKeypointsDataset during initialization runs its method _save_annots_to_files (only if it wasn't already executed for particular split), which parses the COCO .json annotation files and saves per-sample .yaml annotation files and .npy crowd masks (used in loss function) to data/COCO/annotations/person_keypoints_<split>/<sample_id>.yaml and data/COCO/masks/person_keypoints_<split>/<sample_id>.npy. It is executed only if annotations directory (data/COCO/annotations/person_keypoints_<split>) isn't present.

Training HigherHRNet on COCO

Using single GPU

python src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.pretrained_ckpt_path="pretrained/hrnet_32.pt"

--setup.ckpt_path=None to ensure that new experiment is created, --trainer.use_DDP=False to ensure that single GPU is used, --setup.pretrained_ckpt_path to load pretrained backbone model from hrnet_32.pt file

Using multiple GPUs - use torchrun

torchrun --standalone --nproc_per_node=2 src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=True --setup.pretrained_ckpt_path="pretrained/hrnet_32.pt"                

Evaluation of HigherHRNet on COCO (val2017)

NOTE: Before running evaluation script you must ensure that correct run_path is defined inside the script. run_path must point to the directory where training checkpoint (.pt file) and config (.yaml) files are present.

python src/keypoints/bin/eval.py

After running this script there will be a evaluation_results directory created ( inside the run_path directory) with the evaluation output files:

  • coco_output.txt - file with txt output from pycocotools (the table)
  • config.yaml - config of the evaluated run
  • val2017_results.json - json file with results (predicted keypoints coordinates)

Evaluation results obtained for inference parameters:

  • --inference.input_size=512
  • --inference.use_flip=True
Metric name Area Max Dets Metric value
Average Precision (AP) @IoU=0.50:0.95 all 20 0.673
Average Precision (AP) @IoU=0.50 all 20 0.870
Average Precision (AP) @IoU=0.75 all 20 0.733
Average Precision (AP) @IoU=0.50:0.95 medium 20 0.615
Average Precision (AP) @IoU=0.50:0.95 large 20 0.761
Average Recall (AR) @IoU=0.50:0.95 all 20 0.722
Average Recall (AR) @IoU=0.50 all 20 0.896
Average Recall (AR) @IoU=0.75 all 20 0.770
Average Recall (AR) @IoU=0.50:0.95 medium 20 0.652
Average Recall (AR) @IoU=0.50:0.95 large 20 0.819

πŸ“š Training code guide

Code structure

.
β”œβ”€β”€ data                    # datasets files
β”‚   β”œβ”€β”€ COCO                #   COCO dataset
β”‚   β”œβ”€β”€ examples            #   example inputs for inference
β”‚   └── ImageNet            #   ImageNet dataset
|
β”œβ”€β”€ experiments             # experiments configs - files needed to perform training/inference
β”‚   β”œβ”€β”€ classification      #   configs for ClassificationHRNet
β”‚   └── keypoints           #   configs for HigherHRNet
|
β”œβ”€β”€ inference_out           # directory with output from inference
β”‚   β”œβ”€β”€ classification      #   classification inference output
β”‚   └── keypoints           #   keypoints inference output
|
β”œβ”€β”€ Makefile                # Makefile for cleaner scripts using
|
β”œβ”€β”€ mlflow                  # mlflow files
β”‚   β”œβ”€β”€ artifacts           #   artifacts saved during training
β”‚   β”œβ”€β”€ mlruns.db           #   database for mlflow metrics saved during training
β”‚   └── test_experiment.py  #   script for some mlflow server testing
|
β”œβ”€β”€ poetry.lock             # file updated during poetry environment management
|
β”œβ”€β”€ pretrained              # directory with trained checkpoints
β”‚   β”œβ”€β”€ higher_hrnet_32.pt  #   HigherHRNet checkpoint - COCO human pose model
β”‚   └── hrnet_32.pt         #   ClassificationHRNet checkpoint - ImageNet classification model
|
β”œβ”€β”€ pyproject.toml          # definition of poetry environment
|
β”œβ”€β”€ README.md               # project README
|
β”œβ”€β”€ RESEARCH.md             # my sidenotes for human pose estimation task
|
β”œβ”€β”€ results                 # directory with training results/logs
β”‚   β”œβ”€β”€ classification      #   classification experiment results
β”‚   β”œβ”€β”€ debug               #   debug experiments results
β”‚   └── keypoints           #   keypoints experiment results
|
β”œβ”€β”€ scripts                 # directory with useful scripts
β”‚   β”œβ”€β”€ prepare_coco.sh     #   prepares COCO dataset - can be used without any other actions 
β”‚   β”œβ”€β”€ prepare_dirs.sh     #   creates needed directories
β”‚   β”œβ”€β”€ prepare_env.sh      #   installs and activates poetry environment
β”‚   β”œβ”€β”€ prepare_imagenet.sh #   prepares ImageNet dataset - requires ImageNet zip file to be downloaded before running
β”‚   └── run_mlflow.sh       #   runs mlflow server (locally)
|
└── src                     # project modules
    β”œβ”€β”€ base                #   base module - defines interfaces, abstract classes and useful training loops
    β”œβ”€β”€ classification      #   classification related files subclasses
    β”œβ”€β”€ keypoints           #   keypoints related files subclasses
    β”œβ”€β”€ logger              #   logging functionalities (monitoring and training loggers)
    └── utils               #   utilities functions (files loading, images manipulation, configs parsing, etc.)

Code Guide

Configs

Training and inference is parametrized using configs. Configs are defined in experiments directory using the .yaml files. .yaml files parsing is done with dataclasses tailored for this purpose.classification and keypoints configs share some custom implementations which are defined in src/base/config.py. Task specific configs are implemented in src/classification/config.py and src/keypoints/config.py. The Config dataclasses allow to overwrite the config parameters loaded from .yaml files by putting additional arguments to script calls using the following notation: --<field_name>.<nested_field_name>=<new_value>, for example:

python src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.pretrained_ckpt_path=None

overwrites the setup.ckpt_path, trainer.use_DDP and setup.pretrained_ckpt_path attributes.

The Config dataclasses are also repsonsible for creation of training and inference related objects with the use of the following methods:

  • create_net (task-specific) - create neural network object (torch.nn.Module)
  • create_datamodule (task-specific) - create datamodule (object used for loading train/val/test data into batches)
  • create_module (task-specific) - create training module (object used to handle training and validation steps)
  • create_inference_model (task-specific) - create model tailored for inference purposes
  • create_callbacks - create callbacks (objects used during the training, each with special hooks)
  • create_logger - create logger (object used for logging purposes)
  • create_trainer - create trainer (object used to manage the whole training pipeline)

Training

IMPORTANT: You must ensure that environment is active (poetry shell) and mlflow server is running (make mlflow_server) before training.

During training the results directory is being populated with useful info about runs (logs, metrics, evaluation examples, etc.). The structure of the populated results directory is the following:

results
└── <experiment_name>               # run experiment_name (e.g. classification)
    └── <run_name>                  # run run_name (e.g. 03-21_11:05__ImageNet_ClassificationHRNet)
        β”œβ”€β”€ <timestamp_1>           # run timestamp (e.g. 03-21_11:05)
        |   β”œβ”€β”€ checkpoints         # saved checkpoints
        |   β”œβ”€β”€ config.yaml         # config used for current run
        |   β”œβ”€β”€ data_examples       # examples of data produced by datasets defined in datamodule
        |   β”œβ”€β”€ epoch_metrics.html  # plots with metrics returned by module class (html)
        |   β”œβ”€β”€ epoch_metrics.jpg   # plots with metrics returned by module class (jpg)
        |   β”œβ”€β”€ epoch_metrics.yaml  # yaml with metrics
        |   β”œβ”€β”€ eval_examples       # example evaluation results (plots produced by results classes)
        |   β”œβ”€β”€ logs                # per-device logs and system monitoring metrics
        |   └── model               # model-related files (ONNX if saved, layers summary, etc.)
        └── <timestamp_2>           # resumed run timestamp (e.g. 03-22_12:10)
            β”œβ”€β”€ checkpoints
            ...
            └── model

Each training run is parametrized by yaml config. The names shown in <> are defined by:

  • setup.experiment_name define the <experiment_name> directory name,
  • sertup.run_name define the <run_name> directory name. If set to null (default), then <run_name> is generated automatically as <timestamp>__<setup.dataset>_<setup.architecture>

For each new run there is a new results directory created (defined by current timestamp). If run is resumed (same <run_name> is used), then the new subrun directory (based on timestamp) is added.

MLFlow

By default the mlflow is used as the experiments logger (local mlflow server under http://127.0.0.1:5000/ address). The runs logged in mlflow are structured a bit different than ones present in results directory. The main differences:

  • Resuming the run is equivalent to logging to the same run (no subruns directories added),
  • There is a new directory in a run artifacts called history, where logs and configs of each subrun are saved in their corresponding <timestamp> directories,
  • Resuming the run overwrites previously logged data_examples, logs, config.yaml, eval_examples and epoch_metrics artifacts.- [Multi Person Pose Estimation with PyTorch]

Training guide

NOTE: Read all previous chapters before running the commands listed below.

NOTE:
Adjust settings like:

  • --dataloader.batch_size (default: 80 for hrnet, 36 for higher_hrnet)
  • --dataloader.num_workers (default: 4 for both tasks)

to your device capabilities

Depending on what and how you would like to train the models there exist a few possibilities (listed below). All examples assume using single GPU (to train with multiple GPUs use the torchrun commands from previous chapters)

1. Only HRNet (classifier) training

First run:

python src/classification/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.experiment_name="classification_exp" --setup.run_name="only_hrnet_run"
Optionally if resuming is needed:

use checkpoint from previous run:

ckpt_path = "results/classification_exp/only_hrnet_run/<timestamp>/checkpoints/last.pt"

python src/classification/bin/train.py --setup.ckpt_path=<ckpt_path> --trainer.use_DDP=False

2. Only HigherHRNet (keypoints) training (without pre-trained HRNet backbone)

First run:

python src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.experiment_name="keypoints_exp" --setup.run_name="only_higherhrnet_run"
Optionally if resuming is needed:

use checkpoint from previous run:

ckpt_path = "results/keypoints_exp/only_higherhrnet_run/<timestamp>/checkpoints/last.pt"

python src/keypoints/bin/train.py --setup.ckpt_path=<ckpt_path> --trainer.use_DDP=False --setup.experiment_name="keypoints_exp" --setup.run_name="only_higherhrnet_run"

3. Only HigherHRNet (keypoints) training (with pre-trained HRNet backbone)

NOTE: Downloaded hrnet_32.pt checkpoint must be present in pretrained directory.

First run:

python src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.experiment_name="keypoints_exp" --setup.run_name="pretrained_higherhrnet_run" --setup.pretrained_ckpt_path="pretrained/hrnet_32.pt"
Optionally if resuming is needed:

NOTE: There is no need to pass the pretrained_ckpt_path when resuming the training since its weights were updated during training.

use checkpoint from previous run:

ckpt_path = "results/keypoints_exp/pretrained_higherhrnet_run/<timestamp>/checkpoints/last.pt"

python src/keypoints/bin/train.py --setup.ckpt_path=<ckpt_path> --trainer.use_DDP=False --setup.experiment_name="keypoints_exp" --setup.run_name="only_higherhrnet_run"

4. Complete, "from scratch" training

The complete ("from scratch") training include pretraining of ClassificationHRNet and then using it as a backbone for HRNet.

  1. Train classification model (HRNet backbone)
python src/classification/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.experiment_name="classification_exp" --setup.run_name="from_scratch_hrnet_pretrain_run"
Optionally if resuming is needed:

ckpt_path = "results/classification_exp/from_scratch_hrnet_pretrain_run/<timestamp>/checkpoints/last.pt"

python src/classification/bin/train.py --setup.ckpt_path=<ckpt_path> --trainer.use_DDP=False --setup.experiment_name="classification_exp" --setup.run_name="from_scratch_hrnet_pretrain_run"
  1. Use pretrained backbone and train HigherHRNet keypoints estimation model

pretrained_ckpt_path = "results/classification_exp/from_scratch_hrnet_pretrain_run/<timestamp>/checkpoints/last.pt"

python src/keypoints/bin/train.py --setup.ckpt_path=None --trainer.use_DDP=False --setup.experiment_name="keypoints_exp" --setup.run_name="from_scratch_pretrained_higherhrnet_run" --setup.pretrained_ckpt_path=<pretrained_ckpt_path>
Optionally if resuming is needed:

ckpt_path = "results/keypoints_exp/from_scratch_pretrained_higherhrnet_run/<timestamp>/checkpoints/last.pt"

python src/classification/bin/train.py --setup.ckpt_path=<ckpt_path> --trainer.use_DDP=False --setup.experiment_name="classification_exp" --setup.run_name="from_scratch_pretrained_higherhrnet_run"

About

PyTorch (from scratch) implementation and training of HigherHRNet.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published