Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][FA] Improve warp distribution for attention second dot #5892

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zhanglx13
Copy link
Collaborator

No description provided.

// Check if the result of current tl.dot is used as the operand(0)
// of another tl.dot
bool isChainDotHead(tt::DotOp &dotOp) {
auto filter = [&dotOp](Operation *op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically we want to have self-documenting function/variable name for readability. It's not a big concern here given it's a oneliner, but isInSameRegion would be better than a general filter name here.

Comment on lines +59 to +60
if (isa<tt::DotOp>(op) && (op != dotOp)) {
auto dOp = dyn_cast<tt::DotOp>(op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op should never be dotOp here given we don't set fwdOpt.inclusive to true? To be certain you can put an assert in the above. Also typically in MLIR we do like

if (auto userDotOp = dyn_cast<tt::DotOpInterface>()) { ... }

Then in the middle you don't need to cast again.

// ensure output of the first dot is the operand 0 of the second dot
if (isa<tt::DotOp>(op) && (op != dotOp)) {
auto dOp = dyn_cast<tt::DotOp>(op);
auto op0 = dOp.getOperand(0).getDefiningOp();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to use friendly accessors like .getA().

Comment on lines +62 to +63
if (op0 && std::find(fwdSlices.begin(), fwdSlices.end(), op0) !=
fwdSlices.end()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use fwdSlices.contains here?

@@ -44,6 +44,50 @@ int getWmmaVersion(StringRef archGen) {
return 0;
}

// Check if the result of current tl.dot is used as the operand(0)
// of another tl.dot
bool isChainDotHead(tt::DotOp &dotOp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In these two functions, use DotOpInterface instead of hardcoded tt::DotOp so it works for DotScaledOp too. Also typically for friendly named Ops (that is, tt::DotOp, not Operation *), we don't pass as a reference; we directly pass as a value because Ops are just a wrapper of Operation *.


// Check if the operand(0) of current tl.dot is the result of
// another tl.dot
bool isChainDotTail(tt::DotOp &dotOp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for this function.

Comment on lines +99 to +101
// For FA-like pattern, i.e. result of 1st tl.dot is used as the
// operand(0) of the 2nd dot.
// We use {numWaprs, 1} for both tl.dots
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three lines should be the same sentence.

auto ttDotOp = dyn_cast<tt::DotOp>(dotOp);
if (isChainDotHead(ttDotOp) || isChainDotTail(ttDotOp)) {
if ((shape[0] == shapePerWarp.first) && isChainDotTail(ttDotOp))
return {1, (unsigned)numWarps};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add a lit test for this.

// {1, numWarps} for the 2nd tl.dot to save registers
auto ttDotOp = dyn_cast<tt::DotOp>(dotOp);
if (isChainDotHead(ttDotOp) || isChainDotTail(ttDotOp)) {
if ((shape[0] == shapePerWarp.first) && isChainDotTail(ttDotOp))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a local variable to save the result for isChainDotTail to avoid compute it again. (The C++ compiler should do it but I'm not 100% sure.) Also this is a specific case. I suspect we want to swap here as long as it's more beneficial to distribute along second dot's N dim? That is, ceildiv(shape[0], shapePerWarp.first) < ceildiv(shape[1], shapePerWarp.second)?

@antiagainst antiagainst changed the title [AMD][FA] Force the 2nd dot to have warpsPerCTA={1, numWarps} if BLOCK_M == mDim [AMD][FA] Improve warp distribution for attention second dot Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants