Skip to content

shinyflight/SLOGAN

Repository files navigation

Stein Latent Optimization for GANs (SLOGAN)

SLOGAN_model

Stein Latent Optimization for Generative Adversarial Networks (ICLR 2022)
Uiwon Hwang, Heeseung Kim, Dahuin Jung, Hyemi Jang, Hyungyu Lee, Sungroh Yoon
Seoul National University

Paper: https://openreview.net/forum?id=2-mkiUs9Jx7

Abstract: Generative adversarial networks (GANs) with clustered latent spaces can perform conditional generation in a completely unsupervised manner. In the real world, the salient attributes of unlabeled data can be imbalanced. However, most of existing unsupervised conditional GANs cannot cluster attributes of these data in their latent spaces properly because they assume uniform distributions of the attributes. To address this problem, we theoretically derive Stein latent optimization that provides reparameterizable gradient estimations of the latent distribution parameters assuming a Gaussian mixture prior in a continuous latent space. Structurally, we introduce an encoder network and novel unsupervised conditional contrastive loss to ensure that data generated from a single mixture component represent a single attribute. We confirm that the proposed method, named Stein Latent Optimization for GANs (SLOGAN), successfully learns balanced or imbalanced attributes and achieves state-of-the-art unsupervised conditional generation performance even in the absence of attribute information (e.g., the imbalance ratio). Moreover, we demonstrate that the attributes to be learned can be manipulated using a small amount of probe data.


A Tensorflow implementation of SLOGAN

Requirements

conda env create --file environment.yaml
conda activate slogan

Synthetic dataset

Model training

SLOGAN can be trained with the synthetic dataset using following commands:

python slogan_synthetic.py --gpu "GPU_NUMBER"

Generated data are stored in './logs/synthetic'


CIFAR-2 dataset

Pretrained models

We release pretrained model weights and training log files of CIFAR-2 (7:3) and CIFAR-2 (5:5).

You can download pretrained model files from This URL and put them into './logs/cifar2/3/pretrained/' for CIFAR-2 (7:3), and './logs/cifar2/5/pretrained/' for CIFAR-2 (5:5)

Training logs and generated images of pretrained models can be viewed using the following command:

tensorboard --logdir ./logs/cifar2/"3 or 5"/pretrained

Pretrained model weights can be loaded and used to calculate evaluation metrics using the following command:

python slogan_cifar2.py --gpu "GPU_NUMBER" --pretrained ./logs/cifar2/"3 or 5"/pretrained/model-100000

Model training

SLOGAN can be trained with the CIFAR-2 (7:3) dataset using the following command:

python slogan_cifar2.py --gpu "GPU_NUMBER" --ratio_plane 3

Log files are stored in './logs/cifar2/"RATIO_PLANE"', and training logs and generated images can be viewed using Tensorboard.


Citation

@inproceedings{hwang2022stein,
    title={Stein Latent Optimization for Generative Adversarial Networks},
    author={Uiwon Hwang and Heeseung Kim and Dahuin Jung and Hyemi Jang and Hyungyu Lee and Sungroh Yoon},
    booktitle={International Conference on Learning Representations},
    year={2022}
}

About

Stein Latent Optimization for Generative Adversarial Networks (ICLR 2022)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages