Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex support for implicit solver #388

Open
Ricky5389 opened this issue Mar 15, 2024 · 3 comments
Open

Complex support for implicit solver #388

Ricky5389 opened this issue Mar 15, 2024 · 3 comments
Labels
question User queries

Comments

@Ricky5389
Copy link

Ricky5389 commented Mar 15, 2024

Hi, I am trying to solve a stiff differential equation using diffrax. I need to use complex number and would like to use an implicit solver. I would like to try to implement complex type for implicit solver if it is not too difficult. Is it possible and do you have some global advice that would help me start ?

@patrick-kidger
Copy link
Owner

This is a difficult thing to try and do, and comes with some questions we need to think pretty carefully about (most notably what backpropagation does in this scenario).

FWIW you should be able to make this work today by just treating the real and imaginary parts separately.

@patrick-kidger patrick-kidger added the question User queries label Mar 15, 2024
@gautierronan
Copy link

Hi @patrick-kidger, we'd be pretty interested in complex support for implicit solvers in the context of dynamiqs. I understand it's a non-trivial task, but if you're interested, we can definitely put in the effort and make the required PRs for diffrax/lineax in the coming months.

What is your concern regarding backpropagation? I'm guessing regular autodiff should work out of the box, but the recursive checkpoint method might be more subtle? Or is it something else?

Regarding the trick to separate real/imaginary parts, would it not be overall slower due to repeated memory accesses? Or is this optimized by the JIT?

Thanks :)

@patrick-kidger
Copy link
Owner

So the main thing that's needed here is just loads of tests for Diffrax! As it wasn't originally written with complex support in mind then it's entirely possible we have places where we write something like x**2 intending to compute a norm, but which with complex numbers will silently misbehave.

I imagine the actual change in lines of code for adding this feature should be fairly small:

  • In Diffrax itself, fixing up any examples like the above.
  • In Optimistix, being sure that optx.implicit_jvp does the right thing.
  • In Lineax, this should already basically be done thanks to the hard efforts of @Randl! The main blocker here was that they uncovered an XLA bug ([BUG] Inconsistent compiled JAX running results on CPU openxla/xla#8471), but that actually got fixed just a few days ago. Once the new jaxlib release is out then I imagine things should be good-to-go there.

On backpropagation, this is in large part about making sure that optx.implicit_jvp does the right thing. JAX follows a pretty quirky convention when it comes to complex backpropagation, which is that the VJP is given by the transpose of the JVP, not the conjugate transpose of the JVP. This is unlike what you normally do when computing the adjoint of a complex matrix, and is also different to PyTorch. I highly recommend reading that PyTorch doc btw, it's very informative.

(In practice what this usually means is that when doing autodiff in complex numbers in JAX, you should compute the conjugate of your gradients before performing SGD.)

I suspect (but am not sure) that we're already doing the right thing for backpropagation, so if you like you can imagine putting this under the 'it needs to be tested' banner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants