-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explore removing loop in HEALPix FFTs (to reduce JIT compile time for high L) #140
Comments
Yep I think that can be improved a bit |
oookkkk I admit defeat.... I can't figure out how to do it, I was hoping some clever uses of https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html Of course, the nukular option remains... One could implement But that brings two questions:
For now though, I guess living with the for loop version is ok. I think the compilation time is quite faster now that I'm running it on jax v0.4 |
The nuclear option is a possibility but as a very last resort I would say. As you say, that is not going to be the easiest to maintain. But if that is the only way, we could do it. I think implementing the minimal amount of code in cuda would be best, which would help with maintainability and shouldn't have any further performance impact, so basically just Beyond a more efficient JAX implementation of the current HEALPix algorithm, we could perhaps consider variants of the algorithm that avoid this issue or at least mitigate it, e.g. perhaps we could compute a fixed higher resolution ring FFT and downsample rings. Will need to look into this in further detail and give it some more thought. |
Ah that's a shame I was holding out hope you might find a solution! I've already sifted through many of the 'obvious' work-arounds, and even a few hacky ones, but it's not at all clear how best to handle this. I suspect we may want to re-engineer the FFT component of the HEALPix transform which may introduce some error; HEALPix is only approximate regardless, so perhaps this is a fair trade-off in this case. The reason we strayed away from doing anything like that so far is that we wanted to ensure we recreated the transforms presented in existing code, but in JAX. If we make these changes the transforms may diverge somewhat... |
No description provided.
The text was updated successfully, but these errors were encountered: