New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix to make jax.scipy.special.sph_harm
to be close to scipy.special.sph_harm
for multiple theta or phi
#20772
base: main
Are you sure you want to change the base?
Conversation
88d8c37
to
5ec3d36
Compare
44446fb
to
8dc1f05
Compare
@@ -1185,7 +1185,7 @@ def _sph_harm(m: Array, | |||
cos_colatitude = jnp.cos(phi) | |||
|
|||
legendre = _gen_associated_legendre(n_max, cos_colatitude, True) | |||
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip") | |||
legendre_val = legendre.at[abs(m), n, jnp.arange(len(phi))].get(mode="clip") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this break if n
and phi
are not broadcast-compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will break. But I observed that scipy.special.shp_harm
and jax.scipy.special.sph_harm
also raise error when n
, phi
and theta
are not broadcast compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add code at the top of the function to explicitly check for whatever shape compatibility this function expects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use jax.lax.broadcast_shapes
at the top of the function to verify if the shapes of m
, n
, theta
, and phi
are broadcast compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is correct semantically, then yes this is how we should do it. Look at the uses of e.g. jax._src.numpy.util.promote_args
elsewhere in the package.
I'm having trouble evaluating whether this change is correct. This is part due to lack of test coverage, and part due to lack of documentation: it's just not clear to me what set of input shapes should be considered valid here, and so whether changing To be clear, this state of things precedes this PR, and is one reason we've previously flagged this API as a candidate for removal (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-special). |
@@ -1235,6 +1235,8 @@ def sph_harm(m: Array, | |||
A 1D array containing the spherical harmonics at (m, n, theta, phi). | |||
""" | |||
|
|||
lax.broadcast_shapes(jnp.shape(m), jnp.shape(n), jnp.shape(theta), jnp.shape(phi)) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will check if the input arguments are broadcast compatible and will raise a ValueError
if they are not broadcast compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jakevdp Could you please check if this change is correct or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the way to do this would be to use jax._src.numpy.util.promote_shapes
, something like this:
from jax._src.numpy.util import promote_shapes
m, n, theta, phi = promote_shapes("sph_harm", m, n, theta, phi)
That's assuming that the intent of this API is that all inputs are broadcast to a common shape, which is not clear to me from reading the docs, the source, or the tests.
…ph_harm for multiple theta or phi
0a384ca
to
1860604
Compare
Fixes #20769