Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taking more than one gradient fails with default RecursiveCheckpointAdjoint #332

Open
nwlambert opened this issue Nov 9, 2023 · 2 comments
Labels
refactor Tidy things up

Comments

@nwlambert
Copy link

nwlambert commented Nov 9, 2023

I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.

Minimal working example (using the default ODE example from the diffrax introduction):

import jax.numpy as jnp
import numpy as np
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5,  DirectAdjoint

z = 2.3
t = 1.
        
def rhot(z):
    def f(t, y, args):
        return -z*y

    term = ODETerm(f)
    solver = Dopri5()
    y0 = jnp.array([2., 3.])
    solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0) 
    #solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0, adjoint = DirectAdjoint()) #changing the adjoint fixes it
    return solution.ys[0][0]

drhozdz = jax.grad(rhot,argnums = 0)
d2rhozdz = jax.grad(drhozdz,argnums = 0)


print("expected state ", np.exp(-z*t)*2.)
print("found state ", rhot(z))


print("expected ", -2.*t*np.exp(-z*t))
print("found 1st deriative ", drhozdz(z))

print("expected 2nd ", 2.*t**2*np.exp(-z*t))
print("found 2nd derivative ", d2rhozdz(z))  #fails with default adjoint

The error returned is:
"print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint
^^^^^^^^^^^
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."

@patrick-kidger
Copy link
Owner

Yup, I'm afraid this is expected. RecusiveCheckpointAdjoint does some smart things under-the-hood to be very efficient when computing specifically first-order gradients, but unfortunately this also makes it incompatible with certain kinds of higher-order autodiff.

First of all, when looking to compute the Hessian, it is usually more efficient to use forward-over-reverse rather than reverse-over-reverse (and indeed this is what jax.hessian does). RecursiveCheckpointAdjoint should actually be compatible with that in most cases.

But nonetheless, in the general case, using DirectAdjoint is indeed the appropriate fix. (And handling edge cases like this is the reason it exists,)

You might also like the example on second-order sensitivies from the documentation.


I'm going to tag this under "refactor" as this could probably do with a more informative error message.

@patrick-kidger patrick-kidger added the refactor Tidy things up label Nov 9, 2023
@nwlambert
Copy link
Author

Thanks for the quick reply, I missed that documentation, it was very helpful.

Playing around a bit with a more complex example I am struggling with, I see what you mean... doing forward-over-reverse with RecusiveCheckpointAdjoint() works and seems both faster and more memory efficient than using DirectAdjoint(), so that was extremely useful! thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor Tidy things up
Projects
None yet
Development

No branches or pull requests

2 participants