Skip to content

davidnabergoj/normalizing-flows

Repository files navigation

Normalizing flows in PyTorch

This package implements normalizing flows and their building blocks. It allows:

  • easy use of normalizing flows as trainable distributions;
  • easy implementation of new normalizing flows.

Example use:

import torch
from normalizing_flows import Flow
from normalizing_flows.architectures import RealNVP

torch.manual_seed(0)

n_data = 1000
n_dim = 3

x = torch.randn(n_data, n_dim)  # Generate some training data
bijection = RealNVP(n_dim)  # Create the bijection
flow = Flow(bijection)  # Create the normalizing flow

flow.fit(x)  # Fit the normalizing flow to training data
log_prob = flow.log_prob(x)  # Compute the log probability of training data
x_new = flow.sample(50)  # Sample 50 new data points

print(log_prob.shape)  # (100,)
print(x_new.shape)  # (50, 3)

We provide more examples here.

Installing

Install the package:

pip install git+https://github.com/davidnabergoj/normalizing-flows.git

Setup for development:

git clone https://github.com/davidnabergoj/normalizing-flows.git
cd normalizing-flows
pip install -r requirements.txt

We support Python versions 3.7 and upwards.

Brief background

A normalizing flow (NF) is a flexible trainable distribution. It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian. The bijection is typically an invertible neural network. Training a NF using a dataset means optimizing the bijection's parameters to make the dataset likely under the NF. We can use a NF to compute the probability of a data point or to independently sample data from the process that generated our dataset.

The density of a NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as: $$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|.$$ Sampling from a NF means sampling from the simple distribution and transforming the sample using the bijection.

Supported architectures

We list supported NF architectures below. We classify architectures as either autoregressive, residual, or continuous; as defined by Papamakarios et al. (2021). We specify whether the forward and inverse passes are exact; otherwise they are numerical or not implemented (Planar, Radial, and Sylvester flows). An exact forward pass guarantees exact density estimation, whereas an exact inverse pass guarantees exact sampling. Note that the directions can always be reversed, which enables exact computation for the opposite task. We also specify whether the logarithm of the Jacobian determinant of the transformation is exact or computed numerically.

Architecture Bijection type Exact forward Exact inverse Exact log determinant
NICE Autoregressive
Real NVP Autoregressive
MAF Autoregressive
IAF Autoregressive
Rational quadratic NSF Autoregressive
Linear rational NSF Autoregressive
NAF Autoregressive
UMNN Autoregressive
Planar Residual
Radial Residual
Sylvester Residual
Invertible ResNet Residual
ResFlow Residual
Proximal ResFlow Residual
FFJORD Continuous
RNODE Continuous
DDNF Continuous
OT flow Continuous

We also support simple bijections (all with exact forward passes, inverse passes, and log determinants):

  • Permutation
  • Elementwise translation (shift vector)
  • Elementwise scaling (diagonal matrix)
  • Rotation (orthogonal matrix)
  • Triangular matrix
  • Dense matrix (using the QR or LU decomposition)

Releases

No releases published

Packages

No packages published