-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is scan supported in pallas? #202
Comments
On a side note, I can't help but notice that |
It appears that masked reads don't work correctly with interpret=True, which is likely related to this. |
I suspect you're running into an unsupported use-case for scan. Scan automatically slices into its inputs and outputs which isn't supported in Triton afaik. Could you post the full traceback so I could double check it? Re:masking could you open that as a separate issue with a repro? |
Full stacktrace: https://gist.github.com/hr0nix/195f1ece2e6cde792cd0ae0e2fbf6357 After carefully looking at it, it indeed looks like I hit a non-implemented or non-supported code path. |
Yes I can confirm that is the issue:
|
I have a kernel code that contains
jax.lax.map
. It runs fine withinterpret=True
, however lowering to triton fails with the following error:Is it because scan is not supported or is there some other problem? Happy to provide more details if necessary.
The text was updated successfully, but these errors were encountered: