Skip to content

nonconvexopt/torch_variational

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

84 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch_variational

Pytorch implementation and Wapper classes for torch.nn.modules layers of:

  • Flipout[1]
  • Local Reparameterization Trick[2]

Dependencies

Pytorch >= 1.0.0

Available modules

Variationalizers - Wrapper for nn.module class(Supports [Lazy|Standard][Linear|Convolutional] layers.)

  • Flipout Wrapper
  • Local Reparameterization Wrapper

Stand-alone Flipout layers:

  • Conv2d_flipout
  • Linear_flipout

Install

Clone this repo and run:

pip install -e torch_variational

Usage

Example usage for wrapper classes:

from torch_variational.wrapper import Variational_Flipout, Variational_LRT
Flipout_layer = Variational_Flipout(nn.Linear(in_features = 10, out_features = 10, bias = True))
LRT_layer = Variational_LRT(nn.Linear(in_features = 10, out_features = 10, bias = True))

Flipout_output = Flipout_layer(torch.randn(1, 10))
LRT_output = LRT_layer(torch.randn(1, 10))

Flipout_kld = Flipout_layer.kld()
LRT_kld = LRT_layer.kld()

Example usage for Stand-alone Flipout layers:

from torch_variational.wrapper import Variational_Flipout
layer = flipout.Linear_flipout(in_features = 10, out_features = 10, bias = True)
output, kld = layer(torch.randn(1, 10))

Derivations

Assumed weight multiplicative variances.

Flipout

formula
formula
formula

Local Reparameterization Trick

formula
formula
formula

References

[1] @inproceedings{DBLP:conf/iclr/WenVBTG18,
  author    = {Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger B. Grosse},
  title     = {Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches},
  booktitle = {6th International Conference on Learning Representations, {ICLR} 2018,
               Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings},
  year      = {2018},
  url       = {https://openreview.net/forum?id=rJNpifWAb}
}
[2] @inproceedings{NIPS2015_bc731692,
 author = {Kingma, Durk P and Salimans, Tim and Welling, Max},
 title = {Variational Dropout and the Local Reparameterization Trick},
 booktitle = {Advances in Neural Information Processing Systems},
 volume = {28},
 year = {2015}
 url = {https://proceedings.neurips.cc/paper/2015/file/bc7316929fe1545bf0b98d114ee3ecb8-Paper.pdf},
}