Skip to content

aryanasadianuoit/Distilling-Knowledge-via-Intermediate-Classifiers

Repository files navigation

Distilling Knowledge via Intermediate Classifiers

Table of Contents:

Introduction

Distilling Knowledge via Intermediate Classifiers (DIH) is a knowledge distillation framework that mitigates the negative impact of the capacity gap, i.e., the difference in model complexity between the teacher and the student model on knowledge distillation. This approach improves the canonical knowledge distillation (KD) with the help of the teacher's intermediate representations (the outputs of some of the hidden layers).

DIH training pipeline

  1. First, k classifier heads have to be mounted to various intermediate layers of the teacher (see Table 1 for the structure of models, i.e., the location and also the value of k in this repository).
  2. The added intermediate classifier heads pass a cheap and efficient fine-tuning (while the main teacher is frozen). The fine-tuning step is cheaper and more efficient than training a whole model (i.e., a fraction of the teacher model and the added intermediate classifier head module) from scratch. This is due to the frozen state of the backbone of the model, i.e., only the added intermediate head needs to be trained.
  3. The cohort of classifiers (all the mounted ones + the final main classifier) co-teach the student simultaneously with knowledge distillation.
Our experiments on various teacher-student pairs of models and datasets have demonstrated that the proposed approach outperforms the canonical KD and its extensions.


Distilling Knowledge via Intermediate Classifiers (DIH)

Requirements

  • torch 1.7.1 the project is built in PyTorch.
  • torchvision 0.8.2 used for datasets, and data preprocessing.
  • tqdm 4.48.2 for better visualization of training process.
  • torchsummary for invesitating the model's architecture.
  • numpy 1.19.4 used in preprocessing the dataset, adn showing examples.
  • argparse passsing the input variables for easy reproducibility.
  • os reading and writing the trained model's weights.
pip3 install -r requirements.txt

Datasets

CIFAR-10 and CIFAR-100 contain 32x32 pixel RGB images for 10 and 100 classes, respectively. The datasets are composed of 50,000 training and 10,000 testing images. All training and testing datasets are balanced (i.e., the number of images per class is the same within the dataset). The images are augmented by combining horizontal flips, 4 pixels padding, and random crops for these two datasets. We also normalized the images by their mean and standard deviation.

Tiny-ImageNet contains 64x64 pixel RGB images for 200 classes, subsampled from the ImageNet dataset. The dataset is composed of 100,000 training and 10,000 testing images. All training and testing datasets are balanced (i.e., the number of images per class is the same within the dataset). We followed data augmentation techniques similar to CIFAR-10 and CIFAR-100, i.e., the images are augmented by the combination of horizontal flips, 4 pixels padding, and random 64-pixel crops. We also normalized the images by their mean and standard deviation.

Baselines

  1. Canonical Knowledge Distillation (KD)
  2. As one of the benchmarks, we use conventional KD (in the context and the experiments, we have referred to canonical knowledge distillation as KD). We used the same temperature (τ=5) and the same alpha weight (α=0.1) as DIH.
  3. FitNets
  4. FitNets, as a knowledge distillation framework, first transfers the knowledge of a fraction of a trained teacher model up to a selected layer (known as hint layer) to a fraction of a student model up to a selected intermediate layer (called guided layer). This step optimizes the chosen fraction of the student by using the L2 loss objective. The second step of FitNets is the canonical knowledge distillation (KD) to transfer the knowledge from the complete teacher to the entire student. We trained the selected fraction of the student for 40 epochs using the L2 loss function for the first step. In the second step, we used the same KD setting and trained the complete student model for 200 epochs.
  5. Knowledge Distillation with Teacher Assistants (TAKD)
  6. We limited the number of teacher assistants to 1 for experiments in Table 3 of the paper. The setting for training the teacher assistant and the final student is identical (the same setting for KD).
  7. Attention Distillation (AT)
  8. AT transfers the teacher's attention maps (i.e., channel-wise averaged activation maps) to the student's equivalent layer.
  9. Contrastive Representation Distillation (CRD)
  10. CRD improves canonical KD using contrastive learning. The loss objective maximizes the teacher-student mutual information's lower bound. Using this framework, the student learns to generate feature maps close to each other for positive sample pairs and increases the distance between the representations for negative pairs.
  11. Task-Oriented Feature Distillation (TOFD)
  12. Like our approach, TOFD tries to improve the canonical KD with the help of intermediate classifier heads. However, TOFD equips both the teacher and the student with very deep and complex classifier modules containing multiple convolutional, batch normalization, and fully connected layers. Each classifier module resembles the rest of the teacher backbone architecture after the attachment location up to the end of the model, e.g., Consider a residual model with four residual stages; The classifier module attached to the first residual stage would comprise three remaining residual blocks followed by the fully connected layer at the end. Besides different classifier architectures, TOFD also uses a different set of loss objectives. Each student classifier is optimized using regular CE, canonical KD using soft probabilities generated by the teacher's same-stage classifier, L2 loss objective to match same-stage intermediate representations, and the orthogonal loss for information loss reduction(only applied to feature resizing layers).
  13. Multi-head Knowledge Distillation for Model Compression (MHKD)
  14. MHKD is a similar approach to ours, while in MHKD, similar to TOFD, both teacher and the student are equipped with multiple classifier heads that contain convolutional, batch normalization, and ReLU, followed by a fully connected layer at the end. However, MHKD uses a fixed architecture for classifier modules, containing two convolutional layers with batch normalization and ReLU, followed by two fully connected layers. In contrast, we have used simpler classifier modules by only using fully connected layers as intermediate classifiers. MHKD and TOFD also differ in their loss function. MHKD optimizes the student's classifier heads using regular CE and canonical KD with same-stage teacher classifier's soft labels.

Running The Experiments

  • First, the selected teacher model should be trained with regular cross-entropy with the hyper-parameters mentioned in Table 2.
  • All the reported values in the paper experiments are the average of three different runs.
  • For each selected teacher, several mounted intermediate classifier heads need to be fine-tuned. The number of added intermediate heads for each model is available in the following table. In this repository, we have mounted an intermediate classifier head after every group of residual and/or bottleneck blocks in ResNet family models and after each max-pooling layer for VGG-11 model (Note: the VGG model has been equipped with batch normalization).

    Table 1. The number of mounted intermediate classifier heads to the models used in this repository.
    Teacher Model # Intermediate heads (k)
    ResNet-34 4
    ResNet-18 4
    VGG-11 4
    WR-28-2 3
    ResNet-110 3
    ResNet-20 3
    ResNet-14 3
    ResNet-8 3

  • Files in this repository

    • dataload.py loads the data loader for training, validation, and testing for both datasets (CIFAR10-CIFAR100).
    • models_repo contains model classes(two categories of ResNets, VGG, and also the intermediate classifier module).
    • KD_Loss.py canonical knowledge distillation loss function.
    • dih_utlis.py includes the function for loading the trained intermediate heads.
    • train_dih.py contains the function for distillation via intermediate heads (DIH).
    • train_funcs.py regular cross-entropy training, and intermediate header's fine_tuning functions.
    • CRD/train_student.py train function for Contrastive distillation (CRD).
    • TOFD/tofd_train.py train function for Task-oriented feature distillation (TOFD).
    • MHKD/mhkd_training.py train function for Multi-head knowledge distillation (MHKD).
    • test.py testing console for running the functions above.

    Hyper-parameters to set


    Table 2. list of all hyper-parameters, their argparse tags, and their assigned default values.
    Hyper-parameter args tag Default value
    student model student res8
    teacher model teacher res110
    learning rate lr 0.1
    weight decay wd 5e-4
    epochs epochs 200
    dataset dataset cifar100
    schedule schedule [60,120,180]
    γ schedule_gamma 0.1
    temperature τ (KD) kd_temperature 5
    α (KD) kd_alpha 0.1
    batch size batch_size 64
    training type training_type dih
    seed seed [30,50,67]

    Example

    For training a model with regular cross-entropy the following template should be run:
    python3 final_test.py --training_type ce --teacher res110 --path_to_save /home/teacher.pth --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3

    By having a trained teacher, we need to fine_tune all of its intermediate classifier heads by running the following command:
    python3 final_test.py --training_type fine_tune --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/headers --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3

    For evaluation, we used FitNets. To train a student with this approach, this template should be runned:
    python3 final_test.py --student res8 --training_type fitnets --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/stage_1.pth --epochs_fitnets_1 40 --nesterov_fitnets_1 True --momentum_fitnets_1 0.9 --lr_fitnets_1 0.1 --wd_fitnets_1 0.0005 --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3

    The canonical knowledge distillation (KD) is available through the following command:
    python3 final_test.py --student res8 --training_type kd --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --kd_alpha 0.1 --seed 3 --kd_temperature 5

    The student Resnet-8 can be trained via CRD through the following command:
    python3 train_student.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5

    The student Resnet-8 can be trained via TOFD through the following command:
    python3 tofd_train.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5

    The student Resnet-8 can be trained via MHKD through the following command:
    python3 train_mhkd.py --student res8 --teacher res110 --saved_path /home/teacher.pth --path_to_save /home/res8_kd.pth --batch_size 128 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --alpha 0.1 --beta 0.03 --temperature 5

    The student Resnet-8 can be trained via DIH through the following command:
    python3 final_test.py --student res8 --teacher res110 --saved_path /home/teacher.pth --saved_intermediates_directory /home/saved_headers/ --alpha 0.1 --temperature 5 --batch_size 64 --dataset cifar100 --epochs 200 --gpu_id cuda:0 --lr 0.1 --schedule 60 120 180 --wd 0.0005 --seed 3 --path_to_save /home/dih_model.pth

    Reference

    arxiv link: http://arxiv.org/abs/2103.00497
    If you found this library useful in your research, please consider citing:

    @misc{asadian2021distilling,
    title={Distilling Knowledge via Intermediate Classifier Heads},
    author={Aryan Asadian and Amirali Salehi-Abari},
    year={2021},
    eprint={2103.00497},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
    }