Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frequent JIT-recompile of discrete_terminating_event #333

Open
nikolas-claussen opened this issue Nov 18, 2023 · 2 comments
Open

Frequent JIT-recompile of discrete_terminating_event #333

nikolas-claussen opened this issue Nov 18, 2023 · 2 comments
Labels
question User queries

Comments

@nikolas-claussen
Copy link

nikolas-claussen commented Nov 18, 2023

Hi,

I am running into a strange issue when using the diffrax.diffeqsolve with the discrete_terminating_event argument which I believe is due to a large number of JIT-recompiles, making execution time slow.

For context, I am solving an ODE until a stopping criterion occurs. Then, I make some modification to the arguments of the ODE, and restart it. Schematically:

my_event = diffrax.DiscreteTerminatingEvent(lambda: state, **kwargs: my_function(state.y, *kwargs["args"]))
tcurrent = t0
y0 = my_initial_condition
args = my_initial_args
while tcurrent < t1:
    solution = diffrax.diffeqsolve(term, solver, args=args, t0=tcurrent, dt0=dt0, y0=y0,
                                   discrete_terminating_event=my_event,
                                   stepsize_controller=stepsize_controller, max_steps=None)
    tcurrent = float(solution.ts)
    y0 = solution.ys[-1]
    args = modify_args(solution.ys[-1], args)

All the functions (my_function, modify_args, the function wrapped by term) are written in JAX and JITed. The first time I run the while loop - as a cell in a jupyter notebook - it takes approx. 10 seconds. When I run it again, with identical my_initial_condition, it is significantly faster, approx. 0.3s. I assume this difference is due to the JIT compilation overhead - no problem.

However, when I re-run this with a slightly modified initial condition, e.g. y0 = my_initial_condition+1e-5 I am back to 10s runtime. This is not good, because I want to run this code block for large number of times for different values of my_initial_condition. I ran the following tests to see what might be going on:

  • If the times at which the ODE solver is stopped predefined, i.e. not triggered by discrete_terminating_event, then the problem is gone, even if I am still passing a discrete_terminating_event-argument (modified so as to never trigger a stop)
  • If I define my_event inside the while loop, then I always get the ~10s execution time, even if I'm re-running the cell with identical inputs. I.e.:
while tcurrent < t1:
    my_event = diffrax.DiscreteTerminatingEvent(lambda: state, **kwargs: my_function(state.y, *kwargs["args"]))
    solution = diffrax.diffeqsolve(term, solver, args=args, t0=tcurrent, dt0=dt0, y0=y0,
    ...
  • When I evaluate my_function, the function inside my_event with different values of y or args, I do not trigger a JIT recompile.

This has lead my to believe that diffrax JIT-recompiles the discrete_terminating_event every time integration is stopped due to an event. Is there a way to avoid this?

Best,

Nikolas

@patrick-kidger
Copy link
Owner

It's a little hard to tease out an explanation for each individual case you've tested, but fundamentally recompilations happen every time you pass in a new function, or new bool/int/float/complex (that isn't wrapped into a JAX array), or when you change the shape/dtype of an array. But one example that is straightforward to explain is when you put my_event inside the loop, then you are creating a fresh lambda function every time (and Python doesn't offer a way to detect that this looks identical to the previous lambda functions you've created), and so this is what causes recompilation.

Fundamentally, what you almost certainly want to do is to JIT your whole computation -- include the diffeqsolve -- and not just to JIT individual pieces. See point 1 in this guidance. You can convert your Python while loop into a jax.lax.while_loop to make this possible.

@patrick-kidger patrick-kidger added the question User queries label Nov 18, 2023
@nikolas-claussen
Copy link
Author

Thanks a lot - that made it work. I realized in the process of JIT-ing the whole while loop that my modify_args was actually not JIT compatible. But based on your advice about the jax.lax-control flow operators, I was able to fix that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants