Skip to content

zwang84/zsdb3kd

Repository files navigation

Zero-Shot Knowledge Distillation from a Decision-Based Black-Box Model

Introduction

This is the code and data associated with our ICML 2021 paper.

https://icml.cc/virtual/2021/poster/10257

https://arxiv.org/abs/2106.03310

Requirements

NumPy
PyTorch (tested on 1.9.1)
Torchvision

Code structure

.
├── data                                   # datasets downloaded or saved here
├── generated_samples           # generated pseudo samples saved here
├── labels                                 # generated soft labels saved here
├── models                              # trained teacher models saved here
├── train_model_ce.py             # standard training (a teacher) with cross-entropy loss
├── models.py                          # all model definitions and wrapper for sample robustness calculation
├── get_soft_labels.py              # calculate soft labels with sample robustness
├── sample_robustness.py       # methods for calculating sample robustness
├── train_model_kd.py             # training with KD
├── get_pseudo_samples.py    # generate pseudo samples with ZSDB3KD
├── untargeted_mbd.py           # calculate the untargeted distances from a noise input to boundary
├── README.MD                      # readme file

Usage

1. Train a teacher model in a standard way.

Train a LeNet5 teacher with the MNIST dataset:

python train_model_ce.py --mode teacher --dataset MNIST --architecture LeNet5

Train a LeNet5 teacher with the FashionMNIST dataset:

python train_model_ce.py --mode teacher --dataset FASHIONMNIST --architecture LeNet5

Train a AlexNet teacher with the CIFAR10 dataset:

python train_model_ce.py --mode teacher --dataset CIFAR10 --architecture AlexNet

PS: train_model_ce.py can also be used for training/evaluating the student models (e.g., LeNet5-half, LeNet5-fifth, etc.) with the cross-entropy loss only.

2. Construct soft labels by calculating sample robustness with the pre-trained teacher models.

sd: sample distance; bd: boundary distance; mbd: minimal boundary distance

LeNet-5 with MNIST:

python get_soft_labels.py --dataset MNIST --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_LeNet5_MNIST

LeNet-5 with FashionMNIST:

python get_soft_labels.py --dataset FASHIONMNIST --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_LeNet5_FASHIONMNIST

AlexNet with CIFAR10:

python get_soft_labels.py --dataset CIFAR10 --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_AlexNet_CIFAR10

3. Train a student model with KD, using the generated soft labels

python train_model_kd.py --dataset {MNIST/FASHIONMNIST/CIFAR10} --mode {small/tiny} --logits PATH_OF_SAVED_LOGITS

4. Generate pseudo samples (ZSDB3KD)

python get_pseudo_samples.py --dataset MNIST --batch_size 200 --model_dir PATH_OF_SAVED_TEACHER_MODEL

The generated pseudo samples can be used for getting the soft labels with the 2nd and 3rd steps to test ZSDB3KD.

Citation

If you found this code useful, please consider citing the following work. Thank you!

@inproceedings{wang2021zero,
  title={Zero-shot knowledge distillation from a decision-based black-box model},
  author={Wang, Zi},
  booktitle={International Conference on Machine Learning},
  pages={10675--10685},
  year={2021},
  organization={PMLR}
}

About

Knowledge distillation (KD) from a decision-based black-box (DB3) teacher without training data.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages