This repository provides the code to reproduce the experimental results in the paper Augmented Sliced Wasserstein Distances by Xiongjie Chen, Yongxin Yang and Yunpeng Li.
To install the required python packages, run the following command:
pip install -r requirements.txt
Two datasets are used in this repository, namely the CIFAR10 dataset and CELEBA dataset.
- The CIFAR10 dataset (64x64 pixels) will be automatically downloaded from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz when running the experiment on CIFAR10 dataset.
- The CELEBA dataset needs be be manually downloaded and can be found on the website http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, we use the cropped CELEBA dataset with 64x64 pixels.
To calculate the Fréchet Inception Distance (FID score), precalculated statistics for datasets
- CIFAR 10 (calculated on all training samples)
- cropped CelebA (64x64, calculated on all samples)
are provided at: http://bioinf.jku.at/research/ttur/.
Two experiments are included in this repository, where benchmarks are from the paper Generalized Sliced Wasserstein Distances and the paper Distributional Sliced-Wasserstein and Applications to Generative Modeling, respectively. The first one is on the task of sliced Wasserstein flow, and the second one is on generative modellings with GANs. For more details and setups, please refer to the original paper Augmented Sliced Wasserstein Distances.
./result/ASWD/CIFAR/
contains generated imgaes trained with the ASWD on CIFAR10 dataset../result/ASWD/CIFAR/fid/
FID scores of generated imgaes trained with the ASWD on CIFAR10 dataset are saved in this folder../result/CIFAR/
model's weights and losses in the CIFAR10 experiment are stored in this directory.
Other setups follow the same naming rule.
The sliced Wasserstein flow example can be found in the jupyter notebook.
The following scripts belong to the generative modelling example:
- main.py : run this file to conduct experiments.
- utils.py : contains implementations of different sliced-based Wasserstein distances.
- TransformNet.py : edit this file to modify architectures of neural networks used to map samples.
- experiments.py : functions for generating and saving randomly generated images.
- DCGANAE.py : neural network architectures and optimization objective for training GANs.
- fid_score.py : functions for calculating statistics (mean & covariance matrix) of distributions of images and the FID score between two distributions of images.
- inception.py : download the pretrained InceptionV3 model and generate feature maps for FID evaluation.
The generative modelling experiment evaluates the performances of GANs trained with different sliced-based Wasserstein metrics. To train and evaluate the model, run the following command:
python main.py --model-type ASWD --dataset CIFAR --epochs 200 --num-projection 1000 --batch-size 512 --lr 0.0005
--model-type
type of sliced-based Wasserstein metric used in the experiment, available options: ASWD, DSWD, SWD, MSWD, GSWD. Must be specified.--dataset
select from: CIFAR, CELEBA, default as CIFAR.--epochs
training epochs, default as 200.--num-projection
number of projections used in distance approximation, default as 1000.--batch-size
batch size for one iteration, default as 512.--lr
learning rate, default as 0.0005.
--niter
number of iteration, available for the ASWD, MSWD and DSWD, default as 5.--lam
coefficient of regularization term, available for the ASWD and DSWD, default as 0.5.--r
parameter in the circular defining function, available for GSWD, default as 1000.
Sliced Wasserstein flow
We conduct the sliced Wasserstein flow experiment on eight different datasets and the experimental results are presented in the following figure. The first and third columns in the figure below are target distributions. The second and fourth columns are log 2-Wasserstein distances between the target distribution and the source distribution. The horizontal axis show the number of training iterations. Solid lines and shaded areas represent the average values and 95% confidence intervals of log 2-Wasserstein distances over 50 runs.Generative modelling
The table below provides FID scores of generative models trained with different distance metrics. Lower scores indicate better image qualities. In what follows, L is the number of projections, we run each experiment 10 times and report the average values and standard errors of FID scores for CIFAR10 dataset and CELEBA dataset. The running time per training iteration for one batch containing 512 samples is computed based on a computer with an Intel (R) Xeon (R) Gold 5218 CPU 2.3 GHz and 16GB of RAM, and a RTX 6000 graphic card with 22GB memories.
With L=1000 projections, the following figure shows the convergence rate of FID scores of generative models trained with different metrics on CIFAR10 and CELEBA datasets. The error bar represents the standard deviation of the FID scores at the specified training epoch among 10 simulation runs.
The code of generative modelling example is based on the implementation of DSWD by VinAI Research.
The pytorch code for calculating the FID score is from https://github.com/mseitzer/pytorch-fid.
- Distributional Sliced-Wasserstein and Applications to Generative Modeling
- Generalized Sliced Wasserstein Distances
- Sliced Wasserstein Auto-Encoders
- Max-Sliced Wasserstein Distance and its Use for GANs
If you find this code useful for your research, please cite our paper:
@article{chen2020augmented,
title={Augmented Sliced Wasserstein Distance},
author={Chen, Xiongjie and Yang, Yongxin and Li, Yunpeng},
journal={arXiv preprint arXiv:2006.08812},
year={2020}
}