-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD][FA] Improve warp distribution for attention second dot #5892
base: main
Are you sure you want to change the base?
Conversation
f855bf6
to
18be627
Compare
// Check if the result of current tl.dot is used as the operand(0) | ||
// of another tl.dot | ||
bool isChainDotHead(tt::DotOp &dotOp) { | ||
auto filter = [&dotOp](Operation *op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically we want to have self-documenting function/variable name for readability. It's not a big concern here given it's a oneliner, but isInSameRegion
would be better than a general filter
name here.
if (isa<tt::DotOp>(op) && (op != dotOp)) { | ||
auto dOp = dyn_cast<tt::DotOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op
should never be dotOp
here given we don't set fwdOpt.inclusive
to true
? To be certain you can put an assert
in the above. Also typically in MLIR we do like
if (auto userDotOp = dyn_cast<tt::DotOpInterface>()) { ... }
Then in the middle you don't need to cast
again.
// ensure output of the first dot is the operand 0 of the second dot | ||
if (isa<tt::DotOp>(op) && (op != dotOp)) { | ||
auto dOp = dyn_cast<tt::DotOp>(op); | ||
auto op0 = dOp.getOperand(0).getDefiningOp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer to use friendly accessors like .getA()
.
if (op0 && std::find(fwdSlices.begin(), fwdSlices.end(), op0) != | ||
fwdSlices.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use fwdSlices.contains
here?
@@ -44,6 +44,50 @@ int getWmmaVersion(StringRef archGen) { | |||
return 0; | |||
} | |||
|
|||
// Check if the result of current tl.dot is used as the operand(0) | |||
// of another tl.dot | |||
bool isChainDotHead(tt::DotOp &dotOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In these two functions, use DotOpInterface
instead of hardcoded tt::DotOp
so it works for DotScaledOp
too. Also typically for friendly named Op
s (that is, tt::DotOp
, not Operation *
), we don't pass as a reference; we directly pass as a value because Op
s are just a wrapper of Operation *
.
|
||
// Check if the operand(0) of current tl.dot is the result of | ||
// another tl.dot | ||
bool isChainDotTail(tt::DotOp &dotOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly for this function.
// For FA-like pattern, i.e. result of 1st tl.dot is used as the | ||
// operand(0) of the 2nd dot. | ||
// We use {numWaprs, 1} for both tl.dots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three lines should be the same sentence.
auto ttDotOp = dyn_cast<tt::DotOp>(dotOp); | ||
if (isChainDotHead(ttDotOp) || isChainDotTail(ttDotOp)) { | ||
if ((shape[0] == shapePerWarp.first) && isChainDotTail(ttDotOp)) | ||
return {1, (unsigned)numWarps}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add a lit test for this.
// {1, numWarps} for the 2nd tl.dot to save registers | ||
auto ttDotOp = dyn_cast<tt::DotOp>(dotOp); | ||
if (isChainDotHead(ttDotOp) || isChainDotTail(ttDotOp)) { | ||
if ((shape[0] == shapePerWarp.first) && isChainDotTail(ttDotOp)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use a local variable to save the result for isChainDotTail
to avoid compute it again. (The C++ compiler should do it but I'm not 100% sure.) Also this is a specific case. I suspect we want to swap here as long as it's more beneficial to distribute along second dot's N dim? That is, ceildiv(shape[0], shapePerWarp.first) < ceildiv(shape[1], shapePerWarp.second)
?
No description provided.