Skip to content

UMBCvision/Explainable-Models-with-Consistent-Interpretations

Repository files navigation

Explainable Models with Consistent Interpretations

Official PyTorch implementation for the AAAI 2021 paper 'Explainable Models with Consistent Interpretations'

Given the widespread deployment of black box deep neural networks in computer vision applications, the interpretability aspect of these black box systems has recently gained traction. Various methods have been proposed to explain the results of such deep neural networks. However, some recent works have shown that such explanation methods are biased and do not produce consistent interpretations. Hence, rather than introducing a novel explanation method, we learn models that are encouraged to be interpretable given an explanation method. We use Grad-CAM as the explanation algorithm and encourage the network to learn consistent interpretations along with maximizing the log-likelihood of the correct class. We show that our method outperforms the baseline on the pointing game evaluation on ImageNet and MS-COCO datasets respectively. We also introduce new evaluation metrics that penalize the saliency map if it lies outside the ground truth bounding box or segmentation mask, and show that our method outperforms the baseline on these metrics as well. Moreover, our model trained with interpretation consistency generalizes to other explanation algorithms on all the evaluation metrics.

To encourage interpretation consistency, we randomly sample three distractor images for a given input image and feed all four to the Image Gridding module which creates a 2 × 2 grid and places the positive image and the three negative images in random cells. We also feed the Grad-CAM interpretation mask for the ground truth category (‘Keyboard’) to the Image Gridding to obtain the ground truth Grad-CAM mask of this composite image for the positive image category. The negative quadrants of this mask are set to zero. We then penalize the network if the Grad-CAM heatmap of the composite image for the positive image category does not match the ground truth Grad-CAM mask.

Pre-requisites

  • Pytorch 1.3 - Please install PyTorch and CUDA if you don't have it installed.
  • pycocotools

Training

ImageNet

Following code can be used to train a ResNet 18 model using our Grad-CAM consistency method -
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_gcam_grid_consistency.py <path_to_imagenet_dataset> -a resnet18 -b 256 -j 16 --lambda 25 --save_dir <path_to_checkpoint_dir>

Following code can be used to train a ResNet 18 model with Global Max Pooling instead of Global Average Pooling along with our Grad-CAM consistency method -
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_gcam_grid_consistency.py <path_to_imagenet_dataset> -a resnet18 -b 256 -j 16 --lambda 25 --maxpool --save_dir <path_to_checkpoint_dir>

MS-COCO

Since MS-COCO dataset is a multi-class dataset, we randomly select one of the ground truth categories to compute the Grad-CAM heatmap for the original image and the composite image. Hence, we perform a pre-processing to extract a dictionary containing a list of negative images corresponding to each ground truth category. We used the script extract_negative_image_list.py to create this dictionary and use it in the COCO dataloader to create the composite images.

Following code can be used to train a ResNet 18 model using our Grad-CAM consistency method -
CUDA_VISIBLE_DEVICES=0,1 python train_gcam_multiclass_grid_consistency.py <path_to_coco_dataset> -a resnet18 --num-gpus 2 --lr 0.01 -b 256 -j 16 --lambda 1 --resume <path_to_imagenet_pretrained_model_checkpoint> --save_dir <path_to_checkpoint_dir>

Following code can be used to train a ResNet 18 model with Global Max Pooling instead of Global Average Pooling along with our Grad-CAM consistency method -
CUDA_VISIBLE_DEVICES=0,1 python train_gcam_multiclass_grid_consistency.py <path_to_coco_dataset> -a resnet18 --num-gpus 2 --lr 0.01 -b 256 -j 16 --lambda 1 --maxpool --resume <path_to_imagenet_pretrained_model_checkpoint> --save_dir <path_to_checkpoint_dir>

Evaluation

We use the evaluation code adapted from the TorchRay framework. For the SPG metric introduced in our paper, we use a stochastic version of the pointing game metric to sample 100 points from the 2D map of the normalized Grad-CAM interpretation heatmap and evaluate using the bounding box annotation for ImageNet validation set.

  • Change directory to TorchRay and install the library. Please refer to the TorchRay repository for full documentation and instructions.

    • cd TorchRay
    • python setup.py install
  • Change directory to TorchRay/torchray/benchmark

    • cd torchray/benchmark

For the ImageNet dataset, this evaluation requires the following structure for ImageNet validation images and bounding box xml annotations

  • imagenet_root/val/*.JPEG - Flat list of 50000 validation images
  • imagenet_root/val/*.xml - Flat list of 50000 annotation xml files

Evaluation metrics for Interpretation Consistency:

  1. Pointing Game:
    CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_pointinggame.py -j 0 -b 1 --resume --input_resize 448
  2. Stochastic Pointing Game:
    CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_stochastic_pointinggame.py -j 0 -b 1 --resume --input_resize 448
  3. Content Heatmap:
    CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_energy_inside_bbox.py -j 0 -b 1 --resume --input_resize 448

Results

ImageNet

Architecture Model Name                 Top-1 Acc (%) Pointing Game Stochastic Pointing Game Content Heatmap Pre-trained
AlexNet Baseline 56.51 72.80 53.45 45.78 checkpoint
Ours: GC 56.16 73.70 61.15 48.10 checkpoint
ResNet18 Baseline 69.43 79.80 60.50 54.36 checkpoint
Ours: GC 67.74 80.00 65.85 57.73 checkpoint
GMP 69.08 79.30 66.66 62.89 checkpoint
Ours: GMP + GC 69.02 79.60 68.74 65.35 checkpoint
ResNet50 Baseline 76.13 80.00 60.95 54.78 checkpoint
Ours: GC 74.40 80.30 65.26 59.42 checkpoint
GMP 74.63 79.80 66.29 54.23 checkpoint
Ours: GMP + GC 74.14 79.60 69.51 59.70 checkpoint

MS-COCO

Architecture Model Name                 F1-PerClass (%) F1-Overall (%) Pointing Game Stochastic Pointing Game Content Heatmap Pre-trained
ResNet18 Baseline 69.43 69.43 79.80 60.50 54.36 checkpoint
Ours: GC 67.74 69.43 80.00 65.85 57.73 checkpoint
GMP 69.08 69.43 79.30 66.66 62.89 checkpoint
Ours: GMP + GC 69.02 69.43 79.60 68.74 65.35 checkpoint

Citation

If you find our paper, code or models useful, please cite us using:

@article{Pillai_Pirsiavash_2021, 
   title={Explainable Models with Consistent Interpretations}, 
   volume={35}, 
   url={https://ojs.aaai.org/index.php/AAAI/article/view/16344}, 
   number={3}, 
   journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 
   author={Pillai, Vipin and Pirsiavash, Hamed}, 
   year={2021}, 
   month={May}, 
   pages={2431-2439} }

Acknowledgement

We would like to thank Ashley Rothballer and Dennis Fong for helpful disucssions regarding this work. This material is based upon work partially supported by the United States Air Force under Contract No. FA8750-19-C-0098, funding from NSF grant number 1845216, SAP SE, and Northrop Grumman. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of the United States Air Force, DARPA, or other funding agencies.

License

This project is licensed under the MIT license.