Skip to content

Commit

Permalink
[TorchToTcp] lowering support for aten.size.int op (#35)
Browse files Browse the repository at this point in the history
Lowering `aten.size.int` op to `tensor::dim` op during torch-to-tcp.

Test (in docker):

`bazel test //test/...`

---------

Co-authored-by: Ze Zhang <[email protected]>
  • Loading branch information
zezhang and Ze Zhang authored Jan 22, 2024
1 parent a241eed commit 49e65fb
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 42 deletions.
31 changes: 31 additions & 0 deletions lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ class ConvertValueTensorLiteralOp
}
};

class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
public:
using OpConversionPattern<AtenSizeIntOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenSizeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value self = adaptor.getSelf();
auto type = self.getType().cast<RankedTensorType>();
if (!isa<ConstantIntOp>(op->getOperand(1).getDefiningOp())) {
return rewriter.notifyMatchFailure(op, "dim must be a constant int");
}
auto constIntOp =
dyn_cast<ConstantIntOp>(op->getOperand(1).getDefiningOp());
int idxVal = constIntOp.getValueAttr().getValue().getSExtValue();
if (idxVal < 0 || idxVal >= type.getRank()) {
return rewriter.notifyMatchFailure(op, "dim must be in range");
}
auto idxOp = rewriter.create<arith::ConstantIndexOp>(loc, idxVal);
auto dimOp = rewriter.create<tensor::DimOp>(loc, self, idxOp);
auto result =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), dimOp);

rewriter.replaceOp(op, result);

return success();
}
};

template <typename AtenOpT, int fillVal>
class ConvertAtenZerosOnesOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -229,6 +259,7 @@ void torch_to_tcp::populateMiscPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_MISC_OP_PATTERN(AtenBroadcastToOp);
INSERT_ATEN_MISC_OP_PATTERN(ValueTensorLiteralOp);
INSERT_ATEN_MISC_OP_PATTERN(AtenSizeIntOp);
#undef INSERT_ATEN_MISC_OP_PATTERN

#define INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenOpPattern, AtenOp, Val) \
Expand Down
100 changes: 58 additions & 42 deletions test/Conversion/TorchToTcp/misc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ func.func @torch.aten.ones_ui8(%arg0: !torch.int, %arg1: !torch.int) -> !torch.v
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.zeros_like_f32(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----
Expand All @@ -202,13 +202,13 @@ return %0 : !torch.vtensor<[?,?],f32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],si32>
func.func @torch.aten.zeros_like_si32(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],si32> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}

// -----
Expand All @@ -226,13 +226,13 @@ return %0 : !torch.vtensor<[?,?],si32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xi8> -> !torch.vtensor<[?,?],ui8>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],ui8>
func.func @torch.aten.zeros_like_ui8(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],ui8> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.zeros_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
}

// -----
Expand All @@ -250,13 +250,13 @@ return %0 : !torch.vtensor<[?,?],ui8>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.ones_like_f32(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----
Expand All @@ -274,13 +274,13 @@ return %0 : !torch.vtensor<[?,?],f32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],si32>
func.func @torch.aten.ones_like_si32(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],si32> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}

// -----
Expand All @@ -298,11 +298,27 @@ return %0 : !torch.vtensor<[?,?],si32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[BC]] : tensor<?x?xi8> -> !torch.vtensor<[?,?],ui8>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],ui8>
func.func @torch.aten.ones_like_ui8(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],ui8> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%none = torch.constant.none
%cuda3A0 = torch.constant.device "cuda:0"
%0 = torch.aten.ones_like %arg0, %int3, %int0, %cuda3A0, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
return %0 : !torch.vtensor<[?,?],ui8>
}

// -----

// CHECK-LABEL: @torch.aten.size.int(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch.constant.int 0
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = arith.index_cast %[[DIM0]] : index to i64
// CHECK: return
func.func @torch.aten.size.int(%arg0: !torch.vtensor<[?,?],f32>) -> () {
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
return
}

0 comments on commit 49e65fb

Please sign in to comment.