diff --git a/include/triton/Dialect/TritonGPU/IR/Attributes.h b/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000000..4b96af52ddcb --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,7 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 5c29a675cf00..85e3367e4eac 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -8,12 +8,10 @@ // TritonGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Traits.h" -#define GET_ATTRDEF_CLASSES -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" - #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" @@ -82,9 +80,10 @@ bool isaDistributedLayout(Attribute layout); bool isSharedEncoding(Value value); +bool isExpensiveCat(CatOp cat, Attribute &targetEncoding); + } // namespace gpu } // namespace triton - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index d395edd67b92..92732d5797f8 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -18,8 +18,6 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding); -bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding); - bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding); // skipInit is True when we only consider the operands of the initOp but diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 894e2a692b86..cd8d1b82e69c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -6,7 +6,7 @@ #include "mlir/IR/OperationSupport.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" namespace mlir { namespace triton { diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 55eda11332d7..d2a1eacaaabe 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -7,7 +7,6 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -368,9 +367,21 @@ bool isSharedEncoding(Value value) { return false; } +bool isExpensiveCat(CatOp cat, Attribute &targetEncoding) { + // If the new elements per thread is less than the old one, we will need to do + // convert encoding that goes through shared memory anyway. So we consider it + // as expensive. + auto tensorTy = cat.getResult().getType().cast(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto elemTy = tensorTy.getElementType(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); + return newTotalElemsPerThread < totalElemsPerThread; +} + } // namespace gpu } // namespace triton - } // namespace mlir static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index fb4cec415d72..95130a3f13fd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -104,26 +104,13 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { return true; } -bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) { - // If the new elements per thread is less than the old one, we will need to do - // convert encoding that goes through shared memory anyway. So we consider it - // as expensive. - auto tensorTy = cat.getResult().getType().cast(); - auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto elemTy = tensorTy.getElementType(); - auto newTotalElemsPerThread = - triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); - return newTotalElemsPerThread < totalElemsPerThread; -} - bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { if (!op) return true; if (isa(op)) return isExpensiveLoadOrStore(op, targetEncoding); if (isa(op)) - return isExpensiveCat(cast(op), targetEncoding); + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) @@ -136,7 +123,8 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { bool canFoldConversion(Operation *op, Attribute &targetEncoding) { if (isa(op)) - return !isExpensiveCat(cast(op), targetEncoding); + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); return isa(op); }