Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas strided loads ignore the stride argument on CUDA backend #20895

Open
hirayaku opened this issue Apr 23, 2024 · 0 comments
Open

Pallas strided loads ignore the stride argument on CUDA backend #20895

hirayaku opened this issue Apr 23, 2024 · 0 comments
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@hirayaku
Copy link

Description

When I test strided loads in pallas kernels with CUDA backend, pallas.load seems to ignore step in the slice argument. For example, the following code should return [0, 4, 8, 12] but it actually prints [0, 1, 2, 3].

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def strided(x_ref, o_ref):
    x = pl.load(x_ref, slice(0, None, 4))
    o_ref[:] = x

x = jnp.arange(16, dtype=jnp.uint32)
out = pl.pallas_call(
    strided, out_shape=jax.ShapeDtypeStruct((4,), x.dtype)
)(x)
print(out)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.12.3 | packaged by Anaconda, Inc. | (main, Apr 19 2024, 16:50:38) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='HomeLinux', release='6.1.0-9-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.1.27-1 (2023-05-08)', machine='x86_64')

$ nvidia-smi
Tue Apr 23 14:54:02 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:17:00.0 Off |                  Off |
|  0%   56C    P2    39W / 450W |    495MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
@hirayaku hirayaku added the bug Something isn't working label Apr 23, 2024
@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label Apr 24, 2024
@superbobry superbobry self-assigned this May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants