Skip to content
This repository has been archived by the owner on Jan 26, 2022. It is now read-only.

piyush01123/GAN-Lab

Repository files navigation

GAN Lab

This repository is a lab to learn about and train variants of generative adversarial networks or GANs. There are also examples of training GANs of Cloud ML Engine (a cheap option to train models).

Why learn GANs?

GANs are interesting because they are capable of learning even complex data distributions and generate (fake) samples from that distribution. For example:

Fake MNIST images generated by DCGAN after 4000 training iterations.

How to train GANs

The basic framework consists of 2 neural networks:

  • G or Generator and
  • D or Discriminator

G takes as input a noise vector and outputs a sample with same shape as real samples (we call these fake samples). D takes as input a sample (real or fake) and outputs the probability that the input sample is real. Mathematically, G and D are playing a min-max game for a value function V:

$$
\min_G \max_D V(D, G) = E_{x \sim X}log(D(x)) + E_{z \sim Z}log(1-D(G(z)))
$$

Above equation implies that D tries to maximize V(D, G) and G tries to minimize V(D, G). Also the log in the above equation can be modified to things like mean-squared error as in LSGAN. The above formulation is for the binary-crossentropy loss.

However, this objective function has small gradient near 0 for G, so training the generator is difficult in initial phase. So, we modify the objective function as:

$$
\max_D V(D, G) = E_{x \sim X}log(D(x)) + E_{z \sim Z}log(1-D(G(z)))
$$

and

$$
\max_G V(D, G) =  E_{z \sim Z}log(D(G(z)))
$$

In practice, we first construct G, D and AM or Adversarial Model (D stacked on top of G) and then at each training step, we teach D to differentiate between fake and real samples and simultaneously we teach AM to generate fake samples which are able to fool D. Also, D should be set to non-trainable in the 2nd part.

The idea is that with enough training, G will be able to fool D (which means we can construct real-looking samples using G.)

The pattern for the above looks like this:

"""Part 1 - Training discriminator(D)""
# Generate a noise of appropriate shape
noise = np.random.normal(0, 1, size=[batch_size, noise_dim])

# Generate a fake sample from this random noise
fake = G.predict(noise)

# Take a sample from real data
real = data[np.random.randint(0, data.shape[0], batch_size), :, :, :]

# train_x is the concatenation of real and fake samples
train_x = np.concatenate((real, fake))

# and train_y is the matrix containing 1's for real and 0's for fake samples
train_y = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))))

# Train D on this batch of train_x and train_y
d_loss = D.train_on_batch(train_x, train_y)
# Actually, in this repo this step has been done with 2 separate train_on_batch calls for real and fake samples

"""Part 2 - Training adversarial model(G+D)"""
# Generate a noisy sample of appropriate shape
noise = np.random.normal(0, 1, size=[batch_size, noise_dim])

# Here's the trick: forcing the Adversarial model to have output class 1 for fake samples
Y = np.ones((batch_size, 1))

# Train AM (G+D) on this batch of noise and Y
a_loss = AM.train_on_batch(noise, Y)

Above steps are repeated for a pre-specified number of steps and after enough steps, G will be able to generate samples visually indistinguishable from actual images.

About

Experimenting with GANs in Tensorflow/Keras

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published