-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tcp Custom Op and Torch->Tcp conversions + lit tests (#17)
Op definition for `tcp.custom_op` and lowerings for `aten.gather` and `aten.index.Tensor_hacked_twin`. This is useful to represent and lower custom/opaque operations that have tensor inputs and outputs, but 1) lack a clear operational semantics, or 2) have an operational semantic that isn't defined by Tcp (such as user defined custom operations), or 3) need to "flow through" the middleend because they target pre-written kernels available in the backend As a follow-on, I'll look into making `torch -> tcp.custom_op` conversions automatic as much as possible, to minimize the overhead in adding individual conversion patterns manually. Then this could be the last step in a frontend lowering pipeline, to have all remaining `torch.aten` ops converted.
- Loading branch information
1 parent
690574e
commit f79c063
Showing
7 changed files
with
211 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" | ||
|
||
#include "mlir-tcp/Dialect/IR/TcpDialect.h" | ||
#include "mlir-tcp/Dialect/IR/TcpOps.h" | ||
|
||
#include "PopulatePatterns.h" | ||
#include "Utils.h" | ||
#include "torch-mlir/Conversion/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
|
||
#include "llvm/ADT/StringSet.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tcp; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> { | ||
public: | ||
using OpConversionPattern<AtenGatherOp>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
SmallVector<Type> resultTypes; | ||
if (failed( | ||
OpConversionPattern<AtenGatherOp>::getTypeConverter()->convertTypes( | ||
op->getResultTypes(), resultTypes))) { | ||
return failure(); | ||
} | ||
|
||
int64_t dimVal; | ||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimVal))) | ||
return failure(); | ||
|
||
auto indexAttr = | ||
rewriter.getNamedAttr("axis", rewriter.getI64IntegerAttr(dimVal)); | ||
|
||
auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>( | ||
op, resultTypes, ValueRange{adaptor.getSelf(), adaptor.getIndex()}, | ||
indexAttr); | ||
newOp.setOpName(op->getName().getStringRef()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ConvertAtenIndexTensorHackedTwinOp | ||
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> { | ||
public: | ||
using OpConversionPattern<AtenIndexTensorHackedTwinOp>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
SmallVector<Type> resultTypes; | ||
if (failed( | ||
OpConversionPattern<AtenIndexTensorHackedTwinOp>::getTypeConverter() | ||
->convertTypes(op->getResultTypes(), resultTypes))) { | ||
return failure(); | ||
} | ||
|
||
SmallVector<Value> tensorOperands; | ||
Value input = adaptor.getSelf(); | ||
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>(); | ||
// Check input is a tensor type. | ||
if (!inputTensorType) | ||
return rewriter.notifyMatchFailure( | ||
op, "Only tensor types input are currently supported"); | ||
tensorOperands.push_back(input); | ||
|
||
// Deal with torch.prim.ListConstruct of non const value to get the index | ||
Value indexList = op.getIndices(); | ||
SmallVector<Value> indicesTorchType; | ||
if (!getListConstructElements(indexList, indicesTorchType)) | ||
return op.emitError( | ||
"unimplemented: the tensor list is not from list construct"); | ||
SmallVector<Value> indexTensors = getTypeConvertedValues( | ||
rewriter, op->getLoc(), getTypeConverter(), indicesTorchType); | ||
|
||
tensorOperands.append(indexTensors.begin(), indexTensors.end()); | ||
|
||
auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes, | ||
tensorOperands); | ||
newOp.setOpName(op->getName().getStringRef()); | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( | ||
TypeConverter &typeConverter, RewritePatternSet &patterns, | ||
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { | ||
|
||
#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \ | ||
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \ | ||
typeConverter, patterns, target, convertTorchOpsSet) | ||
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); | ||
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); | ||
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten.gather_op( | ||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64> | ||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { | ||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32> | ||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64> | ||
// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> | ||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32> | ||
// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32> | ||
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { | ||
%false = torch.constant.bool false | ||
%int1 = torch.constant.int 1 | ||
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> | ||
return %0 : !torch.vtensor<[2,2],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op( | ||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32> | ||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64> | ||
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64> | ||
// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64> | ||
// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { | ||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32> | ||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64> | ||
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64> | ||
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64> | ||
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64> | ||
// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32> | ||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32> | ||
// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32> | ||
func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { | ||
%0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor> | ||
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32> | ||
return %1 : !torch.vtensor<[1,30,19,3],f32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// RUN: tcp-opt %s -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @tcp_custom_op( | ||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32> | ||
// CHECK: return %[[T0]] : tensor<?x?xf32> | ||
func.func @tcp_custom_op(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 : tensor<?x?xf32> -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @tcp_custom_op_with_named_attrs( | ||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] {axis = 0 : i32} : tensor<?x?xf32> -> tensor<?x?xf32> | ||
// CHECK: return %[[T0]] : tensor<?x?xf32> | ||
func.func @tcp_custom_op_with_named_attrs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 {axis = 0 : i32} : tensor<?x?xf32> -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @tcp_custom_op_without_op_name(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
// expected-error@+1{{expected attribute value}} | ||
%0 = tcp.custom_op() %arg0 : tensor<?x?xf32> -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} |