Skip to content

Commit

Permalink
Tcp Custom Op and Torch->Tcp conversions + lit tests (#17)
Browse files Browse the repository at this point in the history
Op definition for `tcp.custom_op` and lowerings for `aten.gather` and
`aten.index.Tensor_hacked_twin`.

This is useful to represent and lower custom/opaque operations that have
tensor inputs and outputs, but
1) lack a clear operational semantics, or 
2) have an operational semantic that isn't defined by Tcp (such as user
defined custom operations), or
3) need to "flow through" the middleend because they target pre-written
kernels available in the backend

As a follow-on, I'll look into making `torch -> tcp.custom_op`
conversions automatic as much as possible, to minimize the overhead in
adding individual conversion patterns manually. Then this could be the
last step in a frontend lowering pipeline, to have all remaining
`torch.aten` ops converted.
  • Loading branch information
sjain-stanford authored Nov 3, 2023
1 parent 690574e commit f79c063
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 1 deletion.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ cc_library(
"lib/Conversion/TorchToTcp/Elementwise.cpp",
"lib/Conversion/TorchToTcp/Misc.cpp",
"lib/Conversion/TorchToTcp/PopulatePatterns.h",
"lib/Conversion/TorchToTcp/TcpCustomOp.cpp",
"lib/Conversion/TorchToTcp/TorchToTcp.cpp",
"lib/Conversion/TorchToTcp/Utils.cpp",
"lib/Conversion/TorchToTcp/Utils.h",
Expand Down
26 changes: 25 additions & 1 deletion include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementT
specified.
}];

// TODO: Does clamp need to support complex tensors?
let arguments = (ins
Tcp_FloatOrIntTensor:$in,
OptionalAttr<F32Attr>:$min_float,
Expand Down Expand Up @@ -311,6 +310,31 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [
let hasVerifier = 1;
}

def Tcp_CustomOp : Tcp_Op<"custom_op", []> {
let summary = "Custom opaque operation in Tcp dialect";

let description = [{
This is useful to represent and lower custom/opaque operations that have
tensor inputs and outputs, but
(1) lack a clear operational semantics or
(2) have an operational semantic that isn't defined by Tcp (e.g user
defined custom operations) or
(3) need to "flow through" the middle-end because they target existing
pre-written kernels available in the backend.
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
StrAttr:$op_name
);

let results = (outs
Variadic<AnyType>:$outputs
);

let assemblyFormat = "`(`$op_name`)` $inputs attr-dict `:` type($inputs) `->` type($outputs)";
}

def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> {
let summary = "Computes square root of input, elementwise";

Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TorchToTcp/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ void populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet);

void populateTcpCustomOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet);

} // namespace torch_to_tcp
} // namespace mlir
111 changes: 111 additions & 0 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"

#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"

#include "PopulatePatterns.h"
#include "Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include "llvm/ADT/StringSet.h"

using namespace mlir;
using namespace mlir::tcp;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern<AtenGatherOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(
OpConversionPattern<AtenGatherOp>::getTypeConverter()->convertTypes(
op->getResultTypes(), resultTypes))) {
return failure();
}

int64_t dimVal;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimVal)))
return failure();

auto indexAttr =
rewriter.getNamedAttr("axis", rewriter.getI64IntegerAttr(dimVal));

auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(
op, resultTypes, ValueRange{adaptor.getSelf(), adaptor.getIndex()},
indexAttr);
newOp.setOpName(op->getName().getStringRef());
return success();
}
};

class ConvertAtenIndexTensorHackedTwinOp
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
public:
using OpConversionPattern<AtenIndexTensorHackedTwinOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(
OpConversionPattern<AtenIndexTensorHackedTwinOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> tensorOperands;
Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
tensorOperands.push_back(input);

// Deal with torch.prim.ListConstruct of non const value to get the index
Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
SmallVector<Value> indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), indicesTorchType);

tensorOperands.append(indexTensors.begin(), indexTensors.end());

auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes,
tensorOperands);
newOp.setOpName(op->getName().getStringRef());
return success();
}
};
} // namespace

void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) {

#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
}
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTcp/TorchToTcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase<ConvertTorchToTcp> {
torch_to_tcp::populateDataMovementPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet);

torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
typeConverter, patterns, target, convertTorchOpsSet);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp -canonicalize -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64>
// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32>
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32>
return %0 : !torch.vtensor<[2,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64>
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64>
// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64>
// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64>
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64>
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64>
// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32>
func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32>
return %1 : !torch.vtensor<[1,30,19,3],f32>
}
29 changes: 29 additions & 0 deletions test/Dialect/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: tcp-opt %s -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @tcp_custom_op(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: return %[[T0]] : tensor<?x?xf32>
func.func @tcp_custom_op(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func.func @tcp_custom_op_with_named_attrs(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[T0:.*]] = tcp.custom_op("torch.aten.my_custom_op") %[[ARG0]] {axis = 0 : i32} : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: return %[[T0]] : tensor<?x?xf32>
func.func @tcp_custom_op_with_named_attrs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.custom_op("torch.aten.my_custom_op") %arg0 {axis = 0 : i32} : tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

func.func @tcp_custom_op_without_op_name(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1{{expected attribute value}}
%0 = tcp.custom_op() %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

0 comments on commit f79c063

Please sign in to comment.