Skip to content

IShengFang/Relativistic-average-GAN-Keras

Repository files navigation

Relativistic average GAN with Keras

The implementation Relativistic average GAN with Keras

[paper] [blog] [original code(pytorch)]

How to Run?

Python3 Script

mkdir result
python RaGAN_CustomLoss.py --dataset [dataset] --loss [loss] 
python RaGAN_CustomLayers.py --dataset [dataset] --loss [loss] 

[dataset]: mnist, fashion_mnist, cifar10

[loss]: BXE for Binary Crossentropy, LS for Least Squares

italic arguments are default

Jupyter notebook

Custom Loss [Colab][NBViewer]

Custom Layer [Colab][NBViewer]

Result

1 epoch MNIST Fashion MNIST CIFAR10
Binary Cross Entropy
Least Square
10 epoch MNIST Fashion MNIST CIFAR10
Binary Cross Entropy
Least Square
50epoch MNIST Fashion MNIST CIFAR10
Binary Cross Entropy
Least Square
100epoch MNIST Fashion MNIST CIFAR10
Binary Cross Entropy
Least Square
Loss MNIST Fashion MNIST CIFAR10
Binary Cross Entropy
Least Square

What is Relativistic average GAN?

TL;DR

What is different with original GAN

For better math equations rendering, check out HackMD Version

GAN

The GAN is the two player game which subject as below

formula

formula is a value function( aka loss or cost function) formula is a generator, formula is a sample noise from the distribution we known(usually multidimensional Gaussian distribution). formula is a fake data generated by the generator. We want formula in the real data distribution. formula is a discriminator, which finds out that formula is a real data (output 1) or a fake data(output 0) In the training iteration, we will train one neural network first(usual is discriminator), and train the other network. After a lot of iterations, we expect the last generator to map multidimensional Gaussian distribution to the real data distribution.

Relativistic average GAN (RaGAN)

RaGAN's Loss function does not optimize discriminator to distinguish data real or fake. Instead, RaGAN's discriminator distinguishes that "real data isn’t like average fake data" or "fake data isn’t like average real data".

the discriminator estimates the probability that the given real data is more realistic than a randomly sampled fake data. paper subsection.4.1

Given Discriminator output formula Origin GAN Loss is as below,

![formula](https://render.githubusercontent.com/render/math?math=L_D=-\mathbb{E}_{x_{real}\sim\mathbb{P}_{real}}[\logD(x_{real})]-\mathbb{E}_{x_{fake}\sim\mathbb{P}_{fake}}[\log ( 1-D(x_{fake}))])

formula

Relativistic average output is formula and formula

RaGAN's Loss is as below, formula formula

we can also add relativistic average in Least Square GAN or any other GAN Modified by Jonathan Hui from Paper

How to implement with Keras?

We got loss, so just code it. 😄 Just kidding, we have two approaches to implement RaGAN. The important part of implementation is discriminator output. formula and formula We need to average formula and formula. We also need "minus" to get formula and formula.

We can use keras.backend to deal with it, but that means we need to custom loss. We can also write custom layers to apply these computations to keras as a layer, and use Keras default loss to train the model.

  1. Custom layer

    • Pros:
      • Train our RaGAN easily with keras default loss
    • Cons:
      • Write custom layers to implement it.
    • [Colab][NBViewer]
  2. Custom Loss

    • Pros:
      • Do not need to write custom layers. Instead, we need write loss with keras.backend.
      • Custom loss is easy to change loss.
    • Cons:
      • Write custom loss with keras.backend to implement it.
    • [Colab][NBViewer]

Code

Custom Loss

[Colab][NBViewer][python script]

Custom Layer

[Colab][NBViewer][python script]