Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas upstream is now working? #219

Open
sh0416 opened this issue Aug 12, 2023 · 2 comments
Open

Pallas upstream is now working? #219

sh0416 opened this issue Aug 12, 2023 · 2 comments

Comments

@sh0416
Copy link

sh0416 commented Aug 12, 2023

I saw the pallas concept in official latest jax docs, and follow up the pallas quickstart section.

I installed latest jaxlib and jax using github head.

I encountered the following error.

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
    print(add_vectors(jnp.arange(8), jnp.arange(8)))
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 17, in add_vectors
    return pl.pallas_call(add_vectors_kernel, out_shape=out_shape)(x, y)
  File "/home/sh0416/research/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 353, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
    print(add_vectors(jnp.arange(8), jnp.arange(8)))
NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda

Do I have to install something different? or has it not fully upstreamed yet?

@sharadmv
Copy link
Collaborator

I think you might not have Triton installed? If Triton is installed properly the pallas_call should work.

@sh0416
Copy link
Author

sh0416 commented Aug 12, 2023

Right.. Actually, my workload needs both torch and jax, but recent torch requires triton==2.0.0 while jax-triton requires triton-nightly.

There is version conflict so that I decided to use triton==2.0.0 instead of triton-nightly.

Could it be possible under this setting? Currently, triton==2.0.0 is installed in my environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants