You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I test strided loads in pallas kernels with CUDA backend, pallas.load seems to ignore step in the slice argument. For example, the following code should return [0, 4, 8, 12] but it actually prints [0, 1, 2, 3].
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
def strided(x_ref, o_ref):
x = pl.load(x_ref, slice(0, None, 4))
o_ref[:] = x
x = jnp.arange(16, dtype=jnp.uint32)
out = pl.pallas_call(
strided, out_shape=jax.ShapeDtypeStruct((4,), x.dtype)
)(x)
print(out)
System info (python version, jaxlib version, accelerator, etc.)
Description
When I test strided loads in pallas kernels with CUDA backend,
pallas.load
seems to ignorestep
in the slice argument. For example, the following code should return [0, 4, 8, 12] but it actually prints [0, 1, 2, 3].System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: