Skip to content

[optimize-dot-operands]: Fuse load and trans operations - part 3 #4537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

etiotto
Copy link
Contributor

@etiotto etiotto commented Jun 18, 2025

Enhance the transformation to allow laod+transpose fusion in separate for loops when the def-use chains corresponding to the 2 load+transpose instances originate at the make_tensor_ptr operation.

@etiotto etiotto self-assigned this Jun 18, 2025
@etiotto
Copy link
Contributor Author

etiotto commented Jun 18, 2025

Depends on: #4468

@etiotto etiotto changed the title Etiotto.merge load with trans.3 [optimize-dot-operands]: Fuse load and trans operations - part 3 Jun 19, 2025
@etiotto etiotto requested a review from Copilot June 19, 2025 15:49
@etiotto etiotto marked this pull request as ready for review June 19, 2025 15:49
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances dot operands optimization by fusing load and transpose operations in separate loops when the def‑use chains originate from a make_tensor_ptr, and by refactoring cleanup routines.

  • Added a new optimization pass (optimize_dot_operands) in multiple backend components.
  • Introduced a new eraseOperations utility and refactored fusion logic in OptimizeDotOperands.cpp.
  • Updated test cases to validate proper fusion and non‐fused behavior.

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Show a summary per file
File Description
third_party/intel/triton_xpu.cc Added optimize_dot_operands pass registration.
third_party/intel/lib/Utils/Utility.cpp Added a new eraseOperations function for cleanup operations.
third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp Refactored fusion logic and propagation routines to support optimized chaining.
third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp Removed redundant finalize() in favor of using eraseOperations.
third_party/intel/include/Utils/Utility.h Declared the new eraseOperations function.
third_party/intel/backend/compiler.py Registered the new optimize_dot_operands pass in the compiler backend.
test/TritonIntelGPU/dot-operands.mlir Updated test cases to reflect changes in fusion behavior and new pass functionality.
Comments suppressed due to low confidence (2)

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:161

  • [nitpick] The singleUsersInChain function is quite complex; consider refactoring the logic or adding more inline comments to improve readability and maintainability.
  // Determine whether all operations in the def-use chain from \p start to

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp:112

  • [nitpick] Consider renaming the lambda 'usedByDotOp' to a more descriptive name such as 'isChainedToDotOp' to clarify its purpose.
    auto usedByDotOp = [](tt::TransOp transOp) {

@etiotto etiotto requested a review from anmyachev June 20, 2025 20:27
@@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([0, offsetk_y])
if dtype == tl.float8e5:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp16 we undo the source code changes we made and the code is now back to the original. For FP8 we keep the source code changes until we can issue DPAS instructions for them (after making 2 fp8 elems into a fp16).

@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
%c1_i64 = arith.constant 1 : i64
%c1024_i64 = arith.constant 1024 : i64
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0 = tt.get_program_id x : i32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making the test simpler here

@etiotto
Copy link
Contributor Author

etiotto commented Jun 23, 2025

ping @whitneywhtsang, @chengjunlu, @LiyangLingIntel any comments ?

@@ -34,6 +36,94 @@ namespace mlir::triton::gpu::intel {

namespace {

// Represent a def-use chain rooted at 'start' and terminating at tt.trans
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does start always be make tensor ptr op? Should we use MakeTensorPtrOp type instead of generic operator type for start?

"operation");
}
bool operator<(const Chain &other) const {
return start < other.start || end < other.end;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are two chains, chain 1 [1, 4], chain 2 [2, 3], so chain1.start < chain2.start and chain1.end > chain2.end, then which chain is bigger?

"operation");
}
bool operator<(const Chain &other) const {
return start < other.start || end < other.end;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to compare address, should it be something like below?

Suggested change
return start < other.start || end < other.end;
return start < other.start || &end < &other.end;

Comment on lines +188 to +190
Chains &sameRootChains = rootToChains[start];
sameRootChains.insert(otherChain);
rootToChains[start] = sameRootChains;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
Chains &sameRootChains = rootToChains[start];
sameRootChains.insert(otherChain);
rootToChains[start] = sameRootChains;
rootToChains[start].insert(otherChain);

});

// If the same operation is the root of multiple chains, duplicate it to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be cleaner to have this duplication logic in a separate function?


// Prune candidate chains containing load/trans operations that cannot be
// safely fused.
prune(chains);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think is worth pruning rootToChains with chains that contain at least one candidate first? That way we won't clone the chain if there will be no candidates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thought is to have a flag to indicate if we want to clone for a particular root in rootToChains,
if no candidate or all candidates, then no need to clone.

@etiotto
Copy link
Contributor Author

etiotto commented Jun 24, 2025

Thanks @whitneywhtsang for the prompt review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TransOp fusion]: Fuse tt.trans with tt.load to expoit 2D block read operations
2 participants