Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero gradient when using jnp.piecewise inside an ODE #363

Open
dflocher opened this issue Jan 29, 2024 · 2 comments
Open

Zero gradient when using jnp.piecewise inside an ODE #363

dflocher opened this issue Jan 29, 2024 · 2 comments
Labels
question User queries

Comments

@dflocher
Copy link

Hi,
applying jax.grad to a function which uses diffrax to integrate a piecewise defined ODE, I observe that one partial derivative is unexpectedly zero. The ODE solver returns correct function values, just the gradient is wrong. I’m wondering whether this is a bug, or whether I’m doing something wrong.
Thanks in advance!
David

Example:

Consider the piecewise defined ODE

$$\frac{dy}{dt} = -k(t) \cdot y, \qquad y(0) = y_0, \qquad \mathrm{with} \ k(t) = \begin{cases} k_0, \ t \leq T \\ 0, \ t > T \end{cases},$$

to which the solution reads

$$y(t) = y_0 \cdot \begin{cases} e^{-k_0 t}, \ t \leq T \\ e^{-k_0 T}, \ t > T \end{cases}$$

I'm interested in the partial derivatives w.r.t. $T$ and $k_0$. In the code example below, I compare the gradient obtained from integrating the ODE using diffrax to the analytical solution and to a finite difference calculation.

import jax
import jax.numpy as jnp
import diffrax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)


def calc_y_analytically(params, t, y0):
    T, k0 = params
    return y0 * jnp.piecewise(t, [t <= T, t > T], [lambda x: jnp.exp(-k0*x), lambda x: jnp.exp(-k0*T)])


def calc_y_ode(params, t, y0):

    def ode(t, y, args):
        T, k0 = args
        k_of_t = jnp.piecewise(t, [t <= T, t > T], [k0, 0.0])
        d_y = -1 * k_of_t * y
        return d_y

    term = diffrax.ODETerm(ode)
    solver = diffrax.Tsit5()
    sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=t, dt0=0.00001, y0=y0, args=params, max_steps=200000)
    return sol.ys[0]


if __name__ == '__main__':

    params = jnp.array([1.0, 5.5])  # (T, k0)
    t = 1.2
    y0 = 100.0

    # calculate y(t) and the gradient w.r.t. T and k0 analytically
    y_ana, grads_ana = jax.value_and_grad(calc_y_analytically)(params, t, y0)

    # propagate y(0)=y0 until t by solving the ODE and calculate the gradient w.r.t. T and k0
    y_diff, grads_diff = jax.value_and_grad(calc_y_ode)(params, t, y0)

    # perform finite differences method on 0th parameter for verification
    eps = 1e-4
    param0 = params[0]
    params_plus = params.at[0].set(param0 + eps)
    params_minus = params.at[0].set(param0 - eps)

    y_ana_plus = calc_y_analytically(params_plus, t, y0)
    y_ana_minus = calc_y_analytically(params_minus, t, y0)
    part_deriv_ana = (y_ana_plus - y_ana_minus) / (2 * eps)

    y_diff_plus = calc_y_ode(params_plus, t, y0)
    y_diff_minus = calc_y_ode(params_minus, t, y0)
    part_deriv_diff = (y_diff_plus - y_diff_minus) / (2 * eps)

    print('\ny(t):')
    print('Analytical: {y:.6f}'.format(y=y_ana))
    print('Diffrax:    {y:.6f}'.format(y=y_diff))

    print('\nGradient:')
    print('Analytical: ' + str(grads_ana))
    print('Diffrax:    ' + str(grads_diff))

    print('\nPartial derivative w.r.t. parameter 0 via finite difference:')
    print('Analytical: {p:.6f}'.format(p=part_deriv_ana))
    print('Diffrax:    {p:.6f}'.format(p=part_deriv_diff))

prints out the following:

y(t):
Analytical: 0.408677
Diffrax:    0.408675

Gradient:
Analytical: [-2.24772429 -0.40867714]
Diffrax:    [ 0.         -0.40867537]

Partial derivative w.r.t. parameter 0 via finite difference:
Analytical: -2.247724
Diffrax:    -2.247712

I'm using Python 3.11.7, jax 0.4.23, jaxlib 0.4.23.dev20231223, diffrax 0.5.0, MacOS 14.2.1, x86_64, running on CPU

@patrick-kidger
Copy link
Owner

Ah, this is a known (although pretty obscure) limitation of using autodifferentiable differential equation solvers with discontinuous vector fields.

TL;DR: the solution

First of all, the solution: explicitly declare the jump time in the stepsize controller, typically by doing stepsize_controller=PIDController(..., jumps_ts=[T]). (I think you could also use StepTo here if you wanted a fixed step size rather than an adaptive one.)

What went wrong?

As for what's going on, we can explain this in a few different ways.

  • First of all, from the point of view of autodifferentiable software: in your code, note that T is only used to generate some boolean masks (t <= T, t > T). There's no way for a gradient to flow backwards from d_y into T; boolean masks never have gradients.
  • To expand on that from a numerical perspective: try writing out the explicit Euler method on your vector field, and looking at the resulting computation graph. (Which up to the choice of solver is what we get with Diffrax.) Here too, we wouldn't expect to get any gradient with respect to T, as it's never used in a differentiable way.
  • Third, let's consider the mathematics: what is happening here is that you asking for d/dT \int_0^t f(s, y(s), T) ds, i.e. a derivative-of-an-integral. (Here f is your vector field; the RHS of your ODE, and s is the evolving time of the system.) However, because f is discontinuous, we cannot switch the derivative and the integral: there is no meaninngful notion of df/dT. And having a meaningful notion of df/dT is what we are relying upon when differentiating solver!

Why does the solution above work?

So how do we fix this? We've seen that writing diffeqsolve(..., t0=0, t1=t) doesn't work.

Our first insight into fixing this is to observe that if we had split this into diffeqsolve(..., t0=0, t1=T) and diffeqsolve(..., t0=jnp.minimum(t, T), t1=t), then we would have a well-behaved ODE on each piece.

In fact, go ahead and test this, and you'll get the expected gradient! Reasoning in terms of the computation graphs described above, we can see that the reason for this is that T now has a differentiable dependence in the computation graph.

So using PIDController(..., jump_ts=[T]) basically does exactly this: it means we make a numerical step right to that point. (And will also be faster to compile than the double-diffeqsolve approach -- which doubles the number of operations JAX has to compile.)

Can we do better?

This is an unfortunate user footgun! But I don't know of an automatic solution to this; so far as I know it may be an open question in the theory of autodifferentiation. (?)

Spitballing, I imagine this could maybe be solved by having the ODE solver try and detect when it thinks a jump has occured, if so to solve a root-finding problem to find the jump, and then use that in its step size control.

I think investigating this might be an interesting research question in autodifferentiation, for those curious enough to try :)

@patrick-kidger patrick-kidger added the question User queries label Feb 2, 2024
@dflocher
Copy link
Author

dflocher commented Feb 5, 2024

Thank you, Patrick, for your detailed and instructive answer!
Your proposed solution works fine.
Best,
David

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants