Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example notebook with attention isn't working #180

Open
hr0nix opened this issue Jun 24, 2023 · 3 comments
Open

Example notebook with attention isn't working #180

hr0nix opened this issue Jun 24, 2023 · 3 comments

Comments

@hr0nix
Copy link

hr0nix commented Jun 24, 2023

I've tried running the example notebook using jax_triton and jaxlib installed from head. Unfortunately, it doesn't seem to work: running the cell with test_triton_jax(2, 32, 2048, 64) hangs indefinitely.

  • Should the example work?
  • Is there a tested version of flax attention that is guaranteed to work with latest jax_triton?

Thanks!

@sharadmv
Copy link
Collaborator

Should the example work?

Sorry about that, I don't expect that that particular example will work. The examples don't have regression tests on them and the notebook hasn't been touched for a while. I haven't had the bandwidth to keep the examples up to date with Triton changes.

Is there a tested version of flax attention that is guaranteed to work with latest jax_triton?

Do you mean flash attention? If so, the implementation in jax_triton/pallas/ops/attention.py should always be working.

@hr0nix
Copy link
Author

hr0nix commented Jun 26, 2023

Thanks for a quick answer!

Do you mean flash attention?

Oops. I guess I've been thinking a lot about flax lately.

If so, the implementation in jax_triton/pallas/ops/attention.py should always be working.

I don't fully understand what pallas is and what pallas ops are. Are they supposed to be as performant as native triton kernels, but more powerful because they can be vmapped etc.?

@sharadmv
Copy link
Collaborator

I don't fully understand what pallas is and what pallas ops are. Are they supposed to be as performant as native triton kernels, but more powerful because they can be vmapped etc.?

Pallas is an extension to JAX that allows you to write your Triton kernels using JAX directly. As you surmised, one of the core benefits is compatibility w/ JAX transformations (vmap just works, AD is WIP). Personally, I also find it a more friendly front-end than Triton's as well.

Are they supposed to be as performant as native triton kernels?

That is the goal, though we since we are generating Triton kernels, we might hit some unoptimized code paths. The main difference between the two is that Pallas will handle a lot of the pointer arithmetic indexing logic that you normally have to do by hand in Triton. As a result, we might not generate the exact indexing logic that the fastest Triton kernel might.

If you find any performance gaps, please let us know and we can investigate!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants