Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import error encountered in jax_triton #264

Open
egg5154 opened this issue Feb 22, 2024 · 7 comments
Open

Import error encountered in jax_triton #264

egg5154 opened this issue Feb 22, 2024 · 7 comments

Comments

@egg5154
Copy link

egg5154 commented Feb 22, 2024

Hello, I was running jax_triton on A100 and CUDA 12.2, but when I run the command python -c 'import jax_triton as jt', error occurs:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/__init__.py", line 19, in <module>
    from jax_triton.triton_lib import triton_call
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/triton_lib.py", line 50, in <module>
    from triton._C.libtriton import ir as tl_ir
ImportError: cannot import name 'ir' from 'triton._C.libtriton' (/lustre/grp/gyqlab/liyh/anaconda3/envs/jax_triton3/lib/python3.10/site-packages/triton/_C/libtriton.so)

My jax_triton was installed following google/jax#18603

@superbobry
Copy link
Contributor

Hi @egg5154, which version of jax_triton and triton do you have?

@egg5154
Copy link
Author

egg5154 commented Feb 22, 2024

Hi @egg5154, which version of jax_triton and triton do you have?

Hello, jax_triton is 0.1.4 and triton(triton-nightly) is 2.1.0.post20231216005823

@superbobry
Copy link
Contributor

superbobry commented Feb 22, 2024

I suspect you might need the main version of jax_triton. There was a number of refactorings in the Triton Python APIs, and the main version should be up to date.

@egg5154
Copy link
Author

egg5154 commented Feb 22, 2024

I suspect you might need the main version of jax_triton. There was a number of refactorings in the Triton Python APIs, and the main version should be up to date.

Hello @superbobry , I changed to the main version but the error occurs again.
The original import code causes error in jax_triton is from triton._C.libtriton import ir as tl_ir, while using from triton._C.libtriton.triton import ir as tl_ir instead will bypass the error.
I thought that triton's version may be the key point, so I tried to change the version of triton to 2.0.0/2.1.0/2.2.0 but there's no help.

@superbobry
Copy link
Contributor

Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working jax, jax_triton and triton combination.

If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither jax_triton, nor triton, as long as you install the nightly jaxlib and jax (starting tomorrow).

@egg5154
Copy link
Author

egg5154 commented Feb 23, 2024

Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working jax, jax_triton and triton combination.

If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither jax_triton, nor triton, as long as you install the nightly jaxlib and jax (starting tomorrow).

Thanks! Actually I want to use flash-attention in jax, it seems that Pallas can be used instead of Triton?

@superbobry
Copy link
Contributor

Yeah, you could use Pallas, which would lower to Triton on GPU without using Triton Python APIs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants