Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fitting ODE model with diffeqsolve is extremely slow using NUTS on GPU #338

Open
kokbent opened this issue Dec 8, 2023 · 3 comments
Open
Labels
question User queries

Comments

@kokbent
Copy link

kokbent commented Dec 8, 2023

So as the title says, I've been trying to fit my SIR ODE model using NUTS on GPU. However, the fit was extremely slow when compared to CPU. I'm using jax and numpyro to do the fitting. I ran this on Google colab:

CPU
sample: 100%|██████████| 2000/2000 [02:15<00:00, 14.78it/s, 7 steps of size 3.16e-01. acc. prob=0.94]

GPU (had to interrupt because it's too slow)
warmup: 1%| | 16/2000 [05:11<10:43:38, 19.47s/it, 1 steps of size 2.14e-04. acc. prob=0.58]

This is not an issue specific to diffrax, I had the same problem using odeint as my ODE solver too. I've searched through the internet, and seems like similar issue (but odeint) was reported in JAX: Gradients with odeint slow on GPU #5006. According to one of the reply: it seems like the tight loop structure in odeint is not XLA GPU friendly. Given that I have seen similar issue when using diffeqsolve, I guess that it also uses similar technique and suffer from similar issue? The question then is, is there any possible way to circumvent the problem within the diffrax package, perhaps another type of implementation?


Here's the code I use:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5

numpyro.set_platform("cpu")


def sir_ode(state, _, parameters):
    # Unpack state
    s, i, r = state
    beta, gamma = parameters
    population = s + i + r

    # Compute flows
    ds_to_i = beta * s * i / population
    di_to_r = gamma * i

    # Compute derivatives
    ds = -ds_to_i
    di = ds_to_i - di_to_r
    dr = di_to_r

    return (ds, di, dr)  # jnp.stack([ds, di, dr])


# Parameters
rng = np.random.default_rng(seed=867530)
beta = 1.5 / 4.5
gamma = 1.0 / 4.5
population = 10000
initial_infections = 1.0

initial_state = (
    population - initial_infections,  # s
    initial_infections,  # i
    0, # r
)

# Solve ODE
term = ODETerm(lambda t, state, parameters: sir_ode(state, t, parameters))
solver = Tsit5()
t0 = 0.0
t1 = 100.0
dt0 = 0.1
times = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=times)



def des(initial_state, args):
    solution = diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        initial_state,
        args=args,
        saveat=saveat,
    )
    return solution


sol = des(initial_state, [beta, gamma])

# Generate incidence sample
rng = np.random.default_rng(seed=8675309)
incidence = -np.diff(sol.ys[0], axis=0)
incidence_sample = rng.poisson(incidence)


# Sampling model
def sir(times, incidence):
    # Parameters
    initial_infections = numpyro.sample("initial_infections", dist.Exponential(1.0))
    beta = numpyro.sample("beta", dist.Exponential(1.0))
    gamma = numpyro.sample("gamma", dist.Exponential(1.0))

    initial_state = (
        population - initial_infections,  # s
        initial_infections,  # i
        0,
    )  # r

    # Integrate the model
    solution = des(initial_state, [beta, gamma])
    model_incidence = -jnp.diff(solution.ys[0], axis=0)

    # Observed incidence
    numpyro.sample("incidence", dist.Poisson(model_incidence), obs=incidence)


rng_key = random.PRNGKey(8811)
nuts_kernel = NUTS(sir, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, times, incidence_sample)
@patrick-kidger
Copy link
Owner

The first thing that jumps out is that you don't appear to be explicitly JIT'ing your computation. Diffrax already does this for you internally for the most part, but even so best practice is to put an equinox.filter_jit on des.

The second is that it looks like beta and gamma might be Python floats rather than JAX arrays, in which case I suspect things are recompiling every time. Make them NumPy or JAX arrays. (When using equinox.filter_jit, the rule is that things will recompile if a JAX/NumPy array changes shape or dtype, and if anything else changes in any way at all.)

@patrick-kidger patrick-kidger added the question User queries label Dec 8, 2023
@kokbent
Copy link
Author

kokbent commented Dec 11, 2023

Hi Patrick, thanks for the response. I've jitted my des function as you suggested. For the beta and gamma, making a jax array in the first part of the code doesn't seem to have much effect (they are only used to generate a random sample). Within the sampling model sir(), it's handled by numpyro and i believe all the sampled parameters should be in some form of JAX traceable arrays. And the MCMC is still very slow. I probably should also put the issue to numpyro.

@patrick-kidger
Copy link
Owner

You can double-check whether recompilation is happening with equinox.debug.assert_max_traces, by the way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants