This repository contains the code for the paper Bidirectional Generative Modeling Using Adversarial Gradient Estimation.
pip install -r requirements.txt
- Train a bidirectional generative model (BGM) using AGES-ALL on Stacked MNIST:
sh run_mog_age_kl.sh
- Train a BGM using AGES-KL on MoG:
sh run_stack_mnist_age_all.sh
- Train a BGM using AGES-ALL on CelebA:
sh run_celeba_age_all.sh
- Train a BGM using AGES-ALL on ImageNet:
sh run_imagenet.sh
- Train a BGM using AGES-KL with scaling clipping on CelebA:
sh run_celeba_age_kl_sc.sh
- Train a unidirectional model using AGES-ALL on CelebA:
sh run_celeba_uni_age_all.sh
This will create a directory ./results/<dataset>/<save_name> which will contain:
- model.sav: a Python distionary containing the generator, encoder, and discriminator.
- gen.png: generated images.
- recon.png: real images (odd columns) along with the reconstructions (even columns).
- log.txt: All losses computed during training.
- config.txt: training configurations.
Important arguments:
Model elements:
--latent_dim dimension of the latent variable
--prior prior distribution p_z of the latent variable
--enc_dist {deterministic, gaussian, implicit}
distribution of the encoder p_e(z|x) (default: gaussian)
--dec_dist {deterministic, gaussian, implicit}
distribution of the generator p_g(x|z) (default: deterministic)
Objective:
--div {all, kl, js, hellinger, revkl}
use which divergence as the objective of generative modeling
--unigen whether to train a unidirectional generative model (defalt: False)
--clip whether to use the scaling clipping technique (defalt: False)
--scale_lower lower bound of the scaling factor (default: 0.5)
--scale_upper upper bound of the scaling factor (default: None; use 1/scale_lower as the upper bound)
Datasets:
--dataset {celeba, cifar, imagenet, mnist, mnist_stack, mog}
name of the dataset (default: celeba)
--data_dir directory of the dataset
--image_size resolution of the image (default: 64)
The code for SAGAN architectures is based on the PyTorch implementation of SAGAN from this repository.