From 3442904a92e2eabac08abb8c819450f74c08b926 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 5 Jul 2023 22:08:51 +0200 Subject: [PATCH] [BACKEND] Avoid circular dependencies (#1877) Recent changes made TritonGPU dialect depend on transform utils (`isExpensiveCat()`), and Triton ops depend on TritonGPU dialect (`DotOperandEncodingAttr`). This works fine with CMake but circular dependencies are not ideal and Bazel builds (which we use internally at Google) try hard to prevent them. Would it be acceptable to move the `isExpensiveCat()` function back to TritonGPU dialect (where it was before), and split the TritonGPU attributes into a separate header? This would avoid diverging our internal version or creating over-sized bazel targets to avoid circular dependencies. Co-authored-by: Keren Zhou --- .../triton/Dialect/TritonGPU/IR/Attributes.h | 7 +++++++ include/triton/Dialect/TritonGPU/IR/Dialect.h | 7 +++---- .../Dialect/TritonGPU/Transforms/Utility.h | 2 -- lib/Dialect/Triton/IR/Ops.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 15 +++++++++++++-- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 18 +++--------------- 6 files changed, 27 insertions(+), 24 deletions(-) create mode 100644 include/triton/Dialect/TritonGPU/IR/Attributes.h diff --git a/include/triton/Dialect/TritonGPU/IR/Attributes.h b/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000000..4b96af52ddcb --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,7 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 5c29a675cf00..85e3367e4eac 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -8,12 +8,10 @@ // TritonGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Traits.h" -#define GET_ATTRDEF_CLASSES -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" - #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" @@ -82,9 +80,10 @@ bool isaDistributedLayout(Attribute layout); bool isSharedEncoding(Value value); +bool isExpensiveCat(CatOp cat, Attribute &targetEncoding); + } // namespace gpu } // namespace triton - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index d395edd67b92..92732d5797f8 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -18,8 +18,6 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding); -bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding); - bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding); // skipInit is True when we only consider the operands of the initOp but diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 894e2a692b86..cd8d1b82e69c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -6,7 +6,7 @@ #include "mlir/IR/OperationSupport.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" namespace mlir { namespace triton { diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 55eda11332d7..d2a1eacaaabe 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -7,7 +7,6 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -368,9 +367,21 @@ bool isSharedEncoding(Value value) { return false; } +bool isExpensiveCat(CatOp cat, Attribute &targetEncoding) { + // If the new elements per thread is less than the old one, we will need to do + // convert encoding that goes through shared memory anyway. So we consider it + // as expensive. + auto tensorTy = cat.getResult().getType().cast(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto elemTy = tensorTy.getElementType(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); + return newTotalElemsPerThread < totalElemsPerThread; +} + } // namespace gpu } // namespace triton - } // namespace mlir static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index fb4cec415d72..95130a3f13fd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -104,26 +104,13 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { return true; } -bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) { - // If the new elements per thread is less than the old one, we will need to do - // convert encoding that goes through shared memory anyway. So we consider it - // as expensive. - auto tensorTy = cat.getResult().getType().cast(); - auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto elemTy = tensorTy.getElementType(); - auto newTotalElemsPerThread = - triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); - return newTotalElemsPerThread < totalElemsPerThread; -} - bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { if (!op) return true; if (isa(op)) return isExpensiveLoadOrStore(op, targetEncoding); if (isa(op)) - return isExpensiveCat(cast(op), targetEncoding); + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) @@ -136,7 +123,8 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { bool canFoldConversion(Operation *op, Attribute &targetEncoding) { if (isa(op)) - return !isExpensiveCat(cast(op), targetEncoding); + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); return isa(op); }