Skip to content

wwoods/adversarial-explanations-cifar

Repository files navigation

This code demonstrates the techniques from the above paper, a pre-print of which is available on ArXiv. Note that this was not the exact code used in the research, but is a cleaned-up reproduction of the paper's key insights.

Installation

From scratch without a Python environment, installation takes 10-20 minutes. With Python already installed, installation takes only a few minutes.

Install PyTorch, torchvision, and click, potentially via Miniconda with Python 3:

$ conda install -c pytorch pytorch torchvision
$ pip install click

Code was tested with:

  • Python 3.6
  • PyTorch 1.1 + torchvision 0.2.2
  • click 7.0

Any operating system supporting the above libraries should work; we tested using Ubuntu 18.04.

An NVIDIA GPU is not required, but one or more GPUs will greatly accelerate network training.

Usage

This repository contains several pre-built networks, corresponding with the CIFAR-10 networks highlighted in the paper.

The application has two modes: explaining a trained model, and training a model from scratch.

When running the application, the CIFAR-10 dataset will be automatically downloaded via the torchvision library; the desired download location for the CIFAR-10 data must be specified via the environment variable CIFAR10_PATH.

Prebuilt Networks

The repository contains four prebuilt networks:

  1. prebuilt/resnet44-standard.pt: A standard ResNet-44 with no special training.
  2. prebuilt/resnet44-adv-train.pt: A ResNet-44 trained with --adversarial-training.
  3. prebuilt/resnet44-all.pt: A ResNet-44 trained with --robust-additions, --adversarial-training, and --l2-min.
  4. prebuilt/resnet44-robust.pt: A ResNet-44 trained with --robust-additions.

These correspond with, but are not the same as, the networks denoted N1, N2, N3, and N4 in the paper. The training of these networks resulted in the following statistics:

Network Final Training Loss Final Psi Test Accuracy Attack ARA BTR ARA Ship -> Explain     Frog             Cat              Automobile
resnet44-standard.pt 0.0075 N/A 0.9384 0.0013 0.0015 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-adv-train.pt 0.5313 N/A 0.8643 0.0100 0.0157 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-all.pt 1.4799 14240 0.679 0.0188 0.0414 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain
resnet44-robust.pt 1.4799 33778 0.6758 0.0142 0.0395 ShipShipExplain FrogFrogExplain CatCatExplain AutomobileAutomobileExplain

See the paper or the "github-prebuilt-images" command in main.py for additional information on the above table and its images.

Calculate ARA

Attack and BTR ARAs may be calculated via the calculate-ara command. For example, to use a pre-built network with both adversarial training and the robustness additions from the paper:

$ python main.py calculate-ara prebuilt/resnet44-all.pt [--n-images 1000] [--eps 20] [--steps 450] [--momentum 0.9]

Note that arguments in [brackets] are optional. This produces textual output which indicates the calculated attack and BTR ARAs as per Section III.A of the paper. The resulting ARAs for all prebuilt networks are demonstrated in the table above. Calculating both ARAs as in the original paper (default settings) takes around 30 minutes per network, depending on GPU.

Explain

To generate explanations on the first 10 CIFAR-10 testing examples with a trained network, use the explain command. For example, to use a pre-built network with both adversarial training and the robustness additions from the paper:

$ python main.py explain prebuilt/resnet44-all.pt [--eps 0.1]

This will create images in the output/ folder, designed to be viewed in alphabetical order. For example, output/0-cat will contain _input.png, the unmodified input image; _real_was_xxx.png, an explanation using g_{explain+} from the paper on the real class (cat); _second_dog_was_xxx.png, an explanation using g_{explain+} on the most confident class that was not the correct class; and 0_airplane_was_xxx.png, 1_automobile_was_xxx.png, 2_bird_was_xxx.png, ..., 9_truck_was_xxx.png, an explanation targeted at each class of CIFAR-10 as indicated in the filename. In all cases, the _xxx preceding the .png extension indicates the post-softmax confidence of that class on the original image. The images look like this:

                   
_input Input image _real Real target _second Second target
0_airplane Airplane 1_automobile Automobile 2_bird Bird 3_cat Cat 4_deer Deer
5_dog Dog 6_frog Frog 7_horse Horse 8_ship Ship 9_truck Truck

Note that arguments in [brackets] are optional. --eps X specifies that the adversarial explanations should be built with rho=X. The process could be further optimized, but presently takes a minute or two.

Train

To train a new network:

$ python main.py train path/to/model.pt [--adversarial-training] [--robust-additions] [--l2-min]

See python main.py train --help for additional information on these options.

Training time varies greatly based on available GPU(s). With both adversarial training and the robustness additions from the paper, training can take up to several days on a single computer. Turning off either adversarial training or robustness additions would lead to a significant speedup.

At the top of the main.py file are many CAPITAL_CASE variables which may be modified to affect the training process. Their definitions match those in the paper.

About

Code example for the paper, "Adversarial Explanations for Understanding Image Classification Decisions and Improved Neural Network Robustness."

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages