You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.
Minimal working example (using the default ODE example from the diffrax introduction):
import jax.numpy as jnp
import numpy as np
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint
z = 2.3
t = 1.
def rhot(z):
def f(t, y, args):
return -z*y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0)
#solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0, adjoint = DirectAdjoint()) #changing the adjoint fixes it
return solution.ys[0][0]
drhozdz = jax.grad(rhot,argnums = 0)
d2rhozdz = jax.grad(drhozdz,argnums = 0)
print("expected state ", np.exp(-z*t)*2.)
print("found state ", rhot(z))
print("expected ", -2.*t*np.exp(-z*t))
print("found 1st deriative ", drhozdz(z))
print("expected 2nd ", 2.*t**2*np.exp(-z*t))
print("found 2nd derivative ", d2rhozdz(z)) #fails with default adjoint
The error returned is:
"print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint
^^^^^^^^^^^
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."
The text was updated successfully, but these errors were encountered:
Yup, I'm afraid this is expected. RecusiveCheckpointAdjoint does some smart things under-the-hood to be very efficient when computing specifically first-order gradients, but unfortunately this also makes it incompatible with certain kinds of higher-order autodiff.
First of all, when looking to compute the Hessian, it is usually more efficient to use forward-over-reverse rather than reverse-over-reverse (and indeed this is what jax.hessian does). RecursiveCheckpointAdjoint should actually be compatible with that in most cases.
But nonetheless, in the general case, using DirectAdjoint is indeed the appropriate fix. (And handling edge cases like this is the reason it exists,)
Thanks for the quick reply, I missed that documentation, it was very helpful.
Playing around a bit with a more complex example I am struggling with, I see what you mean... doing forward-over-reverse with RecusiveCheckpointAdjoint() works and seems both faster and more memory efficient than using DirectAdjoint(), so that was extremely useful! thanks!
I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.
Minimal working example (using the default ODE example from the diffrax introduction):
The error returned is:
"print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint
^^^^^^^^^^^
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."
The text was updated successfully, but these errors were encountered: