Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.softmax is inconsistent under jax.jit #20856

Open
francois-rozet opened this issue Apr 21, 2024 · 7 comments
Open

jax.nn.softmax is inconsistent under jax.jit #20856

francois-rozet opened this issue Apr 21, 2024 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@francois-rozet
Copy link

Description

In the following code, jax.nn.softmax returns different results under jax.jit.

import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[1.]]

 [[1.]]

 [[1.]]]

Using a "numerically safe" version of softmax based on log_softmax solves the issue.

def f(x):
    return jnp.exp(jax.nn.log_softmax(x, axis=0))

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.2
python: 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:49:32)  [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='thinkpad', release='6.5.0-0.deb12.4-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.5.10-1~bpo12+1 (2023-11-23)', machine='x86_64')
@francois-rozet
Copy link
Author

francois-rozet commented Apr 21, 2024

Note that printing within f weirdly fixes the issue.

import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    jax.debug.print('{}', x[0, 0])
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))
[1.8160863]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[1.8160863]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 21, 2024

Thanks for the report – it looks like this is a CPU-only issue; I get the expected result when running on GPU.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 21, 2024

Also seems like it was introduced in JAX v0.4.22; JAX v0.4.21 and earlier returns the expected results for the handful of versions I've tried.

@francois-rozet
Copy link
Author

francois-rozet commented Apr 21, 2024

it looks like this is a CPU-only issue

Oh I should have checked. Also, it could be related to this comment:

# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: ArrayLike,

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 21, 2024

It looks like this bug appears only when the XLA softmax rewriter is enabled, so it's likely related to this XLA change: openxla/xla#7540

Determined this by running the following on JAX v0.4.21 and 0.4.22:

import jax
import jax.numpy as jnp
print(jax.__version__)

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))

print(jax.jit(f).lower(x).compile().as_text())

Outputs:

0.4.22
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[1.]]

 [[1.]]

 [[1.]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
  %Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
  ROOT %custom-call = f32[3,1,1]{2,1,0} custom-call(f32[3,1,1]{2,1,0} %Arg_0.1), custom_call_target="__onednn$softmax", metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-4-b7a605546166>" source_line=10}
}
0.4.21
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] {
  %Arg_0.5 = f32[] parameter(0)
  %Arg_1.6 = f32[] parameter(1)
  ROOT %maximum.7 = f32[] maximum(f32[] %Arg_0.5, f32[] %Arg_1.6), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%region_1.15 (Arg_0.16: f32[], Arg_1.17: f32[]) -> f32[] {
  %Arg_0.16 = f32[] parameter(0)
  %Arg_1.17 = f32[] parameter(1)
  ROOT %add.18 = f32[] add(f32[] %Arg_0.16, f32[] %Arg_1.17), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%fused_computation (param_0: f32[3,1,1], param_1.2: f32[1,1]) -> f32[3,1,1] {
  %param_0 = f32[3,1,1]{2,1,0} parameter(0)
  %param_1.2 = f32[1,1]{1,0} parameter(1)
  %bitcast.2 = f32[] bitcast(f32[1,1]{1,0} %param_1.2), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %broadcast.2 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.2), dimensions={}, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %divide.0 = f32[3,1,1]{2,1,0} divide(f32[3,1,1]{2,1,0} %param_0, f32[3,1,1]{2,1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%fused_computation.1 (param_0.2: f32[3,1,1], param_1.5: f32[1,1]) -> f32[3,1,1] {
  %param_0.2 = f32[3,1,1]{2,1,0} parameter(0)
  %param_1.5 = f32[1,1]{1,0} parameter(1)
  %bitcast.3 = f32[] bitcast(f32[1,1]{1,0} %param_1.5), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %broadcast.3 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.3), dimensions={}, metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %subtract.0 = f32[3,1,1]{2,1,0} subtract(f32[3,1,1]{2,1,0} %param_0.2, f32[3,1,1]{2,1,0} %broadcast.3), metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %exponential.0 = f32[3,1,1]{2,1,0} exponential(f32[3,1,1]{2,1,0} %subtract.0), metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
  %Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
  %constant.3 = f32[] constant(-inf)
  %reduce.8 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %Arg_0.1, f32[] %constant.3), dimensions={0}, to_apply=%region_0.4, metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %fusion.1 = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %Arg_0.1, f32[1,1]{1,0} %reduce.8), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %constant.2 = f32[] constant(0)
  %reduce.19 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %fusion.1, f32[] %constant.2), dimensions={0}, to_apply=%region_1.15, metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %fusion = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %fusion.1, f32[1,1]{1,0} %reduce.19), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 21, 2024

And it looks like this is "fixed" on JAX's main branch, probably because #20643 changed the raw HLO so that the XLA rewrite logic no longer recognizes it as as softmax. Yeesh, what a mess.

@NeilGirdhar
Copy link
Contributor

Whoa that is a mess. Just looking at the onednn-softmax, it only supports axis=-1 (it's hard-coded).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants