Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent NaN results on Triton matmul kernel #185

Open
karen-sy opened this issue Jun 28, 2023 · 1 comment
Open

Inconsistent NaN results on Triton matmul kernel #185

karen-sy opened this issue Jun 28, 2023 · 1 comment

Comments

@karen-sy
Copy link

I've found a behavior in which the output of jt.triton_call differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.

Specifically, for the Triton repo's matmul kernel (source):

(1) jt.triton_call returns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call
(2) jt.triton_call returns correct results when those metaparameters are selected via triton.autotune (and not directly passed into jt.triton_call)

Also, simply importing Triton's matmul_perf_model (source) further affects this; with the import, the jt.triton_call fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.

I am attaching a script that reproduces this behavior.

I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between args and metaparams seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?

Thanks for the help!

matmul_repro.txt

@sharadmv
Copy link
Collaborator

This repros in Triton as well so it appears to be a Triton compiler issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants