New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.nn.softmax is inconsistent under jax.jit #20856
Comments
Note that printing within import jax
import jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
jax.debug.print('{}', x[0, 0])
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
|
Thanks for the report – it looks like this is a CPU-only issue; I get the expected result when running on GPU. |
Also seems like it was introduced in JAX v0.4.22; JAX v0.4.21 and earlier returns the expected results for the handful of versions I've tried. |
Oh I should have checked. Also, it could be related to this comment: Lines 536 to 538 in e498bca
|
It looks like this bug appears only when the XLA softmax rewriter is enabled, so it's likely related to this XLA change: openxla/xla#7540 Determined this by running the following on JAX v0.4.21 and 0.4.22: import jax
import jax.numpy as jnp
print(jax.__version__)
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
print(jax.jit(f).lower(x).compile().as_text()) Outputs:
|
And it looks like this is "fixed" on JAX's main branch, probably because #20643 changed the raw HLO so that the XLA rewrite logic no longer recognizes it as as softmax. Yeesh, what a mess. |
Whoa that is a mess. Just looking at the onednn-softmax, it only supports |
Description
In the following code,
jax.nn.softmax
returns different results underjax.jit
.Using a "numerically safe" version of
softmax
based onlog_softmax
solves the issue.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: