Skip to content

lucfra/FAR-HO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FAR-HO

Gradient-based hyperparameter optimization and meta-learning package based on TensorFlow

This is the new package that implements the algorithms presented in the paper Forward and Reverse Gradient-Based Hyperparameter Optimization. For the older package see RFHO. FAR-HO features simplified interfaces, additional capabilities and a tighter integration with tensorflow.

  • Reverse hypergradient (ReverseHG), generalization of algorithms presented in Domke [2012] and MacLaurin et Al. [2015] (without reversable dynamics and "reversable dtype"); including the "truncated reverse version" ReverseHG.truncated, see Truncated Back-propagation for Bilevel Optimization
  • Forward hypergradient (ForwardHG)
  • Online versions of the two previous algorithms: Real-Time HO (RTHO) and Truncated-Reverse HO (TRHO)
  • Implicit differentiation (ImplicitHG), which can be used to implement HOAG algorithm [Pedregosa, 2016]

These algorithms algorithms compute, with different procedures, the (approximate) gradient of an outer objective such as a validation error with respect to the outer variables (e.g. hyperparameters). We call the gradient of the outer objective hypergradient. The "online" algorithms may perform several updates of the outer variables before reaching the final iteration, and are in general are much faster then their "batch" version. This procedure is linked to warm restart for solving the inner optimizaiton problem, but the hypergradient is, in general, biased.

IMPORTANT NOTE: This is not a plug-and-play hyperparameter optimizaiton package, but rather a research package that collects some useful methods that aim at simplifying the creation of experiments in gradient-based hyperparameter optimizaiton and related areas. With respect to other HPO packages, here a more specific problem structure is required. Furthermore, depending on the specific problem, the performance may be somewhat sensiteve to algorithmic parameters. As an important example, the inner optimizaion dynamics should not diverge in order for the hypergradients to yield useful informations [ Troubleshooting section coming soon! ].

NOTE II: In Italian FARO means beacon or lighthouse (so... no "H", but the "H" in Italian is silent!) .

alt text

These algorithms are useful also in meta-learning where parameters of various meta-learners effectively play the role of outer variables, as explained here in the workshop paper A Bridge Between Hyperparameter Optimization and Learning-to-learn. and Bilevel Programming for Hyperparameter Optimization and Meta-Learning

This package is also described in the workshop paper _ Far-HO: A Bilevel Programming Package for Hyperparameter Optimization and Meta-Learning_ presented at AutoML 2018 at ICML

Installation & Dependencies

Clone the repository and run setup script.

git clone git clone https://github.com/lucfra/FAR-HO.git
cd FAR-HO
python setup.py install

Beside "usual" packages (numpy), FAR-HO is built upon tensorflow. Some examples depend on the package experimet_manager while automatic dataset download (Omniglot) requires datapackage.

Please note that required packages will not be installed automatically.

Overview

Aim of this package is to implement and develop gradient-based hyperparameter optimization (HO) techniques in TensorFlow, thus making them readily applicable to deep learning systems. This optimization techniques find also natural applications in the field of meta-learning and learning-to-learn. Feel free to issues comments, suggestions and feedbacks! You can email me at luca.franceschi@iit.it .

Quick Start

Core Steps

  • Create a model1 with TensorFlow
  • Create the hyperparameters you wish to optimize2 with the function get_hyperparameter (which could be also variables of your model)
  • Define an inner objective (e.g. a training error) and an outer objective (e.g. a validation error) as scalar tensorflow.Tensor
  • Create an instance of HyperOptimizer after choosing an hyper-gradient computation algorithm among ForwardHG, ReverseHG and ImplicitHG (see next section)
  • Call the function HyperOptimizer.minimize specifying passing the outer and inner objectives, as well as an optimizer for the outer problem (which can be any optimizer form tensorflow) and an optimizer for the inner problem (which must be an optimizer contained in this package; at the moment gradient descent, gradient descent with momentum and Adam algorithms are available, but it should be quite straightforward to implement other optimizers, email me if you're interested!)
  • Execute HyperOptimizer.run(T, ...) function inside a tensorflow.Session, optimize inner variables (parameters) and perform a step of optimization of outer variables (hyperparameter).

Two scripts in the folder autoMLDemos showcase typical usage of this package

import far_ho as far
import tensorflow as tf

model = create_model(...)  

lambda1 = far.get_hyperparameter('lambda1', ...)
lambda1 = far.get_hyperparameter('lambda2', ...)
io, oo = create_objective(...)

inner_problem_optimizer = far.GradientDescentOptimizer(lr=far.get_hyperparameter('lr', 0.1))
outer_problem_optimizer = tf.train.AdamOptimizer()

farho = far.HyperOptimizer() 
ho_step = farho.minimize(oo, outer_problem_optimizer,
                     io, inner_problem_optimizer)

T = 100
with tf.Session().as_default():
  for _ in range(100):
    ho_step(T)    

1 This is gradient-based optimization and for the computation of the hyper-gradients second order derivatives of the training error show up (even tough no Hessian matrix is explicitly computed at any time); therefore, all the ops used in the model should have a second order derivative registered in tensorflow.

2 For the hyper-gradients to make sense, hyperparameters should be real-valued. Moreover, while ReverseHG should handle generic r-rank tensor hyperparameters, ForwardHGrequires scalars hyperparameters. Use the keyword argument scalar=True in get_hyperparameter for obtaining a scalr splitting of a general tensor.

Which Algorithm Do I Choose?

Forward and Reverse-HG compute the same hypergradient, so the choice is a matter of time versus memory!

alt text

The online versions of the algorithms can dramatically speed-up the optimization.

The Idea Behind: Hyperparameter Optimization

The objective is to minimize some validation function E with respect to a vector of hyperparameters lambda. The validation error depends on the model output and thus on the model parameters w. w should be a minimizer of the training error and the hyperparameter optimization problem can be naturally formulated as a bilevel optimization problem.
Since these problems are rather hard to tackle, we
explicitly take into account the learning dynamics used to obtain the model
parameters (e.g. you can think about stochastic gradient descent with momentum), and we formulate HO as a constrained optimization problem. See the paper for details.

New features and differences from RFHO

  • Simplified interface: optimize paramters and hyperparamters with "just" a call of far.HyperOptimizer.minimize, create variables designed as hyperparameters with far.get_hyperparameter, no more need to vectorize the model weights, far.optimizers only need to specify the update as a list of pairs (v, v_{k+1})
  • Additional capabilities: set an initalizaiton dynamics and optimize the (dsitribution) of initial weights, allowed explicit dependence of the outer objective w.r.t. hyperparameters, support for multiple outer objectives and multiple inner problems (episode batching, average the sampling from distributions, ...)
  • Tighter integration: collections for hyperparameters and hypergradients (use far.GraphKeys), use out-of-the-box models (no need to vectorize the model), use any TensorFlow optimizer for the outer objective (validation error)
  • Lighter package: only code for implementing the algorithms and running the examples
  • Forward hypergradient methods have been reimplemented with a double reverse mode trick, thanks to Jamie Townsend.

Citing

If you use this package please cite

@InProceedings{franceschi2017forward,
  title = 	 {Forward and Reverse Gradient-Based Hyperparameter Optimization},
  author = 	 {Luca Franceschi and Michele Donini and Paolo Frasconi and Massimiliano Pontil},
  booktitle = 	 {Proceedings of the 34th International Conference on Machine Learning},
  pages = 	 {1165--1173},
  year = 	 {2017},
  volume = 	 {70},
  series = 	 {Proceedings of Machine Learning Research},
  publisher = 	 {PMLR},
  pdf = 	 {http://proceedings.mlr.press/v70/franceschi17a/franceschi17a.pdf},
}
Works on meta-learning
@InProceedings{franceschi2018bilevel,
  title = 	 {Bilevel Programming for Hyperparameter Optimization and Meta-learning},
  author = 	 {Luca Franceschi and Paolo Frasconi and Saverio Salzo and Riccardo Grazzi and Massimiliano Pontil},
  booktitle = 	 {Proceedings of the 35th International Conference on Machine Learning (ICML 2018},
  year = 	 {2018},
  series = 	 {Proceedings of Machine Learning Research},
  publisher = 	 {PMLR},
  pdf = 	 {http://proceedings.mlr.press/v80/franceschi18a/franceschi18a.pdf},
}
@article{franceschi2017bridge,
  title={A Bridge Between Hyperparameter Optimization and Larning-to-learn},
  author={Franceschi, Luca and Frasconi, Paolo and Donini, Michele and Pontil, Massimiliano},
  journal={arXiv preprint arXiv:1712.06283},
  year={2017}
}

This package has been used for the project LDS-GNN: the code for the ICML 2019 paper "Learning Discrete Structures for Graph Neural Networks".

About

Gradient based hyperparameter optimization & meta-learning package for TensorFlow

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •