-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
User report #142
Comments
Hi @EiffL the first error rises because we expect JIT compilation of forward transformThe slow JIT compilation is a more critical issue (which we are aware of) and there are two points to consider:
tl;dr: this should only be an issue for HEALPix sampling, and exists for both the forward and inverse transform, and its an issue of the FFT component of the transform. SolutionsA short term solution would be to switch for another sampling scheme, none of the remaining schemes should exhibit this compilation issue. A longer term solution, which we are aware of and investigating, in #140 is to optimise this HEALPix FFT loop, perhaps by compressing into a |
right right right... I see.... what would happen if one interpolated the healpix map on one of the other formats that have same number of points per rings? |
of course.... one could also not use the healpix scheme ^^ |
If each ring had the same number of samples then the HEALPix FFT should just reduce to a simple |
Thanks for the comments @EiffL. As @CosmoMatt said, the slow jit compile time for HEALPix is an issue with the HEALPix sampling that we're aware of and need to look into further. This is in fact a general problem with HEALPix, which is also present for CPU HEALPix implementations too. In those implementations one would ideally like to plan an FFT to optimise efficiency (e.g. FFTW allows you to plan FFTs to optimise them), however since the size of each FFT differs for each ring, you cannot plan a single efficient FFT. Therefore other codes typically simply use an estimated FFT, which is perhaps not optimal. But planning each FFT would be a signifcant overhead and there is not much of a performance hit in using estimated FFTs. We face a very similar issue here, again due to the varying number of samples on each ring, but it is more acute since it has a significant impact on jit compile time. We need to look into this some more. This is fairly high priority but we have some other more pressing commitments for the next couple of weeks at least so not sure we'll be able to look into this in detail until before then. Any insight that you can provide @EiffL would be very welcome! |
Dear s2fft devs,
Thanks for this great package! I just wanted to report on some of the pain-points I ran into to hopefully provide constructive feedback.
SHT with inappropriate healpix nside returns mysterious error
Here is what I tried to do:
triggers an error that ends in:
jitting time is long
I'm sure this is something you are well aware of, and I guess that's also one of the reasons for the 2 stages way of doing the transform. I've found that it can take up to 3 mins on my laptop to compile the forward_jax.
It's no big deal though, and there are many ways to accelerate these kind of things and I'll be very happy to take a look at it to see if I can suggest some improvements in #140
The text was updated successfully, but these errors were encountered: