Skip to content

Commit

Permalink
Move byte manipulation ops from elwise ops conversion. (#28)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Jun 20, 2024
1 parent 1cf81ef commit 44035c9
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 129 deletions.
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace cpu {
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createConvertElementwiseOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertElemManipOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertMemoryOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertPtrOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertDotOp();
Expand Down
14 changes: 14 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::Mo
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertElemManipOps : Pass<"triton-cpu-convert-elem-manip-ops", "mlir::ModuleOp"> {
let summary = "Convert elements manipulation ops (transpose, shuffle, etc.).";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertElemManipOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::memref::MemRefDialect",
"mlir::vector::VectorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> {
let summary = "Convert Triton ops related to pointer arithmetics.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_triton_library(TritonToTritonCPU
ConvertControlFlowOps.cpp
ConvertDotOp.cpp
ConvertElementwiseOps.cpp
ConvertElemManipOps.cpp
ConvertHistogramOp.cpp
ConvertMemoryOps.cpp
ConvertPtrOps.cpp
Expand Down
208 changes: 208 additions & 0 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#include "OpTypeConversion.h"
#include "TypeConverter.h"

#include "cpu/include/TritonToTritonCPU/Passes.h"

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTELEMMANIPOPS
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

class ElemManipOpConversionTarget : public ConversionTarget {
public:
explicit ElemManipOpConversionTarget(MLIRContext &ctx,
TypeConverter &converter)
: ConversionTarget(ctx) {
addLegalDialect<vector::VectorDialect>();
addLegalDialect<arith::ArithDialect>();
addLegalDialect<math::MathDialect>();
addLegalDialect<TritonDialect>();
addLegalDialect<TritonCPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();

addIllegalOp<triton::BroadcastOp>();
addIllegalOp<triton::ExpandDimsOp>();
addIllegalOp<triton::ReshapeOp>();
addIllegalOp<triton::TransOp>();
addIllegalOp<triton::JoinOp>();
addIllegalOp<triton::CatOp>();
}
};

struct ReshapeOpConversion : public OpConversionPattern<triton::ReshapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(isa<RankedTensorType>(op.getType()));
auto loc = op.getLoc();
auto src = rewriter.getRemappedValue(op.getSrc());
auto srcShape = dyn_cast<VectorType>(src.getType()).getShape();
auto resTy =
dyn_cast<VectorType>(getTypeConverter()->convertType(op.getType()));
auto dstShape = resTy.getShape();
auto elemTy = resTy.getElementType();

// There are restrictions on how shape can be modified by ShapeCastOp
// when rank is changed. For now, we simply detect it and handle through
// a cast to 1D vector. Better solution may be required later.
if (canCastShape(srcShape, dstShape)) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, VectorType::get(dstShape, elemTy), src);
} else {
SmallVector<int64_t> tmpShape({resTy.getNumElements()});
auto tmp = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get(tmpShape, elemTy), src);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, VectorType::get(dstShape, elemTy), tmp);
}
return success();
}

private:
bool canCastShape(ArrayRef<int64_t> src, ArrayRef<int64_t> dst) const {
if (src.size() == dst.size())
return true;
if (src.size() > dst.size())
return canCastShape(dst, src);

size_t srcIdx = 0;
size_t dstIdx = 0;
while (srcIdx < src.size() && dstIdx < dst.size()) {
if (src[srcIdx] == 1) {
++srcIdx;
} else {
// Source dim size should be a product of continuous dest dim sizes.
int64_t srcSize = src[srcIdx++];
int64_t dstSize = dst[dstIdx++];
while (dstSize < srcSize && dstIdx < dst.size())
dstSize *= dst[dstIdx++];
if (dstSize != srcSize)
return false;
}
}

// Skip trailing 1s.
while (srcIdx < src.size() && src[srcIdx] == 1)
++srcIdx;
while (dstIdx < dst.size() && dst[dstIdx] == 1)
++dstIdx;

return srcIdx == src.size() && dstIdx == dst.size();
}
};

struct TransOpConversion : public OpConversionPattern<triton::TransOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto val = rewriter.getRemappedValue(op.getSrc());
auto order = op.getOrder();
SmallVector<int64_t> permutation(order.begin(), order.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, val, permutation);
return success();
}
};

struct JoinOpConversion : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto lhs = rewriter.getRemappedValue(op.getLhs());
auto rhs = rewriter.getRemappedValue(op.getRhs());
auto interleave = rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
// JoinOp creates a new dimension, but InterleaveOp doubles the final one.
// Use ShapeCastOp to get the required shape.
auto resTy = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resTy, interleave);
return success();
}
};

struct CatOpConversion : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto lhs = rewriter.getRemappedValue(op.getLhs());
auto rhs = rewriter.getRemappedValue(op.getRhs());
auto lhsTy = dyn_cast<VectorType>(lhs.getType());
auto rhsTy = dyn_cast<VectorType>(rhs.getType());
SmallVector<int64_t> indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]);
std::iota(indices.begin(), indices.end(), 0);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(op, lhs, rhs, indices);
return success();
}
};

struct ConvertElemManipOps
: public triton::impl::ConvertElemManipOpsBase<ConvertElemManipOps> {
using ConvertElemManipOpsBase::ConvertElemManipOpsBase;

ConvertElemManipOps() : ConvertElemManipOpsBase() {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

TritonToTritonCPUTypeConverter typeConverter;
ElemManipOpConversionTarget convTarget(*context, typeConverter);
RewritePatternSet patterns(context);
patterns.add<OpTypeConversion<triton::BroadcastOp, vector::BroadcastOp>>(
typeConverter, context);
patterns.add<OpTypeConversion<triton::ExpandDimsOp, vector::ShapeCastOp>>(
typeConverter, context);
patterns.add<ReshapeOpConversion>(typeConverter, context);
patterns.add<TransOpConversion>(typeConverter, context);
patterns.add<JoinOpConversion>(typeConverter, context);
patterns.add<CatOpConversion>(typeConverter, context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createConvertElemManipOps() {
return std::make_unique<ConvertElemManipOps>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
Loading

0 comments on commit 44035c9

Please sign in to comment.