Skip to content

gwang-kim/DiffusionCLIP

Repository files navigation

DiffusionCLIP: Text-Guided Diffusion Models for Robust Image Manipulation (CVPR 2022)

Replicate Open In Spaces Open In Colab

arXiv arXiv video poster

DiffusionCLIP: Text-Guided Diffusion Models for Robust Image Manipulation
Gwanghyun Kim, Taesung Kwon, Jong Chul Ye
CVPR 2022

Abstract:
Recently, GAN inversion methods combined with Contrastive Language-Image Pretraining (CLIP) enables zero-shot image manipulation guided by text prompts. However, their applications to diverse real images are still difficult due to the limited GAN inversion capability. Specifically, these approaches often have difficulties in reconstructing images with novel poses, views, and highly variable contents compared to the training data, altering object identity, or producing unwanted image artifacts. To mitigate these problems and enable faithful manipulation of real images, we propose a novel method, dubbed DiffusionCLIP, that performs text-driven image manipulation using diffusion models. Based on full inversion capability and high-quality image generation power of recent diffusion models, our method performs zero-shot image manipulation successfully even between unseen domains and takes another step towards general application by manipulating images from a widely varying ImageNet dataset. Furthermore, we propose a novel noise combination method that allows straightforward multi-attribute manipulation. Extensive experiments and human evaluation confirmed robust and superior manipulation performance of our methods compared to the existing baselines.

Description

This repo includes the official PyTorch implementation of DiffusionCLIP, Text-Guided Diffusion Models for Robust Image Manipulation. DiffusionCLIP resolves the critical issues in zero-shot manipulation with the following contributions.

  • We revealed that diffusion model is well suited for image manipulation thanks to its nearly perfect inversion capability, which is an important advantage over GAN-based models and hadn't been analyzed in depth before our detailed comparison.
  • Our novel sampling strategies for fine-tuning can preserve perfect reconstruction at increased speed.
  • In terms of empirical results, our method enables accurate in- and out-of-domain manipulation, minimizes unintended changes, and significantly outperformes SOTA baselines.
  • Our method takes another step towards general application by manipulating images from a widely varying ImageNet dataset.
  • Finally, our zero-shot translation between unseen domains and multi-attribute transfer can effectively reduce manual intervention.

The training process is illustrated in the following figure. Once the diffusion model is fine-tuned, any image from the pretrained domain can be manipulated into the corresponding to the target text without re-training:

We also propose two fine-tuning scheme. Quick original fine-tuning and GPU-efficient fine-tuning. For more details, please refer to Sec. B.1 in Supplementary Material.

Getting Started

Installation

We recommend running our code using:

  • NVIDIA GPU + CUDA, CuDNN
  • Python 3, Anaconda

To install our implementation, clone our repository and run following commands to install necessary packages:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git

Resources

  • For the original fine-tuning, VRAM of 24 GB+ for 256x256 images are required.
  • For the GPU-efficient fine-tuning, VRAM of 12 GB+ for 256x256 images and 24 GB+ for 512x512 images are required.
  • For the inference, VRAM of 6 GB+ for 256x256 images and 9 GB+ for 512x512 images are required.

Pretrained Models for DiffusionCLIP Fine-tuning

To manipulate soure images into images in CLIP-guided domain, the pretrained Diffuson models are required.

Image Type to Edit Size Pretrained Model Dataset Reference Repo.
Human face 256×256 Diffusion (Auto), IR-SE50 CelebA-HQ SDEdit, TreB1eN
Church 256×256 Diffusion (Auto) LSUN-Bedroom SDEdit
Bedroom 256×256 Diffusion (Auto) LSUN-Church SDEdit
Dog face 256×256 Diffusion AFHQ-Dog ILVR
ImageNet 512×512 Diffusion ImageNet Guided Diffusion
  • The pretrained Diffuson models on 256x256 images in CelebA-HQ, LSUN-Church, and LSUN-Bedroom are automatically downloaded in the code.
  • In contrast, you need to download the models pretrained on AFHQ-Dog-256 or ImageNet-512 in the table and put it in ./pretrained directory.
  • In addition, to use ID loss for preserving Human face identity, you are required to download the pretrained IR-SE50 model from TreB1eN and put it in ./pretrained directory.

Datasets

To precompute latents and fine-tune the Diffusion models, you need about 30+ images in the source domain. You can use both sampled images from the pretrained models or real source images from the pretraining dataset. If you want to use real source images,

# CelebA-HQ 256x256
bash data_download.sh celeba_hq .

# AFHQ-Dog 256x256
bash data_download.sh afhq .

If you want to use custom paths, you can simply modify ./configs/paths_config.py.

Colab Notebook Open In Colab

We provide a colab notebook for you to play with DiffusionCLIP! Due to 12GB of the VRAM limit in Colab, we only provide the codes of inference & applications with the fine-tuned DiffusionCLIP models, not fine-tuning code. We provide a wide range of types of edits, and you can also upload your fine-tuned models following below instructions on Colab and test them.

DiffusionCLIP Fine-tuning

To fine-tune the pretrained Diffusion model guided by CLIP, run the following commands:

python main.py --clip_finetune          \
               --config celeba.yml      \
               --exp ./runs/test        \
               --edit_attr neanderthal  \
               --do_train 1             \
               --do_test 1              \
               --n_train_img 50         \
               --n_test_img 10          \
               --n_iter 5               \
               --t_0 500                \
               --n_inv_step 40          \
               --n_train_step 6         \
               --n_test_step 40         \
               --lr_clip_finetune 8e-6  \
               --id_loss_w 0            \
               --l1_loss_w 1            
  • You can use --clip_finetune_eff instead of --clip_finetune to save GPU memory.
  • config: celeba.yml for human face, bedroom.yml for bedroom, church.yml for church, afhq.yml for dog face and , imagenet.yml for images from ImageNet.
  • exp: Experiment name.
  • edit_attr: Attribute to edit, you can use ./utils/text_dic.py to predefined source-target text pairs or define new pair.
    • Instead, you can use --src_txts and --trg_txts.
  • do_train, do_test: If you finish training quickly withouth checking the outputs in the middle of training, you can set do_test as 1.
  • n_train_img, n_test_img: # of images in the trained domain for training and test.
  • n_iter: # of iterations of a generative process with n_train_img images.
  • t_0: Return step in [0, 1000), high t_0 enable severe change but may lose more identity or semantics in the original image.
  • n_inv_step, n_train_step, n_test_step: # of steps during the generative pross for the inversion, training and test respectively. They are in [0, t_0]. We usually use 40, 6 and 40 for n_inv_step, n_train_step and n_test_step respectively.
    • We found that the manipulation quality is better when n_***_step does not divide t_0. So we usally use 301, 401, 500 or 601 for t_0.
  • lr_clip_finetune: Initial learning rate for CLIP-guided fine-tuning.
  • id_loss_w, l1_loss : Weights of ID loss and L1 loss when CLIP loss weight is 3.

Novel Applications

The fine-tuned models through DiffusionCLIP can be leveraged to perform the several novel applications.

Manipulation of Images in Trained Domain & to Unseen Domain

You can edit one image into the CLIP-guided domain by running the following command:

python main.py --edit_one_image            \
               --config celeba.yml         \
               --exp ./runs/test           \
               --t_0 500                   \
               --n_inv_step 40             \
               --n_test_step 40            \
               --n_iter 1                  \
               --img_path imgs/celeb1.png  \
               --model_path  checkpoint/neanderthal.pth
  • img_path: Path of an image to edit
  • model_path: Finetuned model path to use

You can edit multiple images from the dataset into the CLIP-guided domain by running the following command:

python main.py --edit_images_from_dataset  \
               --config celeba.yml         \
               --exp ./runs/test           \
               --n_test_img 50             \
               --t_0 500                   \
               --n_inv_step 40             \
               --n_test_step 40            \
               --model_path checkpoint/neanderthal.pth

Image Translation from Unseen Domain into Another Unseen Domain

Generation of Images in Unseen Domain from Strokes

You can tranlate images from an unseen domain to another unseen domain. (e.g. Stroke/Anime ➝ Neanderthal) using following command:

python main.py --unseen2unseen          \
               --config celeba.yml      \
               --exp ./runs/test        \
               --t_0 500                \
               --bs_test 4              \
               --n_iter 10              \
               --n_inv_step 40          \
               --n_test_step 40         \
               --img_path imgs/stroke1.png \
               --model_path  checkpoint/neanderthal.pth
  • img_path: Stroke image or source image in the unseen domain e.g. portrait
  • n_iter: # of iterations of stochastic foward and generative processes to translate an unseen source image into the image in the trained domain. It's required to be larger than 8.

Multiple Attribute Changes

You can change multiple attributes thorugh only one generative process by mixing the noise from the multipe fintuned models.

  1. Set HYBRID_MODEL_PATHS of HYBRID_CONFIG in ./configs/paths_config. The keys of
  2. Run the commands for above Manipulation of Images in Trained Domain & to Unseen Domain with --hybrid_noise 1
HYBRID_MODEL_PATHS = [
	'curly_hair.pth',
	'makeup.pth',
]

HYBRID_CONFIG = \
	{ 300: [1, 0],**
	    0: [0.3, 0.7]}

The keys and values of HYBRID_CONFIG dictionary correspond to thresholds and ratios for the noise mixing process using multiple models. The following pseudo-code represent the noise mixing process. The full codes are in ./utils/diffusion_utils.py.

# models: list of the finetuned diffusion models 

for thr in list(HYBRID_CONFIG.keys()):
    if t >= thr:
        et = 0
        for i, ratio in enumerate(HYBRID_CONFIG[thr]):
            ratio /= sum(HYBRID_CONFIG[thr])
            et_i = models[i](xt, t)
            et += ratio * et_i
        break

Finetuned Models Using DiffuionCLIP

We provide a Google Drive containing several fintuned models using DiffusionCLIP. Human Face, Dog Face, Church, Bedroom, ImageNet Style Transfer, ImageNet Tennis Ball

Related Works

Usage of guidance by CLIP to manipulate images is motivated by StyleCLIP and StyleGAN-NADA. Image translation from an unseen domain to the trained domain using diffusion models is introduced in SDEdit, ILVR. DDIM sampling and its reveral for generation and inversion of images are introduced by in DDIM, Diffusion Models Beat GANs on Image Synthesis.

Our code strcuture is based on the official codes of SDEdit and StyleGAN-NADA. We used pretrained models from SDEdit and ILVR.

Citation

If you find DiffusionCLIP useful in your research, please consider citing:

@InProceedings{Kim_2022_CVPR,
    author    = {Kim, Gwanghyun and Kwon, Taesung and Ye, Jong Chul},
    title     = {DiffusionCLIP: Text-Guided Diffusion Models for Robust Image Manipulation},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {2426-2435}
}

Additional Results

Here, we show more manipulation of real images in the diverse datasets using DiffusionCLIP where the original pretrained models are trained on AFHQ-Dog, LSUN-Bedroom and ImageNet, respectively.

About

[CVPR 2022] Official PyTorch Implementation for DiffusionCLIP: Text-guided Image Manipulation Using Diffusion Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published