Stein Latent Optimization for Generative Adversarial Networks (ICLR 2022)
Uiwon Hwang, Heeseung Kim, Dahuin Jung, Hyemi Jang, Hyungyu Lee, Sungroh Yoon
Seoul National University
Paper: https://openreview.net/forum?id=2-mkiUs9Jx7
Abstract: Generative adversarial networks (GANs) with clustered latent spaces can perform conditional generation in a completely unsupervised manner. In the real world, the salient attributes of unlabeled data can be imbalanced. However, most of existing unsupervised conditional GANs cannot cluster attributes of these data in their latent spaces properly because they assume uniform distributions of the attributes. To address this problem, we theoretically derive Stein latent optimization that provides reparameterizable gradient estimations of the latent distribution parameters assuming a Gaussian mixture prior in a continuous latent space. Structurally, we introduce an encoder network and novel unsupervised conditional contrastive loss to ensure that data generated from a single mixture component represent a single attribute. We confirm that the proposed method, named Stein Latent Optimization for GANs (SLOGAN), successfully learns balanced or imbalanced attributes and achieves state-of-the-art unsupervised conditional generation performance even in the absence of attribute information (e.g., the imbalance ratio). Moreover, we demonstrate that the attributes to be learned can be manipulated using a small amount of probe data.
conda env create --file environment.yaml
conda activate slogan
SLOGAN can be trained with the synthetic dataset using following commands:
python slogan_synthetic.py --gpu "GPU_NUMBER"
Generated data are stored in './logs/synthetic'
We release pretrained model weights and training log files of CIFAR-2 (7:3) and CIFAR-2 (5:5).
You can download pretrained model files from This URL and put them into './logs/cifar2/3/pretrained/' for CIFAR-2 (7:3), and './logs/cifar2/5/pretrained/' for CIFAR-2 (5:5)
Training logs and generated images of pretrained models can be viewed using the following command:
tensorboard --logdir ./logs/cifar2/"3 or 5"/pretrained
Pretrained model weights can be loaded and used to calculate evaluation metrics using the following command:
python slogan_cifar2.py --gpu "GPU_NUMBER" --pretrained ./logs/cifar2/"3 or 5"/pretrained/model-100000
SLOGAN can be trained with the CIFAR-2 (7:3) dataset using the following command:
python slogan_cifar2.py --gpu "GPU_NUMBER" --ratio_plane 3
Log files are stored in './logs/cifar2/"RATIO_PLANE"', and training logs and generated images can be viewed using Tensorboard.
@inproceedings{hwang2022stein,
title={Stein Latent Optimization for Generative Adversarial Networks},
author={Uiwon Hwang and Heeseung Kim and Dahuin Jung and Hyemi Jang and Hyungyu Lee and Sungroh Yoon},
booktitle={International Conference on Learning Representations},
year={2022}
}