This is a Pytorch implementation of StyleGAN (https://arxiv.org/abs/1812.04948), with the capability of generating 1024x1024 pictures. Training to grow to 1024x1024 is also supported. A 1080 Ti is recommended for faster training speed.
See conda.yaml. Please note that I have cuda 10.0 installed. Change your conda.yaml accordingly if you use different cuda version.
This step is not needed if running the generation command succeeds for downloading. If downloading fails for Google Drive, manual download is required:
And place them in ./pretrained directory.
Download the celeba dataset. Unzip the .zip file into ./data/celeba directory.
Run the command:
mlflow -e generate . -P dataset=cats
The default random seed is 77. To generate different images with different image grid (note that the number of images you can generate is limited by your GPU).
mlflow -e generate -P dataset=cats -P random-seed=777 -P nrow=2 -P ncol=5
This will generate 10 images at once.
mlflow -e genearte . \
-P use_official_checkpoints=False \
-P g_checkpoint=[path_to_generator_checkpoint] \
-P target_resolution=128 \
-P nrow=2 \
-P ncol=2
This will generate images using checkpoints trained by this code.
Run the command to start from scratch:
mlflow -e train . -P resume=False
This will kick off the training for 128x128 resolution on CelebA dataset. During training, the model checkpoints are stored under ./checkpoints, and the fake images are generated for checking under ./checks/fake_imgs. Note that this is a progressive process starting from 8x8, so you will see 8x8 images in the begining and 128x128 images in the end of the training process.
To resume training:
mlflow -e train .\
-P resume=True \
-P g_checkpoint=[path_to_generator_checkpoint] \
-P d_checkpoint=[path_to_discriminator_checkpoint]
For other training options, please check the MLproject file. For hyperparameters, please check train.py and NVidia's official implementation.
- Add truncation trick
- Add and experiment with other loss functions (some are in the repo but not tried)
- Add tensorboard support
- Add moving average of generator's weight
Multi-GPU support is added but not experimented due to hardware limitation.
This project is under BSD-3 license.