Skip to content

Examples trained using the python pytorch package pro-gan-pth

License

Notifications You must be signed in to change notification settings

akanimax/pro_gan_pytorch-examples

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

46 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pro_gan_pytorch-examples

This repository contains examples trained using the python package pro-gan-pth. You can find the github repo for the project at github-repository and the PyPI package at pypi

There are two examples presented here for LFW dataset and MNIST dataset. Please refer to the following sections for how to train and / or load the provided trained weights for these models.

Prior Setup

Before running any of the following training experiments, please setup your VirtualEnv with the required packages for this project. Importantly, please install the progan package using $ pip install pro-gan-pth and your appropriate gpu / cpu version of PyTorch 0.4.0. Once this is done, you can proceed with the following experiments.

LFW Experiment

The configuration used for the LFW training experiment can be found in implementation/configs/lfw.conf in this repository. The training was performed using the wgan-gp loss function.

Examples:


Sample loss plot:


MNIST Experiment

The configuration used for the MNIST training experiment can be found in implementation/configs/mnist.conf in this repository. The training was performed using the lsgan loss function.

Examples:


Sample loss plot:


How to use:

Running the training script:

For running the training script, simply use the following procedure:
$ cd implementation
$ python train_network.py --config=configs/mnist.conf

You can tinker with the configuration for your desired behaviour. This training script also exposes some of the use cases of the package's api.

Generating loss plots:

You can generate the loss plots from the `loss-logs` by using the provided script. The logs get generated while the training is in progress.
$ python generate_loss_plots --logdir=training_runs/mnist/losses/ \
                             --plotdir=training_runs/mnist/losses/loss_plots/

Using trained model:

please refer to the following code snippet if you just wish to use the trained model for generating samples:
import torch as th
import pro_gan_pytorch.PRO_GAN as pg
import matplotlib.pyplot as plt

device = th.device("cuda" if th.cuda.is_available() 
                   else "cpu")
gen = pg.Generator(depth=4, latent_size=128, 
                   use_eql=False).to(device)

gen.load_state_dict(
    th.load("training_runs/saved_models/GAN_GEN_3.pth")
)

noise = th.randn(1, 128).to(device)

sample_image = gen(noise, detph=3, alpha=1).detach()

plt.imshow(sample_image[0].permute(1, 2, 0) / 2 + 0.5)
plt.show()

The trained weights can be found in the saved_models directory present in respective training_runs.

How to use on Google Colab Notebook:

This code can be run on Google Colaboratory using GPU acceleration. Colab offers a free Tesla K80 GPU with up to ~12GB of VRAM. However, the duration of the instance is limited and closes after a certain time. All installed libraries and saved files will be reset in that process. A workaround is to save training results to Google Drive. The packages need to be installed after every instance reset.

Here is a step-by-step instruction on how to run this using Google Colab. ProGAN Colaboratory Notebook

Thanks:

Please feel free to open PRs here if you train on other datasets using this package.

Best regards,
@akanimax :)