Skip to content

Commit

Permalink
Convert index_put op to a custom op (#21)
Browse files Browse the repository at this point in the history
### Testing

```
bazel test --config=clang_linux //test:Conversion/TorchToTcp/tcp_custom_ops.mlir.test
```
  • Loading branch information
navahgar authored Nov 9, 2023
1 parent eaaa7cf commit ec83b88
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
51 changes: 51 additions & 0 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,56 @@ class ConvertAtenIndexTensorHackedTwinOp
return success();
}
};

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
public:
using OpConversionPattern<Aten_IndexPutImplOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(OpConversionPattern<Aten_IndexPutImplOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> operands;
operands.push_back(adaptor.getSelf());

// Handle indices
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(adaptor.getIndices(), indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
SmallVector<Value> indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), indicesTorchType);
operands.append(indexTensors.begin(), indexTensors.end());

operands.push_back(adaptor.getValues());

bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
return op.emitError("expected accumulate operand to be a bool constant");
bool unsafe;
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
return op.emitError("expected unsafe operand to be a bool constant");

SmallVector<NamedAttribute> attrs;
attrs.push_back(
rewriter.getNamedAttr("accumulate", rewriter.getBoolAttr(accumulate)));
attrs.push_back(
rewriter.getNamedAttr("unsafe", rewriter.getBoolAttr(unsafe)));
attrs.push_back(rewriter.getNamedAttr(
"op_name", rewriter.getStringAttr(op->getName().getStringRef())));

rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes, operands,
attrs);
return success();
}
};

} // namespace

void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
Expand All @@ -107,5 +157,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
}
21 changes: 21 additions & 0 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,24 @@ func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f3
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32>
return %1 : !torch.vtensor<[1,30,19,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_put_impl_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[10],si32>
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[25],f32>
// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[25],f32> -> tensor<25xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10],si32> -> tensor<10xi32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten._index_put_impl") %[[T0]], %[[T2]], %[[T1]]
// CHECK-SAME: {accumulate = false, unsafe = false}
// CHECK-SAME: tensor<25xf32>, tensor<10xi32>, tensor<f32> -> tensor<25xf32>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<25xf32> -> !torch.vtensor<[25],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[25],f32>
func.func @torch.aten.index_put_impl_op(%arg0: !torch.vtensor<[25],f32>, %arg1: !torch.vtensor<[10],si32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[25],f32> {
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[10],si32>) -> !torch.list<optional<vtensor>>
%1 = torch.aten._index_put_impl %arg0, %0, %arg2, %false, %false : !torch.vtensor<[25],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[25],f32>
return %1 : !torch.vtensor<[25],f32>
}

0 comments on commit ec83b88

Please sign in to comment.