Skip to content

Latest commit

 

History

History
70 lines (55 loc) · 2.89 KB

facebookresearch_pytorch-gan-zoo_pgan.md

File metadata and controls

70 lines (55 loc) · 2.89 KB
layout background-class body-class title summary category image author tags github-link github-id featured_image_1 featured_image_2 accelerator order demo-model-link
hub_detail
hub-background
hub
Progressive Growing of GANs (PGAN)
High-quality image generation of fashion, celebrity faces
researchers
pganlogo.png
FAIR HDGAN
vision
generative
facebookresearch/pytorch_GAN_zoo
pgan_mix.jpg
pgan_celebaHQ.jpg
cuda-optional
10
import torch
use_gpu = True if torch.cuda.is_available() else False

# trained on high-quality celebrity faces "celebA" dataset
# this model outputs 512 x 512 pixel images
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
                       'PGAN', model_name='celebAHQ-512',
                       pretrained=True, useGPU=use_gpu)
# this model outputs 256 x 256 pixel images
# model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
#                        'PGAN', model_name='celebAHQ-256',
#                        pretrained=True, useGPU=use_gpu)

The input to the model is a noise vector of shape (N, 512) where N is the number of images to be generated. It can be constructed using the function .buildNoiseData. The model has a .test function that takes in the noise vector and generates images.

num_images = 4
noise, _ = model.buildNoiseData(num_images)
with torch.no_grad():
    generated_images = model.test(noise)

# let's plot these images using torchvision and matplotlib
import matplotlib.pyplot as plt
import torchvision
grid = torchvision.utils.make_grid(generated_images.clamp(min=-1, max=1), scale_each=True, normalize=True)
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
# plt.show()

You should see an image similar to the one on the left.

If you want to train your own Progressive GAN and other GANs from scratch, have a look at PyTorch GAN Zoo.

Model Description

In computer vision, generative models are networks trained to create images from a given input. In our case, we consider a specific kind of generative networks: GANs (Generative Adversarial Networks) which learn to map a random vector with a realistic image generation.

Progressive Growing of GANs is a method developed by Karras et. al. [1] in 2017 allowing generation of high resolution images. To do so, the generative network is trained slice by slice. At first the model is trained to build very low resolution images, once it converges, new layers are added and the output resolution doubles. The process continues until the desired resolution is reached.

Requirements

  • Currently only supports Python 3

References

[1] Tero Karras et al, "Progressive Growing of GANs for Improved Quality, Stability, and Variation" https://arxiv.org/abs/1710.10196