Skip to content

Latest commit

 

History

History

finite_ntk

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Transfer Learning via Linearized Neural Networks

This repository contains a GPyTorch implementation of finite width neural tangent kernels from the paper (link)

Fast Adaptation with Linearized Neural Networks

by Wesley Maddox, Shuai Tang, Pablo Garcia Moreno, Andrew Gordon Wilson, and Andreas Damianou,

which appeared at AISTATS 2021. Please note that this is a revised and expanded version of the workshop paper On Transfer Learning with Linearised Neural Networks, which appeared at the 3rd MetaLearning Workshop at NeurIPS, 2019.

Introduction

Please cite our work if you find it useful:

@inproceedings{maddox2021fast,
  title={Fast Adaptation with Linearized Neural Networks},
  author={Maddox, Wesley and Tang, Shuai and Moreno, Pablo and Wilson, Andrew Gordon and Damianou, Andreas},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={2737--2745},
  year={2021},
  organization={PMLR}
}

Installation:

python setup.py develop

See requirements.txt file for requirements that came from our setup. We use Pytorch 1.3.1 and Python 3.6+ in our experiments.

Unless otherwise described, all experiments were run on a single GPU.

Minimal Example

import torch
import gpytorch
import finite_ntk

data = torch.randn(300, 1)
response = torch.randn(300, 1)

# randomly initialize a neural network
model = torch.nn.Sequential(torch.nn.Linear(1, 30), 
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm1d(),
                            torch.nn.Linear(30, 1))

class ExactGPModel(gpytorch.models.ExactGP):
    # exact RBF Gaussian process class
    def __init__(self, train_x, train_y, likelihood, model):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = finite_ntk.lazy.NTK(
            model=model)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

gp_lh = gpytorch.likelihoods.GaussianLikelihood()
gp_model = ExactGPModel(data, response, gp_lh, model)

# draw a sample from the GP with kernel given by Jacobian of model
zeromean_pred = gp_lh(gp_model(data)).sample()

References for Code Base

GPyTorch: Pytorch repo; this is the origin of the codebase.

Adam Paszke's gist for the Rop

We'd like to thank Max Balandat for providing us a cleaned version of the malaria data files from Balandat et al, 2019 and Jacob Gardner and Marc Finzi for help with the Fisher vector products.

The Malaria Global Atlas data file can be downloaded at: https://wjmaddox.github.io/assets/data/malaria_df.hdf5

Authors

Wesley Maddox