Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LHS Registers Part 1 - DotOp Hoisting and SMEM-RF Copy Lowering #18

Open
wants to merge 12 commits into
base: llvm-head
Choose a base branch
from

Conversation

ggengnv
Copy link

@ggengnv ggengnv commented Sep 23, 2024

(Part 2: #19)
Part 1 of "WGMMA with LHS operand in registers" feature.

Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.

This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:

  • In TritonGPU OptimizeDotOperands pass, add optimization to change SS GEMM into RS GEMM, where doing so is possible and beneficial
  • Add TritonGPU -> LLVM lowering for copying from SMEM to RF in MMA v3 dotOperand layout

Being without pipelining, this PR is not expected to see perf gains. Pipelining for MMAv3 operand in registers is added in Part 2.

@ggengnv
Copy link
Author

ggengnv commented Sep 23, 2024

@gflegar @chsigg @vwbaker @Moerafaat This is part 1 of the two PRs split from #17 as suggested by @vwbaker. Please assign reviewers/review, thanks
And for part 2 as well, please: #19

@ggengnv
Copy link
Author

ggengnv commented Sep 23, 2024

Addressed all comments in the original PR that are relevant to part 1 in this PR instead.

@@ -11,6 +11,8 @@
import pytest
import torch
import os
os.environ['TRITON_ALWAYS_COMPILE'] = '1'
os.environ['MLIR_ENABLE_DUMP'] = '1'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like these were leftover from debugging

@Moerafaat
Copy link
Member

Ops, seems that updates we have to maintain the Triton integration will cause PR diffs to break because of force-updates. We might need to figure out a better way to handle this as we didn't intend for this repo to accept incoming PRs. Apologies for this, but you will need to rebase again for the diff to include the proper changes.

@ggengnv ggengnv force-pushed the lhs-reg-hoist branch 2 times, most recently from 927f2ec to 1b95c9a Compare September 25, 2024 21:07
@ggengnv
Copy link
Author

ggengnv commented Sep 25, 2024

@Moerafaat np - I've reapplied my changes on the new main

@@ -1327,16 +1324,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
);

let builders = [
// Specially for MMAV1(Volta)
// For MMAV2 and V3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment should be moved to 2 parts. The "if-block" is for MMAV1(Volta) and the following part is MMAV2 and V3.

@@ -87,8 +87,12 @@ SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
if (!encoding)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe wrap this in a helper method and use it here and below?

"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we provide an explanation within the error message for the restriction?

unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
// H100
if (isHopper()) {
return getTotalElemsPerThread(shape, eltTy);
assert(opIdx == 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function never invoked for Hopper with opIdx = 1? If we are unsure maybe we can maintain the old code-flow for that case then.

return res;
}

if (mmaLayout.isHopper()) { // tensor core v3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe shorter this way?

if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3
  if(mmaLayout.isHopper())
    assert(dotOperandLayout.getOpIdx() == 0);
  res = SharedToDotOperandMMAv2OrV3::convertLayout(
    dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
    smemObj, typeConverter, getThreadId(rewriter, loc));
} else if (mmaLayout.isVolta() && isMMA) { // tensor core v1

@@ -385,6 +385,14 @@ static bool loadIsMMAv3(Operation *loadOp) {
if (!sharedEnc.getHasLeadingOffset())
return false;

// In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you able to figure out this TODO?

!isPureUnaryInlineAsm(currOp) &&
currOp->getDialect()->getTypeID() !=
TypeID::get<arith::ArithDialect>())
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The restriction here was more relaxed than the one above. I think we should only fail on the old conditions here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants