Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.insert returning incorrect results wen jitted on Metal #20918

Open
fkeruzore opened this issue Apr 24, 2024 · 1 comment
Open

jax.numpy.insert returning incorrect results wen jitted on Metal #20918

fkeruzore opened this issue Apr 24, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@fkeruzore
Copy link

Description

(pasting from Apple Developer forum)

The jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct.

MWE

import jax
import jax.numpy as jnp

x = jnp.arange(20).reshape(5, 4)
print(f"{x=}\n")

def return_arr_with_ins(arr, ins):
    return jnp.insert(arr, 2, ins, axis=1)

x2 = return_arr_with_ins(x, 99)
print(f"{x2=}\n")

return_arr_with_ins_jit = jax.jit(return_arr_with_ins)
x3 = return_arr_with_ins_jit(x, 99)
print(f"{x3=}\n")

Output

  • x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99
x=Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19]], dtype=int32)

x2=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

x3=Array([[ 0,  1,  2,  3,  0],
       [ 4,  5,  6,  7,  0],
       [ 8,  9, 10, 11,  0],
       [12, 13, 14, 15,  0],
       [16, 17, 18, 19,  0]], dtype=int32)
  • The same code run on a non-metal machine gives the correct results:
x=Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19]], dtype=int32)

x2=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

x3=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.2
python: 3.12.3 | packaged by Anaconda, Inc. | (main, Apr 19 2024, 11:44:52) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
@fkeruzore fkeruzore added the bug Something isn't working label Apr 24, 2024
@shuhand0
Copy link
Collaborator

It is reproducible. The jitted module is incorrectly optimized and we will look into the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants