Skip to content

Yan98/S2FGAN

S2FGAN-pytorch Implementation

Dependency

  • python 3.7.4
  • numpy 1.18.1
  • Pillow 7.0.0
  • opencv-python 4.2.0.32
  • torch 1.5.1
  • torchvision 0.5.0
  • albumentations 0.4.6
  • cudnn 7.6.5
  • CUDA 10.1

At least a single GPU is needed. Please install the library with CUDA and C++ in Linux system.

Dataset

Notice

  • S2FGAN is only tested in 256x256 resolution.
  • We refactor the current implementation from our original training code. If you find any implementation error, please do not hesitate to contact us.

Train S2FGAN

  • The validation images will be saved in sample folder, the model checkpoints will be saved in checkpoint, the training log will be written in log.txt.

  • For training, please run train.py, while set the parameters properly.

python train.py --help

--iter                                 #total training iterations
--batch                                #batch size
--r1                                   #weight of the r1 regularization
--d_reg_every                          #interval of the applying r1 regularization to discriminator 
--lr                                   #learning rate
--augment                              #apply discriminator augmentation
--augment_p                            #probability of applying discriminator augmentation. 0 = use adaptive augmentation
--ada_target                           #target augmentation probability for adaptive augmentation
--ada_length                           #target duraing to reach augmentation probability for adaptive augmentation
--ada_every                            #probability update interval of the adaptive augmentation
--img_height                           #image height
--img_width                            #image width
--NumberOfImage                        #The number of images in the zip.
--imageZip                             #input image zip
--hedEdgeZip                           #hed sketch zip
--hedEdgePath                          #hed_edge_256
--imagePath                            #path of images in the zip
--TORCH_HOME                           #The director store pertained pytorch model, "None" will load the pertained model from default director.
--label_path                           #attributes annotation text file of CelebAMask-HQ
--selected_attrs                       #selected attributes for the CelebAMask-HQ dataset
--ATMDTT                               #Attributes to manipulate during testing time   
--model_type                           #0- S2F-DIS, 1- S2F-DEC
  • Train on S2F-DIS
python3 train.py --model_type 0 #Please set data path properly. 
  • Train on S2F-DEC
python3 train.py --model_type 1 #Please set data path properly. 

Code for Related Work

Evaluation metrics

Todo

  • Upload pretrained checkpoints
  • Upload testing script

If you are urgent to use the checkpoint, please drop me an email.

License

The Equalized layer, Modulated layer, PixelNorm and CUDA kernels are from offical styleGAN. For more details, please refer to repostiories: https://github.com/NVlabs/stylegan2

Thanks for Rosinality's StyleGAN pytorch implementation. The S2FGAN builds based on this template: https://github.com/rosinality/stylegan2-pytorch.

The dataloader is based on eriklindernoren's repostiories: https://github.com/eriklindernoren/PyTorch-GAN

The AttGAN can be find in https://github.com/elvisyjlin/AttGAN-PyTorch

Data prefetcher is based on the implementation from NVIDIA Apex: https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py#L256

The HED detector loading and Crop layer implementation is from Rosebrock: https://www.pyimagesearch.com/2019/03/04/holistically-nested-edge-detection-with-opencv-and-deep-learning/

Demo Video for Attribute Editing - Click to Play

IMAGE ALT TEXT

Citation

If you find S2FGAN useful in your research work, please consider citing:

    @ARTICLE{s2fgan,
    author = {Yang, Yan and Hossain, Md Zakir and Gedeon, Tom and Rahman, Shafin},
    year = {2020},
    month = {11},
    pages = {},
    title = {S2FGAN: Semantically Aware Interactive Sketch-to-Face Translation}
    }