Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does pl.dot explicitly convert its output to float32? #199

Open
hr0nix opened this issue Jul 16, 2023 · 1 comment
Open

Why does pl.dot explicitly convert its output to float32? #199

hr0nix opened this issue Jul 16, 2023 · 1 comment

Comments

@hr0nix
Copy link

hr0nix commented Jul 16, 2023

Here: https://github.com/jax-ml/jax-triton/blob/main/jax_triton/pallas/primitives.py#L512

It is a problem if I want to keep my computations within the input dtype (say, bfloat16).

@sharadmv
Copy link
Collaborator

pl.dot is meant to mirror tl.dot (the triton dot function) which would by default accumulate results in f32.

I think tl.dot now has an argument that allows accumulating in a specified style, which we should port to pallas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants