Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dispatch] Fold collapse(expand) unit dims #19357

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Dec 3, 2024

Fold unit dimensions for the pattern collapse(expand) by removing any unnecessary unit dims that are expanded out and subsequently collapsed again. This prevents the collapse/expand ops, which contain unneeded unit dims, from being propagated and introducing unit dims back to linalg ops.


Doesn't quite resolve #19263 because there are flow.tensor.bitcast ops that are preventing the unit dims from being completely folded:

%expanded_24 = tensor.expand_shape %214 [[0, 1, 2], [3]] output_shape [4, %21, 1, 128] : tensor<?x128xf16> into tensor<4x?x1x128xf16>
%323 = flow.tensor.bitcast %expanded_24 : tensor<4x?x1x128xf16>{%21} -> tensor<4x?x1x64xcomplex<f16>>{%21}
%collapsed_40 = tensor.collapse_shape %323 [[0], [1, 2], [3]] : tensor<4x?x1x64xcomplex<f16>> into tensor<4x?x64xcomplex<f16>>

@IanWood1 IanWood1 force-pushed the fold_reshape_unit_dims branch from 515c632 to 3849574 Compare December 12, 2024 01:43
Fold unit dimensions for the pattern `collapse(expand)` by removing any
unnecessary unit dims that are expanded out and subsequently collapsed
again.

This collapse/expand ops which contain unneeded unit dims from being
propagated and introducing unit dims back to linalg ops.

Signed-off-by: Ian Wood <[email protected]>
(1) It wasn't checking if the new reassoc was appended before calling
`.back()`. (2) Unit dims cannot be dropped if they weren't expanded.

Signed-off-by: Ian Wood <[email protected]>
Signed-off-by: Ian Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama 3.1 8B fp16 TP8 sharded fails to compile for CPU and GPU
1 participant