Skip to content

dbaranchuk/memory-efficient-maml

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

65 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Memory Efficient MAML

Overview

PyTorch implementation of Model-Agnostic Meta-Learning[1] with gradient checkpointing[2]. It allows you to perform way (~10-100x) more MAML steps with the same GPU memory budget.

Install

For normal installation, run pip install torch_maml

For development installation, clone a repo and python setup.py develop

How to use:

See examples in example.ipynb

Open In Colab

Tips and tricks

  1. Make sure that your model doesn't have implicit parameter updates like torch.nn.BatchNorm2d under track_running_stats=True. With gradient checkpointing, these updates will be performed twice (once per forward pass). If still want these updates, take a look at torch_maml.utils.disable_batchnorm_stats. Note that we already support this for vanilla BatchNorm{1-3}d.

  2. When computing gradients through many MAML steps (e.g. 100 or 1000), you should care about vanishing and exploding gradients within optimizers (same as in RNN). This implementation supports gradient clipping to avoid the explosive part of the problem.

  3. Also, when you deal with a large number of MAML steps, be aware of accumulating computational error due to float precision and specifically CUDNN operations. We recommend you to use torch.backend.cudnn.determistic=True. The problem appears when gradients become slightly noisy due to errors, and, during backpropagation through MAML steps, the error is likely to increase dramatically.

  4. You could also consider Implicit Gradient MAML [3] for memory efficient meta-learning alternatives. While this algorithm requires even less memory, it assumes that your optimization converges to the optimum. Therefore, it is inapplicable if your task does not always converge by the time you start backpropagating. In contrast, our implementation allows you to meta-learn even from a partially converged state.

References

[1] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

[2] Gradient checkpointing technique (GitHub)

[3] Meta-Learning with Implicit Gradients