diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 59a8d020d4..db2528b966 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -101,6 +101,10 @@ class DialectVerifyTensorLayoutInterface virtual LogicalResult verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op, function_ref emitError) const = 0; + + virtual LogicalResult + verifyDotOpLayout(Attribute parent, unsigned opIdx, unsigned kWidth, + function_ref emitError) const = 0; }; } // namespace triton diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3f2393a57a..6c0e8b7657 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -683,57 +683,17 @@ LogicalResult DotOperandEncodingAttr::verify( if (!parent) { return emitError() << "ttg.dot_op parent parameter cannot be null"; } - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() << "ttg.dot_op kWidth parameter can only be " - "non-zero for Ampere or Hopper MMA parent"; - if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() << "ttg.dot_op kWidth parameter is mandatory for " - "Ampere or Hopper MMA parent"; - if (opIdx != 0 && parentAttr.isHopper()) - return emitError() - << "ttg.dot_op opIdx parameter must be 0 for " - "Hopper MMA parent, since Hopper WGMMA only allows first " - "operand to be in registers"; - return success(); - } - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 16 && parentAttr.getVersion() == 1 || - kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2) - return emitError() << "ttg.dot_op kWidth parameter must be 16 for " - "gfx11 and 8/16 for gfx12"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth == 0) - return emitError() << "ttg.dot_op kWidth parameter is mandatory for " - "MFMA parent"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - int opsPerChannel = parentAttr.getOpsPerChannel(); - if (opIdx == 0) { - // operand A - if (opsPerChannel == 1) { - if (kWidth != opsPerChannel) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; - } else { - if (kWidth != opsPerChannel / 2) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; - } - } else { - // operand B - if (kWidth != parentAttr.getOpsPerChannel()) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; + if (isa(parent)) { + // The MMA layout can be defined in third party dialect. + // Dispatch to the verifier of dialect interface. + Dialect &dialect = parent.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyDotOpLayout(parent, opIdx, kWidth, + emitError); } - - return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { @@ -2763,6 +2723,45 @@ struct TritonGPUVerifyTensorLayoutInterface return success(); } + + LogicalResult verifyDotOpLayout( + Attribute parent, unsigned opIdx, unsigned kWidth, + function_ref emitError) const override { + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 16 && parentAttr.getVersion() == 1 || + kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2) + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " + "gfx11 and 8/16 for gfx12"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + return emitError() + << "ttg.dot_op un-known parent layout of TritonGPU dialect: " + << parent; + } }; //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 07e1cbfb15..2a28744b64 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -1117,6 +1117,59 @@ struct TritonIntelGPUInferLayoutInterface } }; +struct TritonIntelGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, Operation *op, + function_ref makeErr) const override { + + MLIRContext *ctx = op->getContext(); + auto *dialect = ctx->getLoadedDialect("ttg"); + assert(dialect && "Not found triton gpu dialect"); + auto verifyLayoutInterface = + dyn_cast(dialect); + assert(verifyLayoutInterface && + "Not found verify layout interface of triton gpu dialect"); + // re-dispatch the verify to triton gpu dialect + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, + makeErr); + } + + LogicalResult verifyDotOpLayout( + Attribute parent, unsigned opIdx, unsigned kWidth, + function_ref emitError) const override { + + if (auto parentAttr = mlir::dyn_cast(parent)) { + int opsPerChannel = parentAttr.getOpsPerChannel(); + if (opIdx == 0) { + // operand A + if (opsPerChannel == 1) { + if (kWidth != opsPerChannel) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } else { + if (kWidth != opsPerChannel / 2) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } + } else { + // operand B + if (kWidth != parentAttr.getOpsPerChannel()) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } + + return success(); + } + + return emitError() + << "ttg.dot_op un-known parent layout of TritonIntelGPU dialect: " + << parent; + } +}; + //===----------------------------------------------------------------------===// void TritonIntelGPUDialect::initialize() { @@ -1126,6 +1179,9 @@ void TritonIntelGPUDialect::initialize() { >(); addInterfaces(); + // addInterfaces(); + // addInterfaces(); + addInterfaces(); addOperations< #define GET_OP_LIST