Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is jax-triton replaced by Pallas? #273

Closed
gautierronan opened this issue Apr 11, 2024 · 2 comments
Closed

Is jax-triton replaced by Pallas? #273

gautierronan opened this issue Apr 11, 2024 · 2 comments

Comments

@gautierronan
Copy link

I don't quite understand the relation between triton, jax-triton and jax.experimental.pallas.

  1. Which the last two is recommended today? It looks like there is much more activity on jax.experimental.pallas.
  2. Does jax.experimental.pallas depend on jax-triton? Or is it independent?

Thanks.

@clintg6
Copy link

clintg6 commented Apr 16, 2024

I would like to hear clarification of this as well by the team. Does Pallas replace jax-triton?

@sharadmv
Copy link
Collaborator

Triton is an open source compiler and Python-embedded DSL for writing GPU kernels. You write Python code and some internal parts of Triton convert that Python code into the TTIR and TTGIR MLIR dialects which are then compiled by the Triton compiler into PTX. Triton ships with bindings to launch this PTX with Pytorch tensor inputs/outputs.

JAX-Triton is a standalone library that enables writing Triton kernels using the Triton Python DSL but instead of using Pytorch, you can now use the compiled kernels with JAX arrays instead of Pytorch tensors. It re-uses portions of Triton internals to convert the user's Python kernel into PTX and to integrate with JAX, it uses an XLA Custom call.

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU directly using JAX and JAX-like APIs. It directly integrates with JAX and JAX transformations. The user writes a kernel in Python using JAX APIs and the kernel is traced using JAX's tracing machinery into Jaxpr, JAX's internal representation. On TPU, the Jaxpr is converted into Mosaic, a standalone compiler for TPU kernels ("Triton for TPUs" if you will). On GPU, the Jaxpr is converted into TTIR and TTGIR, Triton's MLIR dialects. We then emit a custom call with serialized TTIR/TTGIR that XLA GPU knows how to handle (it already ships with the Triton compiler). As of recently, it no longer has a dependency on JAX-Triton.

Re: is Pallas replacing JAX-Triton? The short answer is no. JAX-Triton is meant specifically if you want to use the Triton Python DSL to write kernels or to reuse existing Triton kernels. The Pallas frontend exists so you can write kernels with a more JAX-like experience. They have different purposes and I don't have a strong recommendation one way or another. I will say that Pallas is easier to install though, because JAX-Triton is usually pinned to a specific Triton commit, not a public release.

There is a lot more activity on Pallas mainly because the TPU kernels language is being actively worked on. The GPU parts have changed less because we are relying on the Triton compiler for most of the heavy lifting.

JAX-Triton doesn't really need active maintenance other than when we update the Triton version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants