Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to make jax.scipy.special.sph_harm to be close to scipy.special.sph_harm for multiple theta or phi #20772

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rajasekharporeddy
Copy link
Contributor

Fixes #20769

@@ -1185,7 +1185,7 @@ def _sph_harm(m: Array,
cos_colatitude = jnp.cos(phi)

legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
legendre_val = legendre.at[abs(m), n, jnp.arange(len(phi))].get(mode="clip")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break if n and phi are not broadcast-compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will break. But I observed that scipy.special.shp_harm and jax.scipy.special.sph_harm also raise error when n, phi and theta are not broadcast compatible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add code at the top of the function to explicitly check for whatever shape compatibility this function expects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use jax.lax.broadcast_shapes at the top of the function to verify if the shapes of m, n, theta, and phi are broadcast compatible?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is correct semantically, then yes this is how we should do it. Look at the uses of e.g. jax._src.numpy.util.promote_args elsewhere in the package.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 16, 2024

I'm having trouble evaluating whether this change is correct. This is part due to lack of test coverage, and part due to lack of documentation: it's just not clear to me what set of input shapes should be considered valid here, and so whether changing range(len(n)) to range(len(phi)) is correct or not is difficult to discern.

To be clear, this state of things precedes this PR, and is one reason we've previously flagged this API as a candidate for removal (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-special).

@@ -1235,6 +1235,8 @@ def sph_harm(m: Array,
A 1D array containing the spherical harmonics at (m, n, theta, phi).
"""

lax.broadcast_shapes(jnp.shape(m), jnp.shape(n), jnp.shape(theta), jnp.shape(phi))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will check if the input arguments are broadcast compatible and will raise a ValueError if they are not broadcast compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jakevdp Could you please check if this change is correct or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the way to do this would be to use jax._src.numpy.util.promote_shapes, something like this:

from jax._src.numpy.util import promote_shapes

m, n, theta, phi = promote_shapes("sph_harm", m, n, theta, phi)

That's assuming that the intent of this API is that all inputs are broadcast to a common shape, which is not clear to me from reading the docs, the source, or the tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong output of jax.scipy.special.sph_harm
2 participants