-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]: Jax tensor donation #3224
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3224 +/- ##
==========================================
+ Coverage 90.01% 90.03% +0.02%
==========================================
Files 445 446 +1
Lines 55850 55953 +103
Branches 5351 5357 +6
==========================================
+ Hits 50274 50378 +104
Misses 4169 4169
+ Partials 1407 1406 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this, it's a great first step! A few bigger comments before the nitty-gritty:
- I'm a little confused by the name of the pass, it's not really converting JAX to linalg, just making sure that the operands marked as donating are actually reused as destinations, so another name feels more appropriate. (I would recommend making the prefix of the pass something that's not convert, as that's usually used for converting one dialect to another, and in this case it's more of an optimisation.) Maybe
jax-use-donated-arguments
? - The test is quite specific, and uses operations from dialects that aren't strictly involved. I would recommend using the "test" dialect as much as possible to generate values that are necessary for the test.
- It's quite a bit more difficult to review PRs with lots of Pyright errors, could you please fix the errors locally and ping me again to take a look?
- It would be good to be very clear about the limitations of the proposed approach, and to raise helpful messages if used in an unexpected context, could you please add
DiagnosticException
s if the assumptions are not met and tests for those?
walk_regions_first=True, | ||
) | ||
the_one_pass.rewrite_module(op) | ||
MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be done separately by the user, and not included in this pass
tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir
Outdated
Show resolved
Hide resolved
xdsl/dialects/bufferization.py
Outdated
name = "bufferization.materialize_in_destination" | ||
|
||
source = operand_def( | ||
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType])) | |
TensorMemrefInferenceConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr) |
(will need to add the second constraint definition in builtin.py)
93a8703
to
7869d09
Compare
As per our discussion, I'm not sure that it's a good idea to merge this until we find a way to get to linalg from JAX again... Let's close the PR for now, keep the changes in the branch, and reopen when we figure out a new way to get to the kernels we want? |
Yeah, we can close it for now. It turned out to be a known issue in stablehlo. I think, I'll be able to send a pr to them to fix it this week. |
I wouldn't recommend spending time on this, I have a better idea that will be less reliant on Jax |
Closing for now |
This is still work in progress, just wanted to get some feedback for this version as well.
The main area of improvement: donation logic should be more in line with the jax documentation.
There are currently two problems here: