Skip to content

AditMeh/Distillation

Repository files navigation

You can run experiments by doing:

python grid_search.py <config_path> <weights_path>

For example:

python grid_search.py grid_search_configs/baseline.json weights/Teacher_network_val_loss0.00011

To create your own configs, see grid_search_configs/baseline.json for a reference.

Interesting results

This performance table is for a model trained on classes 0-5, tested on classes 0-9. The student network (all dense layers), has the following architecture:

28*28 input -> layer 1 -> 14*14 output -> Relu -> Layer 2 -> (*, 10) output


lr T Epochs Weight 0 1 2 3 4 5 6 7 8 9
0.01 10.0 70 0.9 99.49% 99.73% 99.42% 99.60% 99.6% 99.1% 98.64% 96% 97.33% 97.23%

As you can see above, the student model is able to learn how to classify images of a class that it did not see during training. Through these experiments, I have found that temperature values between 8-12 work the best for this task, along with a learning rate of 0.01 as 0.001 is too small.


The below table shows the results for training the student network on all classes in MNIST (0-9). The first row represents a weight of 0 given to the soft-label loss term, which basically means a standard training loop for the student network using CE loss. The second row represents a student net trained on a mix of both the soft labels and hard labels using a temperature of 2.5. Both networks are trained with the same LR and epochs. As you can see, the accuracies on the net trained with the distillation loss is lower than that of the net trained with a regular CE loss.

lr T epochs weight best_loss 0 1 2 3 4 5 6 7 8 9
0.01 1.0 70 1.0 0.04663 99.39% 98.77% 97.77% 99.01% 97.56% 96.97% 97.7% 97.86% 96.2% 96.53%
0.01 2.5 70 0.1 0.04606 99.49% 99.38% 98.84% 99.31% 99.19% 98.54% 99.06% 99.03% 99.28% 97.62%

I also tested the above experiments using the smallest possible student model with the following architecture:

28*28 -> layer 1 -> (*, 10) output

However, this resulted in poor accuracies in the 50% range for each of the unseen classes, which still is better than the baseline which gets accuracies of 0.0% for all unseen classes during training. But this also shows that this small student model is not capable of learning from the implicit labels, so I chose a larger architecture that would be able to fit to my data (and maybe even overfit!)

So far:

  • Implement distillation loss using torch 1.10 soft-label cross-entropy loss (Kullback-Leibler is the next best option)
  • Teacher and student models setup, along with their respective training loops
  • Basic visualization of training curves.
  • Custom MNIST dataloaders that leave out certain classes (to test implicit learning from soft labels)
  • Train teacher network. 14-16 errors on the val set.
  • Grid searcher
  • Abalation tests for regular classification performance with and without soft labels. Testing for changes in accuracy.
  • Test implicit learning from soft labels (IPR)

To do:

  • Derive dC/dz for CE on softmax with temperature
  • Try on CIFAR10
  • Related to above, add flags for easy switching between datasets in the argparser

References:

About

Implementing and experimenting with knowledge distillation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages