diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index b23d0610..1e64956c 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -19,6 +19,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::tcp; @@ -146,6 +147,88 @@ class ConvertAten_IndexPutImplOp } }; +class ConvertAtenConvolutionOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(OpConversionPattern::getTypeConverter() + ->convertTypes(op->getResultTypes(), resultTypes))) { + return failure(); + } + + SmallVector operands; + SmallVector operandNames; + + auto addOperand = [&](std::string name, Value value) { + operandNames.push_back(name); + operands.push_back(value); + }; + + addOperand("input", adaptor.getInput()); + addOperand("weight", adaptor.getWeight()); + if (!adaptor.getBias().getType().isa()) { + addOperand("bias", adaptor.getBias()); + } + + SmallVector attrs; + + attrs.push_back(rewriter.getNamedAttr( + "torch_operand_names", rewriter.getStrArrayAttr(operandNames))); + + auto addListOfIntAttr = [&](const std::string &name, Value value) { + SmallVector valueInt; + if (!matchPattern(value, m_TorchListOfConstantInts(valueInt))) + return rewriter.notifyMatchFailure(op, std::string("non-const") + name + + "list unsupported"); + attrs.push_back( + rewriter.getNamedAttr(name, rewriter.getIndexArrayAttr(valueInt))); + return success(); + }; + + if (auto result = addListOfIntAttr("stride", adaptor.getStride()); + result.failed()) { + return result; + } + if (auto result = addListOfIntAttr("padding", adaptor.getPadding()); + result.failed()) { + return result; + } + if (auto result = addListOfIntAttr("dilation", adaptor.getDilation()); + result.failed()) { + return result; + } + if (auto result = + addListOfIntAttr("output_padding", adaptor.getOutputPadding()); + result.failed()) { + return result; + } + + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure(op, + "non const transposed unsupported"); + attrs.push_back( + rewriter.getNamedAttr("transposed", rewriter.getBoolAttr(transposed))); + + int64_t groups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) + return rewriter.notifyMatchFailure(op, "non const groups unsupported"); + attrs.push_back( + rewriter.getNamedAttr("groups", rewriter.getI64IntegerAttr(groups))); + + auto replOp = rewriter.replaceOpWithNewOp(op, resultTypes, + operands, attrs); + + replOp.setOpName(op->getName().getStringRef()); + + return success(); + } +}; + } // namespace void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( @@ -159,4 +242,18 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp); #undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN + + auto isTransposedConvOp = [](AtenConvolutionOp op) { + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return false; + return transposed; + }; + + // Only want to convert transposed conv ops, i.e., if its not transposed, + // its "legal", i.e., will not get converted. + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet, + [&](AtenConvolutionOp op) { return !isTransposedConvOp(op); }); } diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index 68bff615..2a4e81a8 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -80,7 +80,10 @@ bool getConstTensorWithType(ConversionPatternRewriter &rewriter, Operation *op, template inline void addPatternIfOpInConvertTorchOpsSet( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet, + std::function dynamicLegalityFcn = [](AtenOp) { + return false; + }) { MLIRContext *context = patterns.getContext(); std::optional opName = TorchToTcpPattern(context).getRootKind(); @@ -90,7 +93,7 @@ inline void addPatternIfOpInConvertTorchOpsSet( if (convertTorchOpsSet.empty() || convertTorchOpsSet.contains( opName->getStringRef().ltrim(torch::Torch::kTorchOpPrefix))) { - target.addIllegalOp(); + target.addDynamicallyLegalOp(dynamicLegalityFcn); patterns.add(typeConverter, context); } } diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index a7968e5f..e72ac608 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -57,3 +57,44 @@ func.func @torch.aten.index_put_impl_op(%arg0: !torch.vtensor<[25],f32>, %arg1: %1 = torch.aten._index_put_impl %arg0, %0, %arg2, %false, %false : !torch.vtensor<[25],f32>, !torch.list>, !torch.vtensor<[],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[25],f32> return %1 : !torch.vtensor<[25],f32> } + + +// ----- + +// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} { +// CHECK-SAME: dilation = [1 : index, 1 : index], +// CHECK-SAME: groups = 1 : i64, +// CHECK-SAME: output_padding = [1 : index, 1 : index], +// CHECK-SAME: padding = [1 : index, 1 : index], +// CHECK-SAME: stride = [2 : index, 2 : index], +// CHECK-SAME: torch_operand_names = ["input", "weight", "bias"], +// CHECK-SAME: transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32> +func.func @torcn.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> + %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> + %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> + + return %output : !torch.vtensor<[1,64,2,200],f32> +} + +// ----- + +// CHECK: torch.aten.convolution %{{.*}} +func.func @torch.aten.regular_convolution() -> !torch.vtensor<[1,32,16,1600],f32> { + %false = torch.constant.bool false + %input = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32> + %weights = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int0x0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %output = torch.aten.convolution %input, %weights, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,32,16,1600],f32> + + return %output : !torch.vtensor<[1,32,16,1600],f32> +}