Skip to content

Commit

Permalink
add scatter to tcp custom op (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfl authored Sep 6, 2024
1 parent 78c1c25 commit 36418dc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,28 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
}
};

class ConvertAtenSliceScatterOp
: public OpConversionPattern<AtenSliceScatterOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// this should really have some tcp op to reduce to. So going to CustomOp
// is more of a placeholder than a serious implementation
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());
helper.addOperand("src", adaptor.getSrc());
helper.addIntAttr("dim", op.getDim());
helper.addIntAttr("start", op.getStart());
helper.addIntAttr("end", op.getEnd());
helper.addIntAttr("step", op.getStep());

return helper.replace();
}
};

} // namespace

void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
Expand All @@ -369,6 +391,7 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSliceScatterOp);
// AtenViewOp can still live after torch-to-tcp conversion
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,20 @@ func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>,
%4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list<int> -> !torch.vtensor<[?,24,16,16],f32>
return %4 : !torch.vtensor<[?,24,16,16],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.slice_scatter(
// CHECK-DAG: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,3],f32> -> tensor<1x3xf32>
// CHECK-DAG: %[[ARG1:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.slice_scatter") %[[ARG0]], %[[ARG1]] {dim = 1 : i64, end = 3 : i64, start = 2 : i64, step = 4 : i64, torch_operand_names = ["self", "src"]} : tensor<1x3xf32>, tensor<1x2xf32> -> tensor<1x3xf32>
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<1x3xf32> -> !torch.vtensor<[1,3],f32>
// CHECK: return %[[RET]]
func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,3],f32> {
%dim = torch.constant.int 1
%start = torch.constant.int 2
%end = torch.constant.int 3
%step = torch.constant.int 4
%0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32>
return %0 : !torch.vtensor<[1,3],f32>
}

0 comments on commit 36418dc

Please sign in to comment.