From f79c06376bc6388d975fa77b17bfa2a7f7d36d15 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Fri, 3 Nov 2023 14:55:03 -0700 Subject: [PATCH] Tcp Custom Op and Torch->Tcp conversions + lit tests (#17) Op definition for `tcp.custom_op` and lowerings for `aten.gather` and `aten.index.Tensor_hacked_twin`. This is useful to represent and lower custom/opaque operations that have tensor inputs and outputs, but 1) lack a clear operational semantics, or 2) have an operational semantic that isn't defined by Tcp (such as user defined custom operations), or 3) need to "flow through" the middleend because they target pre-written kernels available in the backend As a follow-on, I'll look into making `torch -> tcp.custom_op` conversions automatic as much as possible, to minimize the overhead in adding individual conversion patterns manually. Then this could be the last step in a frontend lowering pipeline, to have all remaining `torch.aten` ops converted. --- BUILD | 1 + include/mlir-tcp/Dialect/IR/TcpOps.td | 26 +++- lib/Conversion/TorchToTcp/PopulatePatterns.h | 4 + lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 111 ++++++++++++++++++ lib/Conversion/TorchToTcp/TorchToTcp.cpp | 3 + .../Conversion/TorchToTcp/tcp_custom_ops.mlir | 38 ++++++ test/Dialect/tcp_custom_ops.mlir | 29 +++++ 7 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 lib/Conversion/TorchToTcp/TcpCustomOp.cpp create mode 100644 test/Conversion/TorchToTcp/tcp_custom_ops.mlir create mode 100644 test/Dialect/tcp_custom_ops.mlir diff --git a/BUILD b/BUILD index d22ec247..9b6a4ffe 100644 --- a/BUILD +++ b/BUILD @@ -190,6 +190,7 @@ cc_library( "lib/Conversion/TorchToTcp/Elementwise.cpp", "lib/Conversion/TorchToTcp/Misc.cpp", "lib/Conversion/TorchToTcp/PopulatePatterns.h", + "lib/Conversion/TorchToTcp/TcpCustomOp.cpp", "lib/Conversion/TorchToTcp/TorchToTcp.cpp", "lib/Conversion/TorchToTcp/Utils.cpp", "lib/Conversion/TorchToTcp/Utils.h", diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index bcf504b3..43c84b44 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -48,7 +48,6 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementT specified. }]; - // TODO: Does clamp need to support complex tensors? let arguments = (ins Tcp_FloatOrIntTensor:$in, OptionalAttr:$min_float, @@ -311,6 +310,31 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [ let hasVerifier = 1; } +def Tcp_CustomOp : Tcp_Op<"custom_op", []> { + let summary = "Custom opaque operation in Tcp dialect"; + + let description = [{ + This is useful to represent and lower custom/opaque operations that have + tensor inputs and outputs, but + (1) lack a clear operational semantics or + (2) have an operational semantic that isn't defined by Tcp (e.g user + defined custom operations) or + (3) need to "flow through" the middle-end because they target existing + pre-written kernels available in the backend. + }]; + + let arguments = (ins + Variadic:$inputs, + StrAttr:$op_name + ); + + let results = (outs + Variadic:$outputs + ); + + let assemblyFormat = "`(`$op_name`)` $inputs attr-dict `:` type($inputs) `->` type($outputs)"; +} + def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> { let summary = "Computes square root of input, elementwise"; diff --git a/lib/Conversion/TorchToTcp/PopulatePatterns.h b/lib/Conversion/TorchToTcp/PopulatePatterns.h index d62d0a02..76ee2b17 100644 --- a/lib/Conversion/TorchToTcp/PopulatePatterns.h +++ b/lib/Conversion/TorchToTcp/PopulatePatterns.h @@ -26,5 +26,9 @@ void populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); +void populateTcpCustomOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); + } // namespace torch_to_tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp new file mode 100644 index 00000000..7fdc96c5 --- /dev/null +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -0,0 +1,111 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" + +#include "mlir-tcp/Dialect/IR/TcpDialect.h" +#include "mlir-tcp/Dialect/IR/TcpOps.h" + +#include "PopulatePatterns.h" +#include "Utils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +#include "llvm/ADT/StringSet.h" + +using namespace mlir; +using namespace mlir::tcp; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class ConvertAtenGatherOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed( + OpConversionPattern::getTypeConverter()->convertTypes( + op->getResultTypes(), resultTypes))) { + return failure(); + } + + int64_t dimVal; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimVal))) + return failure(); + + auto indexAttr = + rewriter.getNamedAttr("axis", rewriter.getI64IntegerAttr(dimVal)); + + auto newOp = rewriter.replaceOpWithNewOp( + op, resultTypes, ValueRange{adaptor.getSelf(), adaptor.getIndex()}, + indexAttr); + newOp.setOpName(op->getName().getStringRef()); + return success(); + } +}; + +class ConvertAtenIndexTensorHackedTwinOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed( + OpConversionPattern::getTypeConverter() + ->convertTypes(op->getResultTypes(), resultTypes))) { + return failure(); + } + + SmallVector tensorOperands; + Value input = adaptor.getSelf(); + auto inputTensorType = input.getType().dyn_cast(); + // Check input is a tensor type. + if (!inputTensorType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + tensorOperands.push_back(input); + + // Deal with torch.prim.ListConstruct of non const value to get the index + Value indexList = op.getIndices(); + SmallVector indicesTorchType; + if (!getListConstructElements(indexList, indicesTorchType)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + SmallVector indexTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), indicesTorchType); + + tensorOperands.append(indexTensors.begin(), indexTensors.end()); + + auto newOp = rewriter.replaceOpWithNewOp(op, resultTypes, + tensorOperands); + newOp.setOpName(op->getName().getStringRef()); + return success(); + } +}; +} // namespace + +void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); +#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN +} diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 988f40ac..9ec63209 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -85,6 +85,9 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { torch_to_tcp::populateDataMovementPatternsAndLegality( typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::populateTcpCustomOpPatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir new file mode 100644 index 00000000..1f6593c0 --- /dev/null +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -0,0 +1,38 @@ +// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.gather_op( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64> +// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32> +func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64> +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64> +// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64> +// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64> +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64> +// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64> +// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64> +// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32> +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32> +func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { + %0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list -> !torch.vtensor<[1,30,19,3],f32> + return %1 : !torch.vtensor<[1,30,19,3],f32> +} diff --git a/test/Dialect/tcp_custom_ops.mlir b/test/Dialect/tcp_custom_ops.mlir new file mode 100644 index 00000000..2835e8ee --- /dev/null +++ b/test/Dialect/tcp_custom_ops.mlir @@ -0,0 +1,29 @@ +// RUN: tcp-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @tcp_custom_op( +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] : tensor -> tensor +// CHECK: return %[[T0]] : tensor +func.func @tcp_custom_op(%arg0: tensor) -> tensor { + %0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 : tensor -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @tcp_custom_op_with_named_attrs( +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] {axis = 0 : i32} : tensor -> tensor +// CHECK: return %[[T0]] : tensor +func.func @tcp_custom_op_with_named_attrs(%arg0: tensor) -> tensor { + %0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 {axis = 0 : i32} : tensor -> tensor + return %0 : tensor +} + +// ----- + +func.func @tcp_custom_op_without_op_name(%arg0: tensor) -> tensor { + // expected-error@+1{{expected attribute value}} + %0 = tcp.custom_op() %arg0 : tensor -> tensor + return %0 : tensor +}