Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converter for index.Tensor_hacked_twin #98

Merged
6 changes: 4 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ RUN apt-get update && \
clang \
clang-format \
gdb \
black
black \
sudo

# Install bazel
ARG ARCH="x86_64"
Expand All @@ -42,7 +43,8 @@ WORKDIR /opt/src/mlir-tcp
RUN groupadd -o -g ${GID} ${GROUP} && \
useradd -u ${UID} -g ${GROUP} -ms /bin/bash ${USER} && \
usermod -aG sudo ${USER} && \
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp && \
echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers

# Switch to user
USER ${USER}
3 changes: 3 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let assemblyFormat = "attr-dict `:` type($out)";

let hasFolder = 1;
let hasVerifier = 1;
}

def Tcp_BroadcastOp : Tcp_Op<"broadcast", [
Expand Down Expand Up @@ -617,6 +618,8 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
);

let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";

let hasVerifier = 1;
}

def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {
Expand Down
76 changes: 76 additions & 0 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,79 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
}
};

class ConvertAtenIndexTensorHackedTwin
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ------- Matching the OP -------
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
auto indicesList = op.getIndices();
SmallVector<Value> indices;
if (!getListConstructElements(indicesList, indices))
return op.emitError("Failed to match list of indices");

for (unsigned int i = 0; i < indices.size(); i++) {
auto ttype = cast<RankedTensorType>(
getTypeConverter()->convertType(indices[i].getType()));
if (ttype.getRank() != selfType.getRank() - i) {
// Can use tensor.gather instead for this. But will require that there
// are some broadcasting to get the shapes to match what is expected
return failure("Failed to rewrite Tensor_hacked_twin. Need the "
"element gather for this");
}
for (int j = 1; j < ttype.getRank(); j++) {
if (ttype.getShape()[j] != 1)
return failure("Expected the axes >=1 to have size 1");
}
}

// ------ Rewriting the OP ---------

indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);

for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
auto ttype = cast<RankedTensorType>(idx.getType());
auto selfType = cast<RankedTensorType>(self.getType());
SmallVector<int64_t> outShape(selfType.getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(
rewriter, idx, outShape.size() - ttype.getRank());

SmallVector<Value> broadcastValues;
SmallVector<int64_t> broadcastAxes;
for (unsigned int j = 0; j < selfType.getRank(); j++) {
if (j != i) {
broadcastAxes.push_back(j);
broadcastValues.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
}
}

auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()),
expandedShape, broadcastValues,
rewriter.getI64ArrayAttr(broadcastAxes));

auto gather = rewriter.create<tcp::GatherOp>(op.getLoc(), outType, self,
broadcastedShape.getResult(),
rewriter.getIndexAttr(i));
self = gather.getResult();
}

rewriter.replaceOp(op, self);
return success();
}
};

} // namespace

void torch_to_tcp::populateDataMovementPatternsAndLegality(
Expand All @@ -294,4 +367,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenIndexSelectOp,
AtenIndexSelectOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<
ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>(
typeConverter, patterns, target, convertTorchOpsSet);
}
11 changes: 11 additions & 0 deletions lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Expand Down Expand Up @@ -162,6 +163,16 @@ class ConvertValueTensorLiteralOp
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, denseIntAttr);
return success();
}
if (auto elements =
dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
if (resultType.getElementType().isInteger() &&
resultType != adaptor.getValue().getType()) {
auto attr =
DenseResourceElementsAttr::get(resultType, elements.getRawHandle());
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, attr);
return success();
}
}

rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType,
adaptor.getValue());
Expand Down
45 changes: 0 additions & 45 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,6 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

helper.addOperand("self", adaptor.getSelf());
helper.addOperand("index", adaptor.getIndex());
helper.addIntAttr("axis", op.getDim());

return helper.replace();
}
};

class ConvertAtenIndexTensorHackedTwinOp
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

helper.addOperand("self", input);
helper.addAsMultipleTensorOperands("index_", op.getIndices());

return helper.replace();
}
};

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
Expand Down Expand Up @@ -380,8 +337,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
#define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ getTcpSignednessAttr(MLIRContext *context,
// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
if (rankIncrease == 0)
return input;
RankedTensorType inputType = input.getType().cast<RankedTensorType>();

SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());
Expand Down
31 changes: 31 additions & 0 deletions lib/Dialect/IR/TcpOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() {

OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); }

LogicalResult ConstOp::verify() {
if (getValueAttr().getType() != getType())
return emitOpError("can not be used to cast types");
return success();
}

LogicalResult CastOp::verify() {
auto inputType = getIn().getType().cast<RankedTensorType>();
auto outputType = getOut().getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -170,6 +176,31 @@ LogicalResult CastOp::verify() {
return success();
}

LogicalResult GatherOp::verify() {
auto inputTensor = cast<RankedTensorType>(getInput().getType());
auto indicesTensor = cast<RankedTensorType>(getIndices().getType());
int64_t gatherDim = getDimAttr().getValue().getSExtValue();

if (inputTensor.getRank() != indicesTensor.getRank())
return emitOpError("tcp.gather requires that the input tensor and indices "
"are the same rank");

for (int i = 0; i < inputTensor.getRank(); i++) {
if (inputTensor.getShape()[i] != indicesTensor.getShape()[i] &&
!(inputTensor.getShape()[i] == ShapedType::kDynamic ||
i == gatherDim)) {
return emitOpError("indices tensor does not match expected shape");
matthewfl marked this conversation as resolved.
Show resolved Hide resolved
}
}
matthewfl marked this conversation as resolved.
Show resolved Hide resolved

if (getResult().getType().getShape() != indicesTensor.getShape()) {
return emitOpError(
"Expect the shape of the indicies to match the output shape");
}

return success();
}

//===----------------------------------------------------------------------===//
// BindSymbolicShapeOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ AOT_TEST_SUITE = [
("broadcast_unit_dim_to_dynamic_with_rank_increase", False),
("gather_elements", False),
("gather_slices", False),
("index_hacked_twin", False),
]

py_library(
Expand Down
16 changes: 16 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(
model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes
)


def index_hacked_twin_loader() -> TorchLoaderOutput:
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# not using dynamic dim currently as the i1 tensor would ideally
# be generated conditioned on the shape
i1 = torch.tensor([[0], [1], [2], [3]])
return x[i1, [2, 5, 7]]

x = torch.rand(4, 10)

return TorchLoaderOutput(
model=Model(),
inputs=(x,),
)
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor
%0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32>
return %0 : !torch.vtensor<[4,2],f32>
}

// -----

// CHECK-label: @torch.aten.index.tensor_hacked_twin
// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32>
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32>
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32>
// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]]
// CHECK: return %[[RET]]
func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {
// there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses
// arange to select all of the
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4 // this is a dtype on arange....
%int-1 = torch.constant.int -1
%arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>

%l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list<vtensor>
%ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,5,20],f32>
return %ret : !torch.vtensor<[1,5,20],f32>
}
38 changes: 0 additions & 38 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,43 +1,5 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32>
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64>
// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32>
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32>
return %0 : !torch.vtensor<[2,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64>
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64>
// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64>
// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64>
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64>
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64>
// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] {torch_operand_names = ["self", "index_0", "index_1", "index_2", "index_3"]} : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32>
func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32>
return %1 : !torch.vtensor<[1,30,19,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_put_impl_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32>
Expand Down
Loading