Fitting ODE model with diffeqsolve
is extremely slow using NUTS on GPU
#338
Labels
question
User queries
diffeqsolve
is extremely slow using NUTS on GPU
#338
So as the title says, I've been trying to fit my SIR ODE model using NUTS on GPU. However, the fit was extremely slow when compared to CPU. I'm using
jax
andnumpyro
to do the fitting. I ran this on Google colab:This is not an issue specific to
diffrax
, I had the same problem usingodeint
as my ODE solver too. I've searched through the internet, and seems like similar issue (butodeint
) was reported in JAX: Gradients with odeint slow on GPU #5006. According to one of the reply: it seems like the tight loop structure inodeint
is not XLA GPU friendly. Given that I have seen similar issue when usingdiffeqsolve
, I guess that it also uses similar technique and suffer from similar issue? The question then is, is there any possible way to circumvent the problem within thediffrax
package, perhaps another type of implementation?Here's the code I use:
The text was updated successfully, but these errors were encountered: