Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converter for index.Tensor_hacked_twin #98

Merged
60 changes: 60 additions & 0 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,63 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
}
};

class ConvertAtenIndexTensorHackedTwin
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto self = adaptor.getSelf();
auto indicesList = op.getIndices();
SmallVector<Value> indices;
if (!getListConstructElements(indicesList, indices))
return op.emitError("Failed to match list of indices");
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);

// possible that this should ignore the first batch dim?
if (indices.size() != cast<RankedTensorType>(self.getType()).getRank())
return op.emitError(
"Expected the number of indicies to equal rank of self");

for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
int numNonOneAxis = 0;
auto ttype = cast<RankedTensorType>(idx.getType());
for (int j = 0; j < ttype.getRank(); j++)
if (ttype.getShape()[j] != 1)
numNonOneAxis++;
if (numNonOneAxis > 1)
return op.emitError(
"Expected the input shape to have a single non-one axis");
// convert it to a 1-dim vector
if (ttype.getRank() != 1) {
ReassociationIndices reassocIndices;
for (int j = 0; j < ttype.getRank(); j++)
reassocIndices.push_back(j);
SmallVector<ReassociationIndices> ri = {reassocIndices};
auto reshape =
rewriter.create<tensor::CollapseShapeOp>(op.getLoc(), idx, ri);
idx = reshape.getResult();
}

SmallVector<int64_t> outShape(
cast<RankedTensorType>(self.getType()).getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto gather = rewriter.create<tcp::GatherOp>(
op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i));
self = gather.getResult();
}

rewriter.replaceOp(op, self);
return success();
}
};

} // namespace

void torch_to_tcp::populateDataMovementPatternsAndLegality(
Expand All @@ -294,4 +351,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenIndexSelectOp,
AtenIndexSelectOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<
ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>(
typeConverter, patterns, target, convertTorchOpsSet);
}
45 changes: 0 additions & 45 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,6 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

helper.addOperand("self", adaptor.getSelf());
helper.addOperand("index", adaptor.getIndex());
helper.addIntAttr("axis", op.getDim());

return helper.replace();
}
};

class ConvertAtenIndexTensorHackedTwinOp
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

helper.addOperand("self", input);
helper.addAsMultipleTensorOperands("index_", op.getIndices());

return helper.replace();
}
};

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
Expand Down Expand Up @@ -380,8 +337,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
Expand Down
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor
%0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32>
return %0 : !torch.vtensor<[4,2],f32>
}

// -----

// CHECK-label: @torch.aten.index.tensor_hacked_twin
// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1xi64> -> tensor<1x20x30xf32>
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<5xi64> -> tensor<1x5x30xf32>
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<20xi64> -> tensor<1x5x20xf32>
// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]]
// CHECK: return %[[RET]]
func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {
// there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses
// arange to select all of the
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4 // this is a dtype on arange....
%int-1 = torch.constant.int -1
%arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.aten.arange.start_step will remain unchanged right? As we don't handle it in tcp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this remains unchanged. I was thinking that we should add an optimization pass that removes this when it selects everything, but I have yet to do that

%arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>

%l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list<vtensor>
%ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,5,20],f32>
return %ret : !torch.vtensor<[1,5,20],f32>
}
38 changes: 0 additions & 38 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,43 +1,5 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32>
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64>
// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32>
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32>
return %0 : !torch.vtensor<[2,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64>
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64>
// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64>
// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64>
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64>
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64>
// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] {torch_operand_names = ["self", "index_0", "index_1", "index_2", "index_3"]} : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32>
func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32>
return %1 : !torch.vtensor<[1,30,19,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_put_impl_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32>
Expand Down
Loading