Skip to content

Commit

Permalink
[BACKEND] Propagate mma layout to following elementwise operations. (…
Browse files Browse the repository at this point in the history
…#3973)

For matmul with following arithmetic operations such as `acc +=
tl.dot(a, b)`, currently the mma layout of the `dot` result isn't
propagated into the subsequent `add`. As a result when the dot is inside
a loop, there will be repeated layout conversion from mma to blocked.
I'm fixing this by allowing mma layout propagated so that it can be
reused.
htyu authored Oct 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ed39cb0 commit 1064b59
Showing 3 changed files with 45 additions and 107 deletions.
95 changes: 3 additions & 92 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
@@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() {
op->erase();
}

// Look ahead to at the transitive uses and see if there is a convert to mma
// operations.
bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
SmallVector<Value> queue = {op->getResult(0)};
SetVector<Operation *> forwardSlice;
llvm::SmallDenseSet<Value> seen;
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
// HACK: Stop propagation if the ReduceOp is using mma layout but is
// producing tensor smaller than the layout we would like to propagate.
// This is to avoid stepping into the known bug.
if (isa<mlir::triton::ReduceOp>(op)) {
auto tensorType =
dyn_cast<RankedTensorType>(op->getOperand(0).getType());
if (tensorType &&
isa<NvidiaMmaEncodingAttr>(tensorType.getEncoding())) {
auto mmaInstrShape =
cast<NvidiaMmaEncodingAttr>(encoding).getInstrShape();
if (tensorType.getShape()[tensorType.getRank() - 2] <
mmaInstrShape[0] ||
tensorType.getShape()[tensorType.getRank() - 1] <
mmaInstrShape[1]) {
return false;
}
}
}

if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
Attribute dstEncoding = convertOp.getType().getEncoding();
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding))
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (isa<triton::gpu::AMDMfmaEncodingAttr,
triton::gpu::AMDWmmaEncodingAttr>(dstEncoding))
return true;
if (isa<triton::gpu::DotOperandEncodingAttr>(dstEncoding)) {
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
return mmaLayout.getVersionMajor() > 1;
} else {
assert((mlir::isa<triton::gpu::AMDMfmaEncodingAttr,
triton::gpu::AMDWmmaEncodingAttr>(encoding)));
return true;
}
}
}
bool isMMAV3 =
isa<NvidiaMmaEncodingAttr>(encoding) &&
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
return true;
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
continue;
if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp())) {
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def &&
(forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(ifOp.getResult(operand.getOperandNumber()));
}
}
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
if (!forOp)
continue;
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber()));
}
}
}
return false;
}

// Return true if the op is an op with a layout we don't want to change. We will
// propagate the layout starting from anchor ops.
bool isLayoutAnchor(Operation *op) {
@@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) {
}

void LayoutPropagation::initAnchorLayout() {
auto maybeAddAnchor = [&](Value v) {
auto addAnchor = [&](Value v) {
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
// Workaround, don't popagate MMA layout unless there is a convert
// back to mma further down to avoid generating reduction with MMA
// layout that may have lower performance.
// This can be improved with more aggressive backward propagation.
if (isa<MmaEncodingTrait>(tensorType.getEncoding()) &&
v.getDefiningOp() &&
!hasConvertToMMATransisitiveUse(v.getDefiningOp(),
tensorType.getEncoding())) {
return;
}
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
}
};
@@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() {
// you can pass a tensor with an encoding as an arg, instead of explicitly
// calling tt.load.
for (auto arg : funcOp.getArguments()) {
maybeAddAnchor(arg);
addAnchor(arg);
}

funcOp.walk([&](Operation *op) {
if (isLayoutAnchor(op)) {
for (auto result : op->getResults()) {
maybeAddAnchor(result);
addAnchor(result);
}
}
});
15 changes: 0 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
@@ -3222,21 +3222,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri,
w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs)

if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"):
if not is_cuda():
pass
else:
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync.bfly")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0

# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
# TODO: we should eliminate these unused functions in ptx code.
if not (capability[0] >= 9):
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32)
42 changes: 42 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
@@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return %outLHS : tensor<128x64xf32, #blocked1>
}
}

// -----

#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>

module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} {
// CHECK-LABEL: matmul_add
tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %C : !tt.ptr<f32>) {
%a_ptr_init = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #BL>
%c_ptr_init = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #CL>
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL>
%cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>

%100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>) {
%a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>, #AL>
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr<f16>, #BL>
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL>
// CHECK: %[[T0:.*]] = tt.dot
// CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma>
%t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
// CHECK: scf.yield
scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>
}

// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked
tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr<f32>, #CL>
tt.return
}
}

0 comments on commit 1064b59

Please sign in to comment.