Skip to content

smsharma/jax-conditional-flows

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Conditional normalizing flows in Jax

Implementation of some common normalizing flow models allowing for a conditioning context using Jax, Flax, and Distrax. The following are currently implemented:

Examples

Basic usage

import jax
from models.maf import MaskedAutoregressiveFlow
from models.nsf import NeuralSplineFlow

n_dim = 2  # Feature dim
n_context = 1  # Context dim

## Define flow model
# model = MaskedAutoregressiveFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=12, activation="tanh", use_random_permutations=False)
model = NeuralSplineFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=8, activation="gelu", n_bins=4)

## Initialize model and params
key = jax.random.PRNGKey(42)
x_test = jax.random.uniform(key=key, shape=(64, n_dim))
context = jax.random.uniform(key=key, shape=(64, n_context))
params = model.init(key, x_test, context)

## Log-prob and sampling
log_prob = model.apply(params, x_test, jnp.ones((x_test.shape[0], n_context)))
samples = model.apply(params, n_samples, key, jnp.ones((n_samples, n_context)), method=model.sample)

About

Normalizing flow models allowing for a conditioning context, implemented using Jax, Flax, and Distrax.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published