Skip to content

bouracha/generative_imputation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative Imputation

Dependencies

Some older versions may work. But we used the following:

  • cuda 10.1 (batch size depends on GPU memory)
  • Python 3.6.9
  • Pytorch 1.6.0
  • progress 1.5
  • Tensorboard

Get the data

Human3.6m in exponential map can be downloaded from here.

AMASS was obtained from the repo, you need to make an account.

Once downloaded the datasets should be added to the datasets folder, example below.

Example

It is necessary to also add a saved_models folder as each trained model will produce a lot of checkpoints and data as it trains. If training several models it is cleaner to have a separate folder for each of these sub-folders, so saving checkpoints to folders within a saved_models folder is hardcoded.

Training commands

To train HG-VAE as in the paper:

python3 main.py --name "HGVAE" --lr 0.0001 --warmup_time 200 --beta 0.0001 --n_epochs 500 --variational --output_variance --train_batch_size 800 --test_batch_size 800

see opt.py for all training options. By default checkpoints are saved every 10 epochs. Training may be stop, and resumed by using --start_epoch flag, for example

python3 main.py --start_epoch 31 --name "HGVAE" --lr 0.0001 --warmup_time 200 --beta 0.0001 --n_epochs 500 --variational --output_variance --train_batch_size 800 --test_batch_size 800

will start retraining from the checkpoint saved after epoch 30. We also use the start_epoch flag to select the checkpoint to use when using the trained model.

Licence

MIT

Paper

If you use our code, please cite:

IN REVIEW

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published