-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: s2fft wigner matrices #140
base: main
Are you sure you want to change the base?
Conversation
This is great @lgrcia!! I'll take a closer look later.
Is this something where you could contribute upstream or is it easier to keep it local? |
I opened this issue on the s2fft repo to discuss that there! Just to be sure about the motives behind the static beta. |
Just to keep a reference somewhere, here is the version with the padded import jax
from jaxoplanet.experimental.starry import Ylm, rotation
import numpy as np
import jax.numpy as jnp
deg = 10
y = Ylm.from_dense(np.hstack([1, np.random.rand((deg + 1) ** 2 - 1)]))
def dot_rotation_matrix2(ydeg, x, y, z, theta):
rotation_matrices = rotation.compute_rotation_matrices_s2fft(ydeg, x, y, z, theta)
def dot(y_padded):
padded_y = jnp.einsum("ij,ijk->ik", y_padded, rotation_matrices)
return padded_y
return dot
inc, obl = 0.5, 0.8
values = rotation.right_project_axis_angle(inc, obl, 0.0, 0.0)
f = jax.jit(dot_rotation_matrix2(deg, *values))
rotated_y = Ylm.from_dense_pad(f(y.to_dense_pad())).todense() This is not faster than working with a non-homogeneous set of rotation matrices. |
Based on the answer from @CosmoMatt in from s2fft.utils.rotation import generate_rotate_dls
import jax.numpy as jnp
def new_dls(deg):
delta = generate_rotate_dls(deg, jnp.pi / 2)
idxs = jnp.indices(delta[0].shape)
i = idxs[1][0] - deg + 1
inm = 1j ** (idxs[1] - idxs[0])
def impl(beta):
sum_term = jnp.einsum("nij,nik,i->njk", delta, delta, jnp.exp(1j * i * beta))
m = sum_term * inm
return m.real
return impl which seems correct and passes all sort of tests with different degrees and import jax
from functools import partial
deg = 20
beta = 1.254
betas = jnp.linspace(0, 2 * jnp.pi, 4000)
f = jax.jit(jax.vmap(partial(generate_rotate_dls, deg)))
%timeit jax.block_until_ready(f(betas))
whereas the version based on a single call of f = jax.jit(jax.vmap(new_dls(deg)))
%timeit jax.block_until_ready(f(betas))
@dfm, @CosmoMatt, do you have any intuition why the |
Hi @lgrcia, sorry I couldn't look at this sooner. Unfortunately, I can't see any obvious reason why the three term
I may be missing something here though, perhaps @jasonmcewen has a better answer here. |
Thanks a lot for your answer @CosmoMatt! That is a really interesting avenue to consider anyway! |
I copied some unreleased code from s2fft (only temporary, I opened an issue there) to test the recurrence building of the Wigner D-matrices. I did a modification so that
beta
in theirutils.rotation.generate_rotate_dls
can be non-static.Compilation is orders of magnitude faster:
8 seconds for (new)
3 minutes for (current)
Of course this is preliminary and should be tested more extensively. Just a note that for spherical harmonics with a degree
L
, the output ofs2fft.utils.rotation.generate_rotate_dls
is(L, 2*L+1, 2*L+1)
(whereas we currently have a list of matrices with different shapes[(1,), (3,3), (5,5), ..., (2*L+1, 2*L+1)]
.To exploit that, I tried to pad and unpad the
Ylm
to perform actual matrix multiplications. It doesn't seems to make things faster, which isn't that surprising. Anyway I'll push that bit in theYlm
class in case it's useful.Looking forward doing more testings!