NOTE: This repository is fored from google/compare_gan, and contains the implementations and experiments for our paper (below) as well. Some new tips and workarounds are added to reproduce the training from scratch, mainly in mainland of China. More information about the original google/compare_gan can be found in README-Google.md.
Noise Homogenization via Multi-Channel Wavelet Filtering for High-Fidelity Sample Generation in GANs
In a typical Generative Adversarial Network (GAN), a noise is sampled to generate fake samples via a series of convolutional operations after random initialization. However, current GANs merely rely on the pixel space to sample the noise, which increases the difficulty to approach the target distribution. Fortunately, the long proven wavelet transformation is able to decompose multiple spectral information from the images. In this work, we propose a novel multi-channel wavelet-based filtering method for GANs, to cope with this problem. When embedding a wavelet deconvolution layer in the generator, the resultant GAN, called WaveletGAN, takes advantage of the wavelet deconvolution to learn a filter with multiple channels, or multiple convolutional filters, which can efficiently homogenize the sampled noise via an averaging operation, so as to generate high-fidelity samples. (arXiv preprint arXiv:2005.06707)
Figure 1. WaveletGAN architecture using wavelet filtering to homogenize the generated noise.
# Final processing of the output.
output = self.batch_norm(
output, z=z, y=y, is_training=is_training, name="final_norm")
output = tf.nn.relu(output)
output = ops.conv2d(output, output_dim=self._image_shape[2], k_h=3, k_w=3,
d_h=1, d_w=1, name="final_conv",
use_sn=self._spectral_norm,)
if self._wavelet_deconv: # Add WaveletDeconv layer
output = ops.waveletDeconv(output)
The code can be found in resnet_mnist.py.
To install requirements:
pip install -e .
The .gin config files related with our models are all in example_configs:
- resnet_fmnist_WaveletGAN.gin - our WaveletGAN using ResNet on Fashion-MNIST.
- resnet_kmnist_WaveletGAN.gin - our WaveletGAN using ResNet on KMNIST.
- resnet_svhn_WaveletGAN.gin - our WaveletGAN using ResNet on SVHN.
The option to enable WaveletDeconv in the generator is
# Enable WaveletDeconv (True or False)
G.wavelet_deconv= True
The same directory contains more config files provided by Google.
To see all available options please run `python main.py --help`. Main options:
* To **train** the model use `--schedule=train` (default). Training is resumed
from the last saved checkpoint.
* To **evaluate** all checkpoints use `--schedule=continuous_eval
--eval_every_steps=0`. To evaluate only checkpoints where the step size is
divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`.
By default, 3 averaging runs are used to estimate the Inception Score and
the FID score. Keep in mind that when running locally on a single GPU it may
not be possible to run training and evaluation simultaneously due to memory
constraints.
* To **train and evaluate** the model use `--schedule=eval_after_train
--eval_every_steps=0`.
To train and evaluate our models, run the following commands in compare_gan/compare_gan:
- Fashion-MNIST:
python main.py --gin_config ../example_configs/resnet_fmnist_WaveletGAN.gin --model_dir ../resnet_fmnist_WaveletGAN --score_filename resnet_fmnist_WaveletGAN_score.csv --schedule eval_after_train
- KMNIST:
python main.py --gin_config ../example_configs/resnet_kmnist_WaveletGAN.gin --model_dir ../resnet_kmnist_WaveletGAN --score_filename resnet_kmnist_WaveletGAN_score.csv --schedule eval_after_train
- SVHN:
python main.py --gin_config ../example_configs/resnet_svhn_WaveletGAN.gin --model_dir ../resnet_svhn_WaveletGAN --score_filename resnet_svhn_WaveletGAN_score.csv --schedule eval_after_train
- Baidu Netdisk: https://pan.baidu.com/s/1uQu-2i_NLxAgZvrbi6dgqw (8k8f)
- Google Dirve: https://drive.google.com/open?id=187WCyYEvT_9VDkk36nWLc-HxHr2btXqC
- The generated samples
Figure 2. The real and generated samples from each dataset.
- FIDs
Please refer to https://github.com/google/compare_gan.
The goals are to:
- Make it work in mainland, China.
- Fix some issues locally.
- Discuss with ones who in the same environment.
- Contribute to cutting-edge GAN research.
After installing the prerequired libraries via
pip install -e .
Make sure you are using the following versions of tools:
sudo apt install cuda-10-0
pip install tensorflow-gpu==1.13.1
Install a newer version of tensorflow-datasets for KMNIST:
pip install tensorflow-datasets==1.0.2
However, make sure to use 1.0.1 for manually preparing ImageNet:
pip install tensorflow-datasets==1.0.1
- See the commit to fix
TypeError: '<=' not supported between instances of 'int' and 'str'