Skip to content

Clean up Intel specific code in the common TritonGPU dialect source file. #4469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class DialectVerifyTensorLayoutInterface
virtual LogicalResult
verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op,
function_ref<InFlightDiagnostic()> emitError) const = 0;

virtual LogicalResult
verifyDotOpLayout(Attribute parent, unsigned opIdx, unsigned kWidth,
function_ref<InFlightDiagnostic()> emitError) const = 0;
};

} // namespace triton
Expand Down
97 changes: 48 additions & 49 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,57 +683,17 @@ LogicalResult DotOperandEncodingAttr::verify(
if (!parent) {
return emitError() << "ttg.dot_op parent parameter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter can only be "
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "ttg.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
"gfx11 and 8/16 for gfx12";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (kWidth == 0)
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"MFMA parent";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
int opsPerChannel = parentAttr.getOpsPerChannel();
if (opIdx == 0) {
// operand A
if (opsPerChannel == 1) {
if (kWidth != opsPerChannel)
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
} else {
if (kWidth != opsPerChannel / 2)
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
}
} else {
// operand B
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
if (isa<MmaEncodingTrait>(parent)) {
// The MMA layout can be defined in third party dialect.
// Dispatch to the verifier of dialect interface.
Dialect &dialect = parent.getDialect();
auto verifyLayoutInterface =
dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
if (verifyLayoutInterface) {
return verifyLayoutInterface->verifyDotOpLayout(parent, opIdx, kWidth,
emitError);
}

return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
Expand Down Expand Up @@ -2763,6 +2723,45 @@ struct TritonGPUVerifyTensorLayoutInterface

return success();
}

LogicalResult verifyDotOpLayout(
Attribute parent, unsigned opIdx, unsigned kWidth,
function_ref<InFlightDiagnostic()> emitError) const override {

if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter can only be "
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "ttg.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
"gfx11 and 8/16 for gfx12";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (kWidth == 0)
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"MFMA parent";
return success();
}

return emitError()
<< "ttg.dot_op un-known parent layout of TritonGPU dialect: "
<< parent;
}
};

//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,59 @@ struct TritonIntelGPUInferLayoutInterface
}
};

struct TritonIntelGPUVerifyTensorLayoutInterface
: public triton::DialectVerifyTensorLayoutInterface {
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;

LogicalResult verifyTensorLayout(
Attribute layout, RankedTensorType rankedTy, Operation *op,
function_ref<InFlightDiagnostic()> makeErr) const override {

MLIRContext *ctx = op->getContext();
auto *dialect = ctx->getLoadedDialect("ttg");
assert(dialect && "Not found triton gpu dialect");
auto verifyLayoutInterface =
dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(dialect);
assert(verifyLayoutInterface &&
"Not found verify layout interface of triton gpu dialect");
// re-dispatch the verify to triton gpu dialect
return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op,
makeErr);
}

LogicalResult verifyDotOpLayout(
Attribute parent, unsigned opIdx, unsigned kWidth,
function_ref<InFlightDiagnostic()> emitError) const override {

if (auto parentAttr = mlir::dyn_cast<DpasEncodingAttr>(parent)) {
int opsPerChannel = parentAttr.getOpsPerChannel();
if (opIdx == 0) {
// operand A
if (opsPerChannel == 1) {
if (kWidth != opsPerChannel)
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
} else {
if (kWidth != opsPerChannel / 2)
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
}
} else {
// operand B
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
}

return success();
}

return emitError()
<< "ttg.dot_op un-known parent layout of TritonIntelGPU dialect: "
<< parent;
}
};

//===----------------------------------------------------------------------===//

void TritonIntelGPUDialect::initialize() {
Expand All @@ -1126,6 +1179,9 @@ void TritonIntelGPUDialect::initialize() {
>();

addInterfaces<TritonIntelGPUInferLayoutInterface>();
// addInterfaces<TritonGPUOpAsmInterface>();
// addInterfaces<TritonGPUInferLayoutInterface>();
Comment on lines +1182 to +1183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// addInterfaces<TritonGPUOpAsmInterface>();
// addInterfaces<TritonGPUInferLayoutInterface>();

addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>();

addOperations<
#define GET_OP_LIST
Expand Down