Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore removing loop in HEALPix FFTs (to reduce JIT compile time for high L) #140

Open
jasonmcewen opened this issue Feb 15, 2023 · 4 comments

Comments

@jasonmcewen
Copy link
Contributor

No description provided.

@jasonmcewen jasonmcewen changed the title Explore removing loop in HEALPix FFTs (to reduce JIT computation time for high L) Explore removing loop in HEALPix FFTs (to reduce JIT compile time for high L) Feb 15, 2023
@EiffL
Copy link
Contributor

EiffL commented Feb 17, 2023

Yep I think that can be improved a bit

@EiffL EiffL mentioned this issue Feb 17, 2023
@EiffL
Copy link
Contributor

EiffL commented Feb 18, 2023

oookkkk I admit defeat.... I can't figure out how to do it, I was hoping some clever uses of https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html
would work, but no, the size has to be statically known.... CodeGPT/ChatGPT/Copilot were no help at all :-( (there might be some use cases for good old humans after all ^^')

Of course, the nukular option remains... One could implement healpix_fft_jax directly in cuda, and wrap it in JAX.... It's only a matter of running a bunch of ffts so it's pretty trivial. Here is an example of how it's done: https://github.com/dfm/extending-jax

But that brings two questions:

  • Would you be happy for s2fft to also have some custom compiled ops? The potential problem is that it would require some compilation as opposed to a pure jax library. Plus, long term maintenance has higher cost because you have to keep up with the undocumented jax custom op api.

  • As long as we are implementing things in CUDA, would it be smarter to directly wrap a full SHT for healpix as opposed to just the ring-wise FFTs?

For now though, I guess living with the for loop version is ok. I think the compilation time is quite faster now that I'm running it on jax v0.4

@jasonmcewen
Copy link
Contributor Author

jasonmcewen commented Feb 18, 2023

The nuclear option is a possibility but as a very last resort I would say. As you say, that is not going to be the easiest to maintain. But if that is the only way, we could do it. I think implementing the minimal amount of code in cuda would be best, which would help with maintainability and shouldn't have any further performance impact, so basically just healpix_fft_jax etc. But I think we should only do this as a last report. It would be worth trying to explore some other alternatives first.

Beyond a more efficient JAX implementation of the current HEALPix algorithm, we could perhaps consider variants of the algorithm that avoid this issue or at least mitigate it, e.g. perhaps we could compute a fixed higher resolution ring FFT and downsample rings. Will need to look into this in further detail and give it some more thought.

@CosmoMatt
Copy link
Collaborator

oookkkk I admit defeat.... I can't figure out how to do it, I was hoping some clever uses of https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html would work, but no, the size has to be statically known.... CodeGPT/ChatGPT/Copilot were no help at all :-( (there might be some use cases for good old humans after all ^^')

Of course, the nukular option remains... One could implement healpix_fft_jax directly in cuda, and wrap it in JAX.... It's only a matter of running a bunch of ffts so it's pretty trivial. Here is an example of how it's done: https://github.com/dfm/extending-jax

But that brings two questions:

  • Would you be happy for s2fft to also have some custom compiled ops? The potential problem is that it would require some compilation as opposed to a pure jax library. Plus, long term maintenance has higher cost because you have to keep up with the undocumented jax custom op api.
  • As long as we are implementing things in CUDA, would it be smarter to directly wrap a full SHT for healpix as opposed to just the ring-wise FFTs?

For now though, I guess living with the for loop version is ok. I think the compilation time is quite faster now that I'm running it on jax v0.4

Ah that's a shame I was holding out hope you might find a solution! I've already sifted through many of the 'obvious' work-arounds, and even a few hacky ones, but it's not at all clear how best to handle this. I suspect we may want to re-engineer the FFT component of the HEALPix transform which may introduce some error; HEALPix is only approximate regardless, so perhaps this is a fair trade-off in this case.

The reason we strayed away from doing anything like that so far is that we wanted to ensure we recreated the transforms presented in existing code, but in JAX. If we make these changes the transforms may diverge somewhat...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants