From 36418dc578e73929ee1f0fb9d90a882164673d09 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 6 Sep 2024 16:15:34 -0400 Subject: [PATCH] add scatter to tcp custom op (#95) --- lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 23 +++++++++++++++++++ .../Conversion/TorchToTcp/tcp_custom_ops.mlir | 17 ++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 85b0c39..3f3468d 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -349,6 +349,28 @@ class ConvertAtenViewOp : public OpConversionPattern { } }; +class ConvertAtenSliceScatterOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // this should really have some tcp op to reduce to. So going to CustomOp + // is more of a placeholder than a serious implementation + torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, + getTypeConverter()}; + helper.addOperand("self", adaptor.getSelf()); + helper.addOperand("src", adaptor.getSrc()); + helper.addIntAttr("dim", op.getDim()); + helper.addIntAttr("start", op.getStart()); + helper.addIntAttr("end", op.getEnd()); + helper.addIntAttr("step", op.getStep()); + + return helper.replace(); + } +}; + } // namespace void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( @@ -369,6 +391,7 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp); + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSliceScatterOp); // AtenViewOp can still live after torch-to-tcp conversion patterns.add(typeConverter, patterns.getContext()); #undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index 3807d04..bb84631 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -341,3 +341,20 @@ func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>, %4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list -> !torch.vtensor<[?,24,16,16],f32> return %4 : !torch.vtensor<[?,24,16,16],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.slice_scatter( +// CHECK-DAG: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,3],f32> -> tensor<1x3xf32> +// CHECK-DAG: %[[ARG1:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32> +// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.slice_scatter") %[[ARG0]], %[[ARG1]] {dim = 1 : i64, end = 3 : i64, start = 2 : i64, step = 4 : i64, torch_operand_names = ["self", "src"]} : tensor<1x3xf32>, tensor<1x2xf32> -> tensor<1x3xf32> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<1x3xf32> -> !torch.vtensor<[1,3],f32> +// CHECK: return %[[RET]] +func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,3],f32> { + %dim = torch.constant.int 1 + %start = torch.constant.int 2 + %end = torch.constant.int 3 + %step = torch.constant.int 4 + %0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32> + return %0 : !torch.vtensor<[1,3],f32> +}