Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Jax tensor donation #3224

Closed
wants to merge 22 commits into from

Conversation

mamanain
Copy link
Collaborator

This is still work in progress, just wanted to get some feedback for this version as well.

The main area of improvement: donation logic should be more in line with the jax documentation.
There are currently two problems here:

  1. Donated buffers can still be used in the function but here it can be used as a buffer straight away. This can lead to data being overwritten and then used for some computations which will lead to errors. So we need to change the logic so that the buffer becomes available only after its last usage.
  2. Right now buffer can be used only once. Should we keep it in the dictionary and reuse in other places where it fits?

Copy link

codecov bot commented Sep 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 90.03%. Comparing base (70fa878) to head (8fc4d95).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3224      +/-   ##
==========================================
+ Coverage   90.01%   90.03%   +0.02%     
==========================================
  Files         445      446       +1     
  Lines       55850    55953     +103     
  Branches     5351     5357       +6     
==========================================
+ Hits        50274    50378     +104     
  Misses       4169     4169              
+ Partials     1407     1406       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@superlopuh superlopuh marked this pull request as draft September 28, 2024 22:50
Copy link
Member

@superlopuh superlopuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this, it's a great first step! A few bigger comments before the nitty-gritty:

  • I'm a little confused by the name of the pass, it's not really converting JAX to linalg, just making sure that the operands marked as donating are actually reused as destinations, so another name feels more appropriate. (I would recommend making the prefix of the pass something that's not convert, as that's usually used for converting one dialect to another, and in this case it's more of an optimisation.) Maybe jax-use-donated-arguments?
  • The test is quite specific, and uses operations from dialects that aren't strictly involved. I would recommend using the "test" dialect as much as possible to generate values that are necessary for the test.
  • It's quite a bit more difficult to review PRs with lots of Pyright errors, could you please fix the errors locally and ping me again to take a look?
  • It would be good to be very clear about the limitations of the proposed approach, and to raise helpful messages if used in an unexpected context, could you please add DiagnosticExceptions if the assumptions are not met and tests for those?

walk_regions_first=True,
)
the_one_pass.rewrite_module(op)
MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done separately by the user, and not included in this pass

xdsl/dialects/tensor.py Outdated Show resolved Hide resolved
name = "bufferization.materialize_in_destination"

source = operand_def(
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType]))
TensorMemrefInferenceConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)

(will need to add the second constraint definition in builtin.py)

@superlopuh
Copy link
Member

As per our discussion, I'm not sure that it's a good idea to merge this until we find a way to get to linalg from JAX again... Let's close the PR for now, keep the changes in the branch, and reopen when we figure out a new way to get to the kernels we want?

@mamanain
Copy link
Collaborator Author

mamanain commented Nov 4, 2024

Yeah, we can close it for now. It turned out to be a known issue in stablehlo. I think, I'll be able to send a pr to them to fix it this week.

@superlopuh
Copy link
Member

I wouldn't recommend spending time on this, I have a better idea that will be less reliant on Jax

@superlopuh
Copy link
Member

Closing for now

@superlopuh superlopuh closed this Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants