Skip to content

SrishtiGautam/ProtoVAE

Repository files navigation

ProtoVAE

protovae

The official PyTorch implementation of "ProtoVAE: A Trustworthy Self-Explainable Prototypical Variational Model" (NeurIPS 2022, https://nips.cc/Conferences/2022/ScheduleMultitrack?event=53023) by Srishti Gautam, Ahcene Boubekki, Stine Hansen, Suaiba Amina Salahuddin, Robert Jenssen, Marina MC Höhne, Michael Kampffmeyer.

The code is built upon ProtoPNet's official implementation (https://github.com/cfchen-duke/ProtoPNet) and LRP implementation from https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet.

Setup

Install a new conda environment

conda env create -f requirements.yml
conda activate protovae

Usage

Settings for hyperparameters can be changed in settings.py corresponding to different datasets.

Command example: python main.py -data=mnist -mode=test -model_file=saved_models/MNIST/mnist.pth -expl=True
(1) -data: name of the dataset. Supports mnist, fmnist, quickdraw, cifar10 and svhn.
(2) -mode: train/test for training or testing mode respectively.
(3) -model_file: path of the saved model, only required if -mode=test. Produces test accuracy as well as visualization of class prototypes.
(4) -expl: True/False, set True for generating local layer-wise relevance based prototypical explanation maps for 100 test images. Only works when -mode=test.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages