Skip to content

jacobjinkelly/easy-neural-ode

Repository files navigation

Learning Differential Equations that are Easy to Solve

Code for the paper:

Jacob Kelly*, Jesse Bettencourt*, Matthew James Johnson, David Duvenaud. "Learning Differential Equations that are Easy to Solve." Neural Information Processing Systems (2020). [arxiv] [bibtex]

*Equal Contribution

Includes JAX implementations of the following models:

Includes JAX implementations of the following adaptive-stepping numerical solvers:

  • Heun-Euler heun (2nd order)
  • Fehlberg (RK1(2)) fehlberg (2nd order)
  • Bogacki-Shampine bosh (3rd order)
  • Cash-Karp cash_karp (4th order)
  • Fehlberg rk_fehlberg (4th order)
  • Owrenzen owrenzen (4th order)
  • Dormand-Prince dopri (5th order)
  • Owrenzen owrenzen5 (5th order)
  • Tanyam tanyam (7th order)
  • Adams adams (adaptive order)
  • RK4 rk4 (4th order, fixed step-size)

Requirements

Python

Please use python>=3.8

JAX

Follow installation instructions here.

Haiku

Follow installation instructions here.

Tensorflow Datasets

For using the MNIST dataset, follow installation instructions here.

Usage

Different scripts are provided for each task and dataset.

MNIST Classification

python mnist.py --reg r3 --lam 6e-5

Latent ODEs

python latent_ode.py --reg r3 --lam 1e-2

FFJORD (Tabular)

python ffjord_tabular.py --reg r2 --lam 1e-2

FFJORD (MNIST)

python ffjord_mnist.py --reg r2 --lam 3e-4

Datasets

MNIST

tensorflow-datasets (instructions for installing above) will download the data when called from the training script.

Physionet

The file physionet_data.py, adapted from Latent ODEs for Irregularly-Sampled Time Series will download and process the data when called from the training script. A preprocessed version is available in releases.

Tabular (FFJORD)

Data must be downloaded following instructions from gpapamak/maf and placed in data/. Only MINIBOONE is needed for experiments in the paper.

Code in datasets/, adapted from Free-form Jacobian of Reversible Dynamics (FFJORD), will create an interface for the MINIBOONE dataset once it's downloaded. It is called from the training script.

Acknowledgements

Code in lib is modified from google/jax under the license.

Several numerical solvers were adapted from torchdiffeq and DifferentialEquations.jl.

BibTeX

@inproceedings{kelly2020easynode,
  title={Learning Differential Equations that are Easy to Solve},
  author={Kelly, Jacob and Bettencourt, Jesse and Johnson, Matthew James and Duvenaud, David},
  booktitle={Neural Information Processing Systems},
  year={2020},
  url={https://arxiv.org/abs/2007.04504}
}