Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: s2fft wigner matrices #140

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

lgrcia
Copy link
Collaborator

@lgrcia lgrcia commented Feb 27, 2024

I copied some unreleased code from s2fft (only temporary, I opened an issue there) to test the recurrence building of the Wigner D-matrices. I did a modification so that beta in their utils.rotation.generate_rotate_dls can be non-static.

Compilation is orders of magnitude faster:

  • 8 seconds for (new)

    from jaxoplanet.experimental.starry.s2fft_rotation import compute_rotation_matrices as R1
    
    deg = 20
    R = R1(deg, 0.0, 1.0, 0.0, 1.0)
  • 3 minutes for (current)

    from jaxoplanet.experimental.starry.rotation import compute_rotation_matrices as R2
    
    R = R2(deg, 0.0, 1.0, 0.0, 1.0)

Of course this is preliminary and should be tested more extensively. Just a note that for spherical harmonics with a degreeL, the output of s2fft.utils.rotation.generate_rotate_dls is (L, 2*L+1, 2*L+1) (whereas we currently have a list of matrices with different shapes [(1,), (3,3), (5,5), ..., (2*L+1, 2*L+1)].

To exploit that, I tried to pad and unpad the Ylm to perform actual matrix multiplications. It doesn't seems to make things faster, which isn't that surprising. Anyway I'll push that bit in the Ylm class in case it's useful.

Looking forward doing more testings!

@lgrcia lgrcia requested a review from dfm February 27, 2024 02:03
@dfm
Copy link
Member

dfm commented Feb 27, 2024

This is great @lgrcia!! I'll take a closer look later.

I did a modification so that beta in their utils.rotation.generate_rotate_dls can be non-static.

Is this something where you could contribute upstream or is it easier to keep it local?

@lgrcia
Copy link
Collaborator Author

lgrcia commented Feb 27, 2024

I opened this issue on the s2fft repo to discuss that there! Just to be sure about the motives behind the static beta.

@lgrcia
Copy link
Collaborator Author

lgrcia commented Feb 27, 2024

Just to keep a reference somewhere, here is the version with the padded Ylm:

import jax
from jaxoplanet.experimental.starry import Ylm, rotation
import numpy as np
import jax.numpy as jnp

deg = 10
y = Ylm.from_dense(np.hstack([1, np.random.rand((deg + 1) ** 2 - 1)]))


def dot_rotation_matrix2(ydeg, x, y, z, theta):
    rotation_matrices = rotation.compute_rotation_matrices_s2fft(ydeg, x, y, z, theta)

    def dot(y_padded):
        padded_y = jnp.einsum("ij,ijk->ik", y_padded, rotation_matrices)
        return padded_y

    return dot


inc, obl = 0.5, 0.8
values = rotation.right_project_axis_angle(inc, obl, 0.0, 0.0)
f = jax.jit(dot_rotation_matrix2(deg, *values))
rotated_y = Ylm.from_dense_pad(f(y.to_dense_pad())).todense()

This is not faster than working with a non-homogeneous set of rotation matrices.

@lgrcia
Copy link
Collaborator Author

lgrcia commented Mar 22, 2024

Based on the answer from @CosmoMatt in s2fft's issue #191, I tried the following implementation based on equation 8 of this paper:

from s2fft.utils.rotation import generate_rotate_dls
import jax.numpy as jnp


def new_dls(deg):
    delta = generate_rotate_dls(deg, jnp.pi / 2)
    idxs = jnp.indices(delta[0].shape)
    i = idxs[1][0] - deg + 1
    inm = 1j ** (idxs[1] - idxs[0])

    def impl(beta):
        sum_term = jnp.einsum("nij,nik,i->njk", delta, delta, jnp.exp(1j * i * beta))
        m = sum_term * inm
        return m.real

    return impl

which seems correct and passes all sort of tests with different degrees and $\beta$. However the performance I get is less that what I would have expected. The original generate_rotate_dls gives

import jax
from functools import partial

deg = 20
beta = 1.254
betas = jnp.linspace(0, 2 * jnp.pi, 4000)

f = jax.jit(jax.vmap(partial(generate_rotate_dls, deg)))
%timeit jax.block_until_ready(f(betas))
257 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

whereas the version based on a single call of generate_rotate_dls gives

f = jax.jit(jax.vmap(new_dls(deg)))
%timeit jax.block_until_ready(f(betas))
366 ms ± 6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@dfm, @CosmoMatt, do you have any intuition why the vmap over $\beta$ in generate_rotate_dls still provides very good performance compared to the implementation I tested? (which might involve lots of consuming matrices multiplications at the end). I would have expected the recursive construction of the dls to be expensive. I might be missing some jax behavior here.

@CosmoMatt
Copy link

Hi @lgrcia, sorry I couldn't look at this sooner. Unfortunately, I can't see any obvious reason why the three term einsum in impl should be slower than the full recursion. I wonder if this may be a scaling issue that I overlooked in my previous response here.

  • The brute force approach requires O(L^3) per beta for a full complexity of O(beta * L^3) to compute all elements.

  • The FFT approach for each beta requires computing all el, m, n entries which each involve a summation of length L with complexity O(L). So the complexity for the FFT approach would presumably be O(beta * L^4).

I may be missing something here though, perhaps @jasonmcewen has a better answer here.

@lgrcia
Copy link
Collaborator Author

lgrcia commented Apr 2, 2024

Thanks a lot for your answer @CosmoMatt! That is a really interesting avenue to consider anyway!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants