diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 88ffe21..37b42ab 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -345,6 +345,51 @@ class ConvertAtenSliceScatterOp } }; +class ConvertAtenArangeStartStepOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // At this point all tensors should have value semantics, and hence the + // `layout` check can be ignored. + + // The pin_memory should be either `False` or `none`. + bool pinMemory; + if (!isa(op.getPinMemory().getType()) && + (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory must be either None or false"); + } + + torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, + getTypeConverter()}; + bool allStatic = true; + // trt-mlir takes F64Attr, so we need to convert const int to fp attr + if (!helper.tryConvertConstToFloatAttr("start", op.getStart())) { + allStatic = false; + helper.addOperand("start", adaptor.getStart()); + } + if (!helper.tryConvertConstToFloatAttr("end", op.getEnd())) { + allStatic = false; + helper.addOperand("end", adaptor.getEnd()); + } + if (!helper.tryConvertConstToFloatAttr("step", op.getStep())) { + allStatic = false; + helper.addOperand("step", adaptor.getStep()); + } + // static start, end, and step case will be handled through TOSA dialect + if (allStatic) + return rewriter.notifyMatchFailure(op, + "only non-constant values supported"); + + return helper.replace(); + } +}; + } // namespace void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( @@ -365,8 +410,10 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSliceScatterOp); - // AtenViewOp can still live after torch-to-tcp conversion + // Following ops can still live after torch-to-tcp conversion patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); #undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN // Torch -> TOSA doesn't handle transposed convolutions; map them to diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index 479d580..5249287 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -529,6 +529,29 @@ void TorchToTcpCustomOpConversionHelper::addFloatAttr(std::string attrName, rewriter.getNamedAttr(attrName, rewriter.getF64FloatAttr(constVal))); } +bool TorchToTcpCustomOpConversionHelper::tryConvertConstToFloatAttr( + std::string attrName, Value value) { + if (conversionResult.failed()) + return false; + + double constFPVal; + if (matchPattern(value, torch::Torch::m_TorchConstantFloat(&constFPVal))) { + attrs.push_back( + rewriter.getNamedAttr(attrName, rewriter.getF64FloatAttr(constFPVal))); + return true; + } + + // convert constant int to fp if possible + int64_t constIntVal; + if (matchPattern(value, torch::Torch::m_TorchConstantInt(&constIntVal))) { + attrs.push_back(rewriter.getNamedAttr( + attrName, rewriter.getF64FloatAttr(static_cast(constIntVal)))); + return true; + } + + return false; +} + void TorchToTcpCustomOpConversionHelper::addListOfIntsAttr(std::string attrName, Value value) { if (conversionResult.failed()) diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index 42156ea..23bf029 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -167,6 +167,9 @@ class TorchToTcpCustomOpConversionHelper { // Add value as a named float attribute void addFloatAttr(std::string attrName, Value value); + // Try to convert a const value to a float attribute. + bool tryConvertConstToFloatAttr(std::string attrName, Value value); + // Add value as a named list of integers attribute void addListOfIntsAttr(std::string attrName, Value value); diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index 9bfcba9..c44a78f 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -320,3 +320,22 @@ func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !to %0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32> return %0 : !torch.vtensor<[1,3],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.arange.start_step( +// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.vtensor<[?],si32> { +// CHECK: %[[IN:.*]] = torch_c.to_i64 %[[ARG0]] +// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.arange.start_step") %[[IN]] {start = 0.000000e+00 : f64, step = 1.000000e+00 : f64, torch_operand_names = ["end"]} : i64 -> tensor +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor -> !torch.vtensor<[?],si32> +// CHECK: return %[[RET]] : !torch.vtensor<[?],si32> +func.func @torch.aten.arange.start_step(%arg0: !torch.int) -> !torch.vtensor<[?],si32> { + %false = torch.constant.bool false + %none = torch.constant.none + %cpu = torch.constant.device "cpu" + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %1 = torch.aten.arange.start_step %int0, %arg0, %int1, %int3, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[?],si32> + return %1 : !torch.vtensor<[?],si32> +}