Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is scan supported in pallas? #202

Open
hr0nix opened this issue Jul 17, 2023 · 5 comments
Open

Is scan supported in pallas? #202

hr0nix opened this issue Jul 17, 2023 · 5 comments

Comments

@hr0nix
Copy link

hr0nix commented Jul 17, 2023

I have a kernel code that contains jax.lax.map. It runs fine with interpret=True, however lowering to triton fails with the following error:

E         jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:
E           a:i32[6] = scan[
E           jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let
E               d:i32[] = mul c 64
E               e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E               f:i32[64] = add e d
E               g:bool[64] = lt f 384
E               h:i32[64] <- b[d:d+64]
E               i:i32[] = reduce_min[axes=(0,)] h
E             in (i,) }
E           length=6
E           linear=(False, False)
E           num_carry=0
E           num_consts=1
E           reverse=False
E           unroll=1
E         ] j k

Is it because scan is not supported or is there some other problem? Happy to provide more details if necessary.

@hr0nix
Copy link
Author

hr0nix commented Jul 17, 2023

On a side note, I can't help but notice that pl.load mask (g in the code above) isn't being used in the code, is that expected?

@hr0nix
Copy link
Author

hr0nix commented Jul 18, 2023

On a side note, I can't help but notice that pl.load mask (g in the code above) isn't being used in the code, is that expected?

It appears that masked reads don't work correctly with interpret=True, which is likely related to this.

@sharadmv
Copy link
Collaborator

I suspect you're running into an unsupported use-case for scan. Scan automatically slices into its inputs and outputs which isn't supported in Triton afaik. Could you post the full traceback so I could double check it?

Re:masking could you open that as a separate issue with a repro?

@hr0nix
Copy link
Author

hr0nix commented Jul 18, 2023

Full stacktrace: https://gist.github.com/hr0nix/195f1ece2e6cde792cd0ae0e2fbf6357

After carefully looking at it, it indeed looks like I hit a non-implemented or non-supported code path.

@sharadmv
Copy link
Collaborator

Yes I can confirm that is the issue:

>     if num_extensive: raise NotImplementedError
E     NotImplementedError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants