Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong output of jax.scipy.special.sph_harm #20769

Open
SGENZO opened this issue Apr 16, 2024 · 3 comments · May be fixed by #20772
Open

Wrong output of jax.scipy.special.sph_harm #20769

SGENZO opened this issue Apr 16, 2024 · 3 comments · May be fixed by #20772
Labels
bug Something isn't working

Comments

@SGENZO
Copy link

SGENZO commented Apr 16, 2024

Description

There is a wrong output of jax.scipy.special.sph_harm(m, n, theta, phi, n_max=None), when the degree of the harmonic $n \neq 0$.
Here is an example:

import jax 
import jax.numpy as jnp
from jax.scipy.special import sph_harm as jnp_sph
from scipy.special import sph_harm

# Generate 200 3D points
seed = 23
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)
data = jax.random.normal(subkey, shape=(200,3))
r = jnp.linalg.norm(data, ord=2, axis=1)
phi = jnp.array(jnp.arccos(data[:,2]/r))
theta = jnp.array(jnp.arctan2(data[:,1],data[:,0]))

# Calculate spa_harm value of Jax and scipy
m = 0
n = 1
scipy_result = sph_harm(jnp.array([m]), jnp.array([n]), theta, phi)
jax_result = jnp_sph(jnp.array([m]), jnp.array([n]), theta, phi, n_max=n)
print(jnp.max(jnp.abs(scipy_result - jax_result)))

The return value should be close to zero, but the real return is 0.8381599. When $m = 0, n = 0$, the return is 2.9802322e-08 is correct.

I check the source code of jax.scipy.special.sph_harm and find the wrong maybe is here:

@partial(jit, static_argnums=(4,))
def _sph_harm(m: Array,
              n: Array,
              theta: Array,
              phi: Array,
              n_max: int) -> Array:
  """Computes the spherical harmonics."""

  cos_colatitude = jnp.cos(phi)

  legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
  legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")

  angle = abs(m) * theta
  vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
  harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                          legendre_val * jnp.imag(vandermonde))

  # Negative order.
  harmonics = jnp.where(m < 0,
                        (-1.0)**abs(m) * jnp.conjugate(harmonics),
                        harmonics)

  return harmonics

This statement of Legendre_val used the wrong array, which should be changed to
legendre_val = legendre.at[abs(m), n, jnp.arange(len(phi))].get(mode="clip")

The reason of function value is correct when degree $n=0$, is at that degree, the Legendre polynomial is a constant, so every value in legendre.at[abs(m), n, jnp.arange(len(phi))] is as same as the value of legendre.at[abs(m), n, jnp.arange(len(n))] . But when degree $n\neq 0$, the value will be wrong.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.7
jaxlib: 0.4.7
numpy: 1.22.4
python: 3.8.16 (default, Mar 1 2023, 21:19:10) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1

@rajasekharporeddy
Copy link
Contributor

Hi @SGENZO

Thank for reporting the bug. I have opened a PR #20772 on this. This issue will be closed once the PR is merged.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 16, 2024

Thanks – the sph_harm code is unfortunately very poorly implemented and we're considering removing it entirely (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-special). We can probably fix this bug, but long-term I'd suggest finding a different implementation to rely on.

@SGENZO
Copy link
Author

SGENZO commented Apr 17, 2024

@jakevdp
Got it. I'll consider other implementation, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants