Skip to content

Commit

Permalink
update custom op conversions (#92)
Browse files Browse the repository at this point in the history
- Fix a minor issue with broadcast op (fold the broadcast if axes attr
is empty)
- Add following ops to tcp.custom_op:
--`aten.sort`
--`aten.cumsum`
--`aten.min.dim`
--`aten.view`(dynamic shape only)
--`aten.topk`


To test:
`bazel test //...` (in docker)
  • Loading branch information
zezhang authored Aug 30, 2024
1 parent 7b53fe4 commit 57d5e00
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern<AtenOpT> {
}
}

// fold the broadcast if no axes are found
if (axes.size() == 0) {
rewriter.replaceOp(op, input);
return success();
}
RankedTensorType resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op->getResult(0).getType())
.template cast<RankedTensorType>();

auto axesAttr = rewriter.getI64ArrayAttr(axes);
rewriter.replaceOpWithNewOp<tcp::BroadcastOp>(op, resultType, input,
resultShape, axesAttr);
Expand Down
148 changes: 148 additions & 0 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

#include "llvm/ADT/StringSet.h"

Expand Down Expand Up @@ -211,6 +214,145 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
}
};

class ConvertAtenTopkOp : public OpConversionPattern<AtenTopkOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenTopkOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());

helper.addIntAttr("k", op.getK());
helper.addIntAttr("dim", op.getDim());
helper.addBoolAttr("largest", op.getLargest());
helper.addBoolAttr("sorted", op.getSorted());

return helper.replace();
}
};

class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenSortOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());

helper.addIntAttr("dim", op.getDim());
helper.addBoolAttr("descending", op.getDescending());

return helper.replace();
}
};

class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());

helper.addIntAttr("dim", op.getDim());
if (!isa<Torch::ConstantNoneOp>(op.getDtype().getDefiningOp()))
return rewriter.notifyMatchFailure(op, "Unsupported dtype argument");

return helper.replace();
}
};

class ConvertAtenMinDimOp : public OpConversionPattern<AtenMinDimOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenMinDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());

helper.addIntAttr("dim", op.getDim());
helper.addBoolAttr("keepdim", op.getKeepdim());

return helper.replace();
}
};

class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
Value self = adaptor.getSelf();
auto srcType = self.getType().cast<RankedTensorType>();
auto resultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();

SmallVector<int64_t> size;
// static shape will be handled through TOSA dialect
if (matchPattern(op.getSize(), m_TorchListOfConstantInts(size)) &&
srcType.hasStaticShape() && resultType.hasStaticShape())
return rewriter.notifyMatchFailure(op, "only dynamic shape is supported");

helper.addOperand("self", self);
Operation *primListOp = op.getSize().getDefiningOp();
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(primListOp);
if (!listConstruct) {
return rewriter.notifyMatchFailure(
op, "Size must come from PrimListConstructOp");
}
int idx = 0;
for (Value value : listConstruct.getElements()) {
int64_t dimSize;
if (matchPattern(value, m_TorchConstantInt(&dimSize))) {
size.push_back(dimSize);
} else {
size.push_back(ShapedType::kDynamic);
// dynamic shape should follow pattern:
// %dim_32 = tensor.dim %arg1, %c0 : tensor<?x2736x16xf32>
// %1 = arith.index_cast %dim_32 : index to i64
// %2 = torch_c.from_i64 %1
// %3 = torch.prim.ListConstruct %2 ...
if (!isa<TorchConversion::FromI64Op>(value.getDefiningOp()))
return rewriter.notifyMatchFailure(
op, "dynamic dim size should come from FromI64Op");
auto conversionOp =
dyn_cast<TorchConversion::FromI64Op>(value.getDefiningOp());
if (!isa<arith::IndexCastOp>(conversionOp.getOperand().getDefiningOp()))
return rewriter.notifyMatchFailure(
op, "dynamic dim size should come from IndexCastOp");
auto indexCastOp = dyn_cast<arith::IndexCastOp>(
conversionOp.getOperand().getDefiningOp());
if (!isa<tensor::DimOp>(indexCastOp.getIn().getDefiningOp()))
return rewriter.notifyMatchFailure(
op, "dynamic dim size should come from DimOp");
auto dimOp =
dyn_cast<tensor::DimOp>(indexCastOp.getIn().getDefiningOp());
helper.addOperand("idx_" + std::to_string(idx), dimOp);
}
idx++;
}
helper.addDenseIntArrayAttr("size", size);

return helper.replace();
}
};

} // namespace

void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
Expand All @@ -227,6 +369,12 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
AtenFakeQuantizePerTensorAffineTensorQparamsOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerChannelAffineOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenTopkOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
// AtenViewOp can still live after torch-to-tcp conversion
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN

// Torch -> TOSA doesn't handle transposed convolutions; map them to
Expand Down
2 changes: 2 additions & 0 deletions lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"

void mlir::tcp::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<tcp::TcpDialect>();
registry.insert<torch::Torch::TorchDialect>();
registry.insert<torch::TorchConversion::TorchConversionDialect>();
mlir::func::registerInlinerExtension(registry);
mlir::tcp::registerTilingInterfaceExternalModels(registry);
}
Expand Down
87 changes: 87 additions & 0 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,90 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
%output = torch.aten.fake_quantize_per_channel_affine %input, %scale, %zero_point, %int1, %int0, %int255 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
return %output : !torch.vtensor<[1,3,32,32],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.topk(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,80],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.topk") %[[T0]] {dim = -1 : i64, k = 80 : i64, largest = true, sorted = true, torch_operand_names = ["self"]} :
// CHECK-SAME: tensor<?x2304xf32> -> tensor<?x80xf32>, tensor<?x80xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x80xf32> -> !torch.vtensor<[?,80],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,80],f32>
func.func @torch.aten.topk(%input: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,80],f32> {
%int-1 = torch.constant.int -1
%int80 = torch.constant.int 80
%true = torch.constant.bool true
%output0, %output1 = torch.aten.topk %input, %int80, %int-1, %true, %true : !torch.vtensor<[?,2304],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,80],f32>, !torch.vtensor<[?,80],si64>
return %output0 : !torch.vtensor<[?,80],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.sort(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.sort") %[[T0]] {descending = true, dim = -1 : i64, torch_operand_names = ["self"]} :
// CHECK-SAME: tensor<?x2304xf32> -> tensor<?x2304xf32>, tensor<?x2304xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x2304xf32> -> !torch.vtensor<[?,2304],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,2304],f32>
func.func @torch.aten.sort(%input: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
%output0, %output1 = torch.aten.sort %input, %int-1, %true : !torch.vtensor<[?,2304],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,2304],f32>, !torch.vtensor<[?,2304],si64>
return %output0 : !torch.vtensor<[?,2304],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.cumsum(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?],si32> -> tensor<?xi32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.cumsum") %[[T0]] {dim = 0 : i64, torch_operand_names = ["self"]} : tensor<?xi32> -> tensor<?xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<?xi64> -> !torch.vtensor<[?],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[?],si64>
func.func @torch.aten.cumsum(%input: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
%int0 = torch.constant.int 0
%none = torch.constant.none
%1 = torch.aten.cumsum %input, %int0, %none : !torch.vtensor<[?],si32>, !torch.int, !torch.none -> !torch.vtensor<[?],si64>
return %1 : !torch.vtensor<[?],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.min.dim(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,80],f32> -> tensor<?x80xf32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.min.dim") %[[T0]] {dim = 1 : i64, keepdim = false, torch_operand_names = ["self"]} :
// CHECK-SAME: tensor<?x80xf32> -> tensor<?xf32>, tensor<?xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
func.func @torch.aten.min.dim(%input: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%output0, %output1 = torch.aten.min.dim %input, %int1, %false : !torch.vtensor<[?,80],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
return %output0 : !torch.vtensor<[?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.view_dynamic_shape(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,384,16],f32>, %[[ARG1:.*]]: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,384,16],f32> -> tensor<?x384x16xf32>
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x2736x16xf32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.view") %[[T0]], %[[DIM]] {size = array<i64: -9223372036854775808, 24, 16, 16>, torch_operand_names = ["self", "idx_0"]} :
// CHECK-SAME: tensor<?x384x16xf32>, index -> tensor<?x24x16x16xf32>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x24x16x16xf32> -> !torch.vtensor<[?,24,16,16],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,24,16,16],f32>
func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>, %arg1: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
%c0 = arith.constant 0 : index
%int24 = torch.constant.int 24
%int16 = torch.constant.int 16
%dim_32 = tensor.dim %arg1, %c0 : tensor<?x2736x16xf32>
%1 = arith.index_cast %dim_32 : index to i64
%2 = torch_c.from_i64 %1
%3 = torch.prim.ListConstruct %2, %int24, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list<int> -> !torch.vtensor<[?,24,16,16],f32>
return %4 : !torch.vtensor<[?,24,16,16],f32>
}

0 comments on commit 57d5e00

Please sign in to comment.