diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index 3807d04..bb84631 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -341,3 +341,20 @@ func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>, %4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list -> !torch.vtensor<[?,24,16,16],f32> return %4 : !torch.vtensor<[?,24,16,16],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.slice_scatter( +// CHECK-DAG: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,3],f32> -> tensor<1x3xf32> +// CHECK-DAG: %[[ARG1:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32> +// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.slice_scatter") %[[ARG0]], %[[ARG1]] {dim = 1 : i64, end = 3 : i64, start = 2 : i64, step = 4 : i64, torch_operand_names = ["self", "src"]} : tensor<1x3xf32>, tensor<1x2xf32> -> tensor<1x3xf32> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<1x3xf32> -> !torch.vtensor<[1,3],f32> +// CHECK: return %[[RET]] +func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,3],f32> { + %dim = torch.constant.int 1 + %start = torch.constant.int 2 + %end = torch.constant.int 3 + %step = torch.constant.int 4 + %0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32> + return %0 : !torch.vtensor<[1,3],f32> +}