Skip to content

guillaumejs2403/DiME

Repository files navigation

DiME's official code

This is the codebase for the ACCV 2022 paper Diffusion Models for Counterfactual Explanations.

Environment

Through anaconda, install our environment:

conda env create -f env.yaml
conda activate dime

Data preparation

Please download and uncompress the CelebA dataset here. There is no need for any post-processing. The final folder structure should be:

PATH ---- img_align_celeba ---- xxxxxx.jpg
      |
      --- list_attr_celeba.csv
      |
      --- list_eval_partition.csv

Downloading pre-trained models

To use our trained models, you must download them first. Please extract them to the folder models. Our code provides the CelebA diffusion model, the classifier under observation, and the trained oracle. Download the VGGFace2 model throught this github repo. Download the resnet50_ft model.

Download Link:

Extracting Counterfactual Explanations

To create the counterfactual explanations, please use the main.py script as follows:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --learn_sigma True --noise_schedule linear --num_channels 128 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAMPLE_FLAGS="--batch_size 50 --timestep_respacing 200"
DATAPATH=/path/to/dataset
MODELPATH=/path/to/model.pt
CLASSIFIERPATH=/path/to/classifier.pt
ORACLEPATH=/path/to/oracle.pt
OUTPUT_PATH=/path/to/output
EXPNAME=exp/name

# parameters of the sampling
GPU=0
S=60
SEED=4
USE_LOGITS=True
CLASS_SCALES='8,10,15'
LAYER=18
PERC=30
L1=0.05
QUERYLABEL=31
TARGETLABEL=-1
IMAGESIZE=128  # dataset shape

python -W ignore main.py $MODEL_FLAGS $SAMPLE_FLAGS \
  --query_label $QUERYLABEL --target_label $TARGETLABEL \
  --output_path $OUTPUT_PATH --num_batches $NUMBATCHES \
  --start_step $S --dataset 'CelebAMV' \
  --exp_name $EXPNAME --gpu $GPU \
  --model_path $MODELPATH --classifier_scales $CLASS_SCALES \
  --classifier_path $CLASSIFIERPATH --seed $SEED \
  --oracle_path $ORACLEPATH \
  --l1_loss $L1 --use_logits $USE_LOGITS \
  --l_perc $PERC --l_perc_layer $LAYER \
  --save_x_t True --save_z_t True \
  --use_sampling_on_x_t True \
  --save_images True --image_size $IMAGESIZE

Given that the sampling process may take much time, we've included a way to split the sampling into multiple processes. To use this feature, include the flag --num_chunks C, where C is the number of chunks to split the dataset. Then, run C times the code using the flag --chunk c, where c is the chunk to generate the evaluation (hence, c \in {0, 1, ..., C - 1}).

The results will be stored OUTPUT_PATH. This folder has the following structure:

OUTPUT_PATH ----- Original ---- Correct
              |             |
              |             --- Incorrect
              |
              |
              |
              --- Results ---- EXPNAME ---- (I/C)C ---- (I/C)CF ---- CF
                                                                 |
                                                                 --- Info
                                                                 |
                                                                 --- Noise
                                                                 |
                                                                 --- SM

We found this structure useful to experiment since we can change only the EXPNAME to refer to another experiment without changing the original images. The folder Original contains the correctly classified (misclassified) images in Correct (Incorrect). We resume the structure of the counterfactuals explanations (Results/EXPNAME) as: (I/C)C: (In/correct) classification. (I/C)CF: (In/correct) counterfactual. CF: counterfactual images. Info: Useful information per instance. Noise: Noisy instance at timestep $\tau$ of the input data. SM: Difference between the input and its counterfactual. All files in all folders will have the same identifier.

Evaluation

We provide our evaluation protocol scripts to assess the performance of our method. All our evaluation codes use the folder structure presented before. Please look at the --help function flag for more information about their inputs.

  • FVA: compute_FVA.py.
  • MNAC: compute_MNAC.py.
  • $\sigma_L$: compute_LPIPS.py. Computes the variability metric.
  • CD: compute_CD.py. Computes our proposed metric, Correlation Difference.
  • FID: compute-fid.sh. The first input is the OUTPUT_PATH and the second one the EXPNAME.

Training the DDPM model from scratch

We provided a bash script to train the DDPM to generate the counterfactual explanations: celeba-train.sh. Nevertheless, the syntax to run the code base is:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
TRAIN_FLAGS="--batch_size 15 --lr 1e-4 --save_interval 30000 --weight_decay 0.05 --dropout 0.0"
mpiexec -n N python celeba-train-diffusion.py $TRAIN_FLAGS \
                                              $MODEL_FLAGS \
                                              --output_path OUTPUT_FOLDER \
                                              --gpus GPUS

Citation

If you found useful our code, please cite our work.

@inproceedings{Jeanneret_2022_ACCV,
    author    = {Jeanneret, Guillaume and Simon, Lo\"ic and Fr\'ed\'eric Jurie},
    title     = {Diffusion Models for Counterfactual Explanations},
    booktitle = {Proceedings of the Asian Conference on Computer Vision (ACCV)},
    month     = {December},
    year      = {2022}
}

Code Base

We based our repository on openai/guided-diffusion.

About

Official Code for the ACCV 2022 paper Diffusion Models for Counterfactual Explanations

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published