Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas Broken after making JAX-Triton calls serializable update #179

Open
adam-hartshorne opened this issue Jun 24, 2023 · 6 comments
Open

Comments

@adam-hartshorne
Copy link

The new update has broken pallas again with the error shown below. I have tried updating jax / jaxlib to head and the issue persists.

I0624 11:43:47.866600 139777320678464 xla_bridge.py:568] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
I0624 11:43:47.866749 139777320678464 xla_bridge.py:568] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0624 11:43:47.866775 139777320678464 xla_bridge.py:568] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(4, 8)
Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 231, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 226, in main
    print(pl.pallas_call(kernel1, out_shape=out_shape, grid=grid)(x, y))
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/pallas_call.py", line 352, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. jaxlib.cuda._triton.TritonKernel(arg0: str, arg1: str, arg2: int, arg3: int)

pallas_error.txt

@sharadmv
Copy link
Collaborator

Could you provide your JAX Triton and JAXlib versions?

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Jun 26, 2023

Ok, so I have gone back through and created a fresh install.

jax / jaxlib 0.4.12 with cuda 11 work with the latest commits.

jax / jaxlib 0.4.13 is where I get the aforementioned error (but these versions work for code prior to the serialisation change commit f947255).

@aniquetahir
Copy link

I'm having the same issue.

@sharadmv
Copy link
Collaborator

sharadmv commented Jul 5, 2023

I haven't been able to reproduce this but I used jaxlib-0.4.14 from the nightlies: https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html. (Specifically the 0705 one).

@aniquetahir
Copy link

I was using the cuda 11 version of jaxlib/jax

@sharadmv
Copy link
Collaborator

sharadmv commented Jul 5, 2023

Are you able to try a nightly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants