diff --git a/BUILD b/BUILD index d22ec247..9b6a4ffe 100644 --- a/BUILD +++ b/BUILD @@ -190,6 +190,7 @@ cc_library( "lib/Conversion/TorchToTcp/Elementwise.cpp", "lib/Conversion/TorchToTcp/Misc.cpp", "lib/Conversion/TorchToTcp/PopulatePatterns.h", + "lib/Conversion/TorchToTcp/TcpCustomOp.cpp", "lib/Conversion/TorchToTcp/TorchToTcp.cpp", "lib/Conversion/TorchToTcp/Utils.cpp", "lib/Conversion/TorchToTcp/Utils.h", diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index bcf504b3..43c84b44 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -48,7 +48,6 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementT specified. }]; - // TODO: Does clamp need to support complex tensors? let arguments = (ins Tcp_FloatOrIntTensor:$in, OptionalAttr:$min_float, @@ -311,6 +310,31 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [ let hasVerifier = 1; } +def Tcp_CustomOp : Tcp_Op<"custom_op", []> { + let summary = "Custom opaque operation in Tcp dialect"; + + let description = [{ + This is useful to represent and lower custom/opaque operations that have + tensor inputs and outputs, but + (1) lack a clear operational semantics or + (2) have an operational semantic that isn't defined by Tcp (e.g user + defined custom operations) or + (3) need to "flow through" the middle-end because they target existing + pre-written kernels available in the backend. + }]; + + let arguments = (ins + Variadic:$inputs, + StrAttr:$op_name + ); + + let results = (outs + Variadic:$outputs + ); + + let assemblyFormat = "`(`$op_name`)` $inputs attr-dict `:` type($inputs) `->` type($outputs)"; +} + def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> { let summary = "Computes square root of input, elementwise"; diff --git a/lib/Conversion/TorchToTcp/PopulatePatterns.h b/lib/Conversion/TorchToTcp/PopulatePatterns.h index d62d0a02..76ee2b17 100644 --- a/lib/Conversion/TorchToTcp/PopulatePatterns.h +++ b/lib/Conversion/TorchToTcp/PopulatePatterns.h @@ -26,5 +26,9 @@ void populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); +void populateTcpCustomOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); + } // namespace torch_to_tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp new file mode 100644 index 00000000..7fdc96c5 --- /dev/null +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -0,0 +1,111 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" + +#include "mlir-tcp/Dialect/IR/TcpDialect.h" +#include "mlir-tcp/Dialect/IR/TcpOps.h" + +#include "PopulatePatterns.h" +#include "Utils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +#include "llvm/ADT/StringSet.h" + +using namespace mlir; +using namespace mlir::tcp; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class ConvertAtenGatherOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed( + OpConversionPattern::getTypeConverter()->convertTypes( + op->getResultTypes(), resultTypes))) { + return failure(); + } + + int64_t dimVal; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimVal))) + return failure(); + + auto indexAttr = + rewriter.getNamedAttr("axis", rewriter.getI64IntegerAttr(dimVal)); + + auto newOp = rewriter.replaceOpWithNewOp( + op, resultTypes, ValueRange{adaptor.getSelf(), adaptor.getIndex()}, + indexAttr); + newOp.setOpName(op->getName().getStringRef()); + return success(); + } +}; + +class ConvertAtenIndexTensorHackedTwinOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed( + OpConversionPattern::getTypeConverter() + ->convertTypes(op->getResultTypes(), resultTypes))) { + return failure(); + } + + SmallVector tensorOperands; + Value input = adaptor.getSelf(); + auto inputTensorType = input.getType().dyn_cast(); + // Check input is a tensor type. + if (!inputTensorType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + tensorOperands.push_back(input); + + // Deal with torch.prim.ListConstruct of non const value to get the index + Value indexList = op.getIndices(); + SmallVector indicesTorchType; + if (!getListConstructElements(indexList, indicesTorchType)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + SmallVector indexTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), indicesTorchType); + + tensorOperands.append(indexTensors.begin(), indexTensors.end()); + + auto newOp = rewriter.replaceOpWithNewOp(op, resultTypes, + tensorOperands); + newOp.setOpName(op->getName().getStringRef()); + return success(); + } +}; +} // namespace + +void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); +#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN +} diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 988f40ac..9ec63209 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -85,6 +85,9 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { torch_to_tcp::populateDataMovementPatternsAndLegality( typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::populateTcpCustomOpPatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir new file mode 100644 index 00000000..1f6593c0 --- /dev/null +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -0,0 +1,38 @@ +// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.gather_op( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64> +// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32> +func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64> +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64> +// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64> +// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64> +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64> +// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64> +// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64> +// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32> +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32> +func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { + %0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list -> !torch.vtensor<[1,30,19,3],f32> + return %1 : !torch.vtensor<[1,30,19,3],f32> +} diff --git a/test/Dialect/tcp_custom_ops.mlir b/test/Dialect/tcp_custom_ops.mlir new file mode 100644 index 00000000..2835e8ee --- /dev/null +++ b/test/Dialect/tcp_custom_ops.mlir @@ -0,0 +1,29 @@ +// RUN: tcp-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @tcp_custom_op( +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] : tensor -> tensor +// CHECK: return %[[T0]] : tensor +func.func @tcp_custom_op(%arg0: tensor) -> tensor { + %0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 : tensor -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @tcp_custom_op_with_named_attrs( +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] {axis = 0 : i32} : tensor -> tensor +// CHECK: return %[[T0]] : tensor +func.func @tcp_custom_op_with_named_attrs(%arg0: tensor) -> tensor { + %0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 {axis = 0 : i32} : tensor -> tensor + return %0 : tensor +} + +// ----- + +func.func @tcp_custom_op_without_op_name(%arg0: tensor) -> tensor { + // expected-error@+1{{expected attribute value}} + %0 = tcp.custom_op() %arg0 : tensor -> tensor + return %0 : tensor +}