New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex support for implicit solver #388
Comments
This is a difficult thing to try and do, and comes with some questions we need to think pretty carefully about (most notably what backpropagation does in this scenario). FWIW you should be able to make this work today by just treating the real and imaginary parts separately. |
Hi @patrick-kidger, we'd be pretty interested in complex support for implicit solvers in the context of dynamiqs. I understand it's a non-trivial task, but if you're interested, we can definitely put in the effort and make the required PRs for diffrax/lineax in the coming months. What is your concern regarding backpropagation? I'm guessing regular autodiff should work out of the box, but the recursive checkpoint method might be more subtle? Or is it something else? Regarding the trick to separate real/imaginary parts, would it not be overall slower due to repeated memory accesses? Or is this optimized by the JIT? Thanks :) |
So the main thing that's needed here is just loads of tests for Diffrax! As it wasn't originally written with complex support in mind then it's entirely possible we have places where we write something like I imagine the actual change in lines of code for adding this feature should be fairly small:
On backpropagation, this is in large part about making sure that (In practice what this usually means is that when doing autodiff in complex numbers in JAX, you should compute the conjugate of your gradients before performing SGD.) I suspect (but am not sure) that we're already doing the right thing for backpropagation, so if you like you can imagine putting this under the 'it needs to be tested' banner. |
Hi, I am trying to solve a stiff differential equation using diffrax. I need to use complex number and would like to use an implicit solver. I would like to try to implement complex type for implicit solver if it is not too difficult. Is it possible and do you have some global advice that would help me start ?
The text was updated successfully, but these errors were encountered: