Skip to content

Commit

Permalink
[BACKEND] Avoid circular dependencies (triton-lang#1877)
Browse files Browse the repository at this point in the history
Recent changes made TritonGPU dialect depend on transform utils
(`isExpensiveCat()`), and Triton ops depend on TritonGPU dialect
(`DotOperandEncodingAttr`). This works fine with CMake but circular
dependencies are not ideal and Bazel builds (which we use internally at
Google) try hard to prevent them.

Would it be acceptable to move the `isExpensiveCat()` function back to
TritonGPU dialect (where it was before), and split the TritonGPU
attributes into a separate header? This would avoid diverging our
internal version or creating over-sized bazel targets to avoid circular
dependencies.

Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
chsigg and Jokeren authored Jul 5, 2023
1 parent 4255ef0 commit 3442904
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 24 deletions.
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Attributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_
#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"

#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_
7 changes: 3 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Traits.h"

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

Expand Down Expand Up @@ -82,9 +80,10 @@ bool isaDistributedLayout(Attribute layout);

bool isSharedEncoding(Value value);

bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);

} // namespace gpu
} // namespace triton

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
2 changes: 0 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,

bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);

// skipInit is True when we only consider the operands of the initOp but
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "mlir/IR/OperationSupport.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"

namespace mlir {
namespace triton {
Expand Down
15 changes: 13 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -368,9 +367,21 @@ bool isSharedEncoding(Value value) {
return false;
}

bool isExpensiveCat(CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}

} // namespace gpu
} // namespace triton

} // namespace mlir

static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
Expand Down
18 changes: 3 additions & 15 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,13 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
return true;
}

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return isExpensiveLoadOrStore(op, targetEncoding);
if (isa<triton::CatOp>(op))
return isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
return triton::gpu::isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
Expand All @@ -136,7 +123,8 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {

bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
if (isa<triton::CatOp>(op))
return !isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
return !triton::gpu::isExpensiveCat(cast<triton::CatOp>(op),
targetEncoding);
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
}
Expand Down

0 comments on commit 3442904

Please sign in to comment.