Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the best way to integrate a function? #361

Open
mariogeiger opened this issue Jan 26, 2024 · 5 comments
Open

What is the best way to integrate a function? #361

mariogeiger opened this issue Jan 26, 2024 · 5 comments
Labels
question User queries

Comments

@mariogeiger
Copy link

Let say I have a function $h(t)$ that I can evaluate in any $t$ and I want to calculate $$\int_0^1 h(t) dt.$$

Would diffrax used like that be a good idea?

def h(t):
    return t**2


dx.diffeqsolve(
    dx.ODETerm(lambda t, _y, _args: h(t)),
    dx.Tsit5(),
    0.0,
    1.0,
    dt0=None,
    y0=0.0,
    stepsize_controller=dx.PIDController(rtol=0.0, atol=1e-3),
).ys.squeeze(0)
@patrick-kidger
Copy link
Owner

This would work, but might not be that efficient, algorithmically speaking. To explain this, I like to distinguish between ODEs (when the vector field depends on y) and integrals (when the vector field does not depend on y).

ODE solvers are based on the premise that small errors based early on in the integration will result in downstream errors later in the integration -- as the evolving y gradually drifts away from the true solution.

But for an integral like this, that premise is no longer the case. That means that solver can be less conservative, and e.g. allow itself to take larger timesteps whilst still getting a sufficiently accurate solution. That means integral-specific solvers (i.e. "quadrature methods") can accomplish the same goal with less computational work.

Sadly we don't really have a great quadrature library in JAX yet. (Quadrature and interpolation are the two pieces we still have missing, really.) So if you particularly care about efficiency, I'd suggest coding up your own quadrature rule. But if you just care about getting a result and aren't too fussed about speed, then go ahead and keep using Diffrax!

@patrick-kidger patrick-kidger added the question User queries label Jan 26, 2024
@f0uriest
Copy link

If you don't mind a little self promotion, I've been working on a library for quadrature in Jax that should do what you need: https://github.com/f0uriest/quadax

And in response to @patrick-kidger comment, I also have one for interpolation and splines: https://github.com/f0uriest/interpax

@patrick-kidger
Copy link
Owner

Haha, self promotion is absolutely encouraged! I think these are both really cool (and I think I've seen them before, and have been meaning to check them out).

Poking through Quadax a little bit, two quick questions/comments:

  1. Heads-up that this cond will turn into a lax.select when vmap'd. That means that the whole scan will run to the very end, unconditionally. (This exact issue is actually a big part of why I wrote eqx.internal.while_loop.)
  2. I notice you have quite a few little jax.jit decorators floating around in the internals, but I don't see them on the top-level quad* functions. What's the reason for this? Typically I would think to follow the paradigm of putting a top-level JIT somewhere, and then everything below it will automatically be JIT'd.

(I should emphasise that I really like the look of both these libraries :) )

@f0uriest
Copy link

Thanks! I'm a big fan of equinox so I had considered using your version of while_loop but I didn't see it in the documented API so wasn't sure if it is for public use or just internal. Could you advise? (fwiw if you vmap the evaluation of an integral its likely that each will take around the same cost anyways, so it's probably not a huge inefficiency to run all the loops to the final step, unless you have some really nasty parameterized integrand)

wrt jit I think it's likely a holdover from some previous versions where I was testing against functions that couldn't be jitted so I had to be careful about placement, but that's over now so I'll change it for the next release.

@patrick-kidger
Copy link
Owner

So when it comes to "running all loops until the final step", note that this will be until the final step of the overall scan -- to be precise, the (presumably-expensive) body function will be evaluated on every step, for the fixed length of the entire scan. This is including once you're past the point when all batch elements are done -- it's just that in this case the body function will be evaluated and then discarded.

If that's what you mean and are okay with that, then sure. In practice for difffeqsolves it's fairly common to have a maximum number of steps set at e.g. 10^4, but to only make e.g. 10^1 steps most of the time. So there it's important that we exhibit early-exit behaviour.

The status of Equinox's while loop: it's a stable API, the only reason is that it isn't documented is because it's easily footgunnable (see the warnings in its docstring). I don't think that's safe enough for an average user -- the rule I've gone with for the Eqx ecosystem has been that I'd rather not do something than do it badly. Too easy to break user trust otherwise. (=exactly the reason I set up my own thing rather than use the Julia ecosystem!)

So I think you should be good to use it, just scrutinise what you do carefully :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants