-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LHS Registers Part 1 - DotOp Hoisting and SMEM-RF Copy Lowering #18
base: llvm-head
Are you sure you want to change the base?
Conversation
Addressed all comments in the original PR that are relevant to part 1 in this PR instead. |
@@ -11,6 +11,8 @@ | |||
import pytest | |||
import torch | |||
import os | |||
os.environ['TRITON_ALWAYS_COMPILE'] = '1' | |||
os.environ['MLIR_ENABLE_DUMP'] = '1' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like these were leftover from debugging
04ed621
to
b7e2df0
Compare
3596dc5
to
10d3305
Compare
Ops, seems that updates we have to maintain the Triton integration will cause PR diffs to break because of force-updates. We might need to figure out a better way to handle this as we didn't intend for this repo to accept incoming PRs. Apologies for this, but you will need to rebase again for the diff to include the proper changes. |
927f2ec
to
1b95c9a
Compare
1b95c9a
to
942dad4
Compare
@Moerafaat np - I've reapplied my changes on the new main |
@@ -1327,16 +1324,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim | |||
); | |||
|
|||
let builders = [ | |||
// Specially for MMAV1(Volta) | |||
// For MMAV2 and V3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the comment should be moved to 2 parts. The "if-block" is for MMAV1(Volta) and the following part is MMAV2 and V3.
@@ -87,8 +87,12 @@ SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy, | |||
if (!tensorTy) | |||
return inValues; | |||
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding()); | |||
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent()))) | |||
if (!encoding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe wrap this in a helper method and use it here and below?
"Ampere or Hopper MMA parent"; | ||
if (opIdx != 0 && parentAttr.isHopper()) | ||
return emitError() | ||
<< "triton_gpu.dot_op opIdx parameter must be 0 for " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we provide an explanation within the error message for the restriction?
unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( | ||
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const { | ||
auto shapePerCTA = getShapePerCTA(*this, shape); | ||
int warpsPerCTAM = getWarpsPerCTA()[0]; | ||
int warpsPerCTAN = getWarpsPerCTA()[1]; | ||
// H100 | ||
if (isHopper()) { | ||
return getTotalElemsPerThread(shape, eltTy); | ||
assert(opIdx == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function never invoked for Hopper with opIdx = 1? If we are unsure maybe we can maintain the old code-flow for that case then.
return res; | ||
} | ||
|
||
if (mmaLayout.isHopper()) { // tensor core v3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe shorter this way?
if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3
if(mmaLayout.isHopper())
assert(dotOperandLayout.getOpIdx() == 0);
res = SharedToDotOperandMMAv2OrV3::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, typeConverter, getThreadId(rewriter, loc));
} else if (mmaLayout.isVolta() && isMMA) { // tensor core v1
@@ -385,6 +385,14 @@ static bool loadIsMMAv3(Operation *loadOp) { | |||
if (!sharedEnc.getHasLeadingOffset()) | |||
return false; | |||
|
|||
// In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were you able to figure out this TODO?
!isPureUnaryInlineAsm(currOp) && | ||
currOp->getDialect()->getTypeID() != | ||
TypeID::get<arith::ArithDialect>()) | ||
if (!canHoistDotOpEncV2(currOp, dotOpEnc)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The restriction here was more relaxed than the one above. I think we should only fail on the old conditions here.
(Part 2: #19)
Part 1 of "WGMMA with LHS operand in registers" feature.
Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.
This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:
Being without pipelining, this PR is not expected to see perf gains. Pipelining for MMAv3 operand in registers is added in Part 2.