-
Notifications
You must be signed in to change notification settings - Fork 63
[optimize-dot-operands]: Fuse load and trans operations - part 3 #4537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
…tt.dot Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Depends on: #4468 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances dot operands optimization by fusing load and transpose operations in separate loops when the def‑use chains originate from a make_tensor_ptr, and by refactoring cleanup routines.
- Added a new optimization pass (optimize_dot_operands) in multiple backend components.
- Introduced a new eraseOperations utility and refactored fusion logic in OptimizeDotOperands.cpp.
- Updated test cases to validate proper fusion and non‐fused behavior.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
third_party/intel/triton_xpu.cc | Added optimize_dot_operands pass registration. |
third_party/intel/lib/Utils/Utility.cpp | Added a new eraseOperations function for cleanup operations. |
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp | Refactored fusion logic and propagation routines to support optimized chaining. |
third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp | Removed redundant finalize() in favor of using eraseOperations. |
third_party/intel/include/Utils/Utility.h | Declared the new eraseOperations function. |
third_party/intel/backend/compiler.py | Registered the new optimize_dot_operands pass in the compiler backend. |
test/TritonIntelGPU/dot-operands.mlir | Updated test cases to reflect changes in fusion behavior and new pass functionality. |
Comments suppressed due to low confidence (2)
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:161
- [nitpick] The singleUsersInChain function is quite complex; consider refactoring the logic or adding more inline comments to improve readability and maintainability.
// Determine whether all operations in the def-use chain from \p start to
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:112
- [nitpick] Consider renaming the lambda 'usedByDotOp' to a more descriptive name such as 'isChainedToDotOp' to clarify its purpose.
auto usedByDotOp = [](tt::TransOp transOp) {
…tt.dot Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
@@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # | |||
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): | |||
start_n = tl.multiple_of(start_n, BLOCK_N) | |||
# -- compute qk ---- | |||
k = desc_k.load([0, offsetk_y]) | |||
if dtype == tl.float8e5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp16 we undo the source code changes we made and the code is now back to the original. For FP8 we keep the source code changes until we can issue DPAS instructions for them (after making 2 fp8 elems into a fp16).
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |||
%c1_i64 = arith.constant 1 : i64 | |||
%c1024_i64 = arith.constant 1024 : i64 | |||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> | |||
%0 = tt.get_program_id x : i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making the test simpler here
ping @whitneywhtsang, @chengjunlu, @LiyangLingIntel any comments ? |
@@ -34,6 +36,94 @@ namespace mlir::triton::gpu::intel { | |||
|
|||
namespace { | |||
|
|||
// Represent a def-use chain rooted at 'start' and terminating at tt.trans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does start always be make tensor ptr op? Should we use MakeTensorPtrOp type instead of generic operator type for start?
"operation"); | ||
} | ||
bool operator<(const Chain &other) const { | ||
return start < other.start || end < other.end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if there are two chains, chain 1 [1, 4], chain 2 [2, 3], so chain1.start < chain2.start and chain1.end > chain2.end, then which chain is bigger?
"operation"); | ||
} | ||
bool operator<(const Chain &other) const { | ||
return start < other.start || end < other.end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we want to compare address, should it be something like below?
return start < other.start || end < other.end; | |
return start < other.start || &end < &other.end; |
Chains &sameRootChains = rootToChains[start]; | ||
sameRootChains.insert(otherChain); | ||
rootToChains[start] = sameRootChains; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work?
Chains &sameRootChains = rootToChains[start]; | |
sameRootChains.insert(otherChain); | |
rootToChains[start] = sameRootChains; | |
rootToChains[start].insert(otherChain); |
}); | ||
|
||
// If the same operation is the root of multiple chains, duplicate it to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would be cleaner to have this duplication logic in a separate function?
|
||
// Prune candidate chains containing load/trans operations that cannot be | ||
// safely fused. | ||
prune(chains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think is worth pruning rootToChains
with chains that contain at least one candidate first? That way we won't clone the chain if there will be no candidates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thought is to have a flag to indicate if we want to clone for a particular root in rootToChains
,
if no candidate or all candidates, then no need to clone.
Thanks @whitneywhtsang for the prompt review! |
Enhance the transformation to allow laod+transpose fusion in separate for loops when the def-use chains corresponding to the 2 load+transpose instances originate at the
make_tensor_ptr
operation.