Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running Pallas #224

Open
mehdiataei opened this issue Aug 23, 2023 · 1 comment
Open

Running Pallas #224

mehdiataei opened this issue Aug 23, 2023 · 1 comment

Comments

@mehdiataei
Copy link

Hi,

Installing jax, jaxlib, and jax-triton nightly builds cause the following error:

  File "/home/mehdi/Repos/venvs/jax-dev/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1721, in pallas_call_lowering
    backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
TypeError: to_proto(): incompatible function arguments. The following argument types are supported:
    1. (self: jaxlib.cuda._triton.TritonKernelCall, arg0: str, arg1: str) -> bytes

Invoked with: <jaxlib.cuda._triton.TritonKernelCall object at 0x7faebfce6ef0>, b''
I0000 00:00:1692795278.467295  180484 tfrt_cpu_pjrt_client.cc:469] TfrtCpuClient destroyed.

Package Version Editable project location


absl-py 1.4.0
filelock 3.12.2
jax 0.4.15
jax-triton 0.1.4 /home/mehdi/Repos/jax-triton
jaxlib 0.4.15.dev20230822+cuda12.cudnn89
ml-dtypes 0.2.0
numpy 1.25.2
opt-einsum 3.3.0
pip 22.0.2
scipy 1.11.2
setuptools 59.6.0
triton-nightly 2.1.0.dev20230714011643

@mehdiataei
Copy link
Author

The example code is:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant