Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tcp.gather_nd and rework index.Tensor_hacked_twin to use gather_nd #101

Merged
merged 5 commits into from
Nov 7, 2024

Conversation

matthewfl
Copy link
Contributor

@matthewfl matthewfl commented Oct 18, 2024

Add tcp.gather_nd and rework index.Tensor_hacked_twin to use gather_nd

Copy link
Collaborator

@navahgar navahgar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for taking this up.

Some minor comments inline.

include/mlir-tcp/Dialect/IR/TcpOps.td Outdated Show resolved Hide resolved
lib/Conversion/TcpToLinalg/DataMovement.cpp Outdated Show resolved Hide resolved
lib/Conversion/TcpToLinalg/DataMovement.cpp Outdated Show resolved Hide resolved
lib/Conversion/TcpToLinalg/DataMovement.cpp Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
@sjain-stanford
Copy link
Collaborator

sjain-stanford commented Oct 23, 2024

Curious, does this functionality need additional e2e tests in aot_compile or do existing tests already cover this op? If former, could you include e2e tests as well?

@matthewfl
Copy link
Contributor Author

matthewfl commented Oct 23, 2024

@sjain-stanford there is an e2e test for the index.Tensor_hacked_twin already. This PR changes the index.Tensor_hacked_twin to use tcp.gather_nd whereas before it was using tcp.gather. The reason for the change is that more complex indexing that is supported by index.Tensor_hacked_twin requires being able to gather over multiple dimensions at the same time. E.g. you can use index.Tensor_hacked_twin to select the diagonal of a matrix by doing something like x[torch.arange(x.shape[0]), torch.arange(x.shape[0])].

I will fix the nits that you pointed out. I am waiting to merge this PR as I would like to validate that the index.Tensor_hacked_twin change works but I am currently blocked by trt-mlir upgrade for testing the tcp.gather_nd. I already know that there needs to be some additional casting added to the index.Tensor_hacked_twin -> tcp.gather_nd to handle cases where different int types are used

@matthewfl matthewfl merged commit 0f4b396 into cruise-automation:main Nov 7, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

3 participants