Skip to content

Commit

Permalink
Add converter for index.Tensor_hacked_twin (#98)
Browse files Browse the repository at this point in the history
* Add converter for index.Tensor_hacked_twin -> tcp.gather &
tcp.broadcast
* Remove the  `tcp.custom_op` variants of `Tensor_hacked_twin`
* The Tcp variants do not have full coverage of the PyTorch, but we
should seek to expand the coverage of our converters
* Fix tcp.const when it is a dense resource of ints to cast the value
* Add verifier for `tcp.gather` and `tcp.const` to ensure that is used
correctly
  • Loading branch information
matthewfl authored Sep 20, 2024
1 parent 3fc3290 commit 2f129aa
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 75 deletions.
6 changes: 4 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ RUN apt-get update && \
clang \
clang-format \
gdb \
black
black \
sudo

# Install bazel
ARG ARCH="x86_64"
Expand All @@ -42,7 +43,8 @@ WORKDIR /opt/src/mlir-tcp
RUN groupadd -o -g ${GID} ${GROUP} && \
useradd -u ${UID} -g ${GROUP} -ms /bin/bash ${USER} && \
usermod -aG sudo ${USER} && \
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp
chown -R ${USER}:${GROUP} /opt/src/mlir-tcp && \
echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers

# Switch to user
USER ${USER}
3 changes: 3 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let assemblyFormat = "attr-dict `:` type($out)";

let hasFolder = 1;
let hasVerifier = 1;
}

def Tcp_BroadcastOp : Tcp_Op<"broadcast", [
Expand Down Expand Up @@ -657,6 +658,8 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
);

let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";

let hasVerifier = 1;
}

def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {
Expand Down
76 changes: 76 additions & 0 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,79 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
}
};

class ConvertAtenIndexTensorHackedTwin
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ------- Matching the OP -------
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
auto indicesList = op.getIndices();
SmallVector<Value> indices;
if (!getListConstructElements(indicesList, indices))
return op.emitError("Failed to match list of indices");

for (unsigned int i = 0; i < indices.size(); i++) {
auto ttype = cast<RankedTensorType>(
getTypeConverter()->convertType(indices[i].getType()));
if (ttype.getRank() != selfType.getRank() - i) {
// Can use tensor.gather instead for this. But will require that there
// are some broadcasting to get the shapes to match what is expected
return failure("Failed to rewrite Tensor_hacked_twin. Need the "
"element gather for this");
}
for (int j = 1; j < ttype.getRank(); j++) {
if (ttype.getShape()[j] != 1)
return failure("Expected the axes >=1 to have size 1");
}
}

// ------ Rewriting the OP ---------

indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);

for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
auto ttype = cast<RankedTensorType>(idx.getType());
auto selfType = cast<RankedTensorType>(self.getType());
SmallVector<int64_t> outShape(selfType.getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(
rewriter, idx, outShape.size() - ttype.getRank());

SmallVector<Value> broadcastValues;
SmallVector<int64_t> broadcastAxes;
for (unsigned int j = 0; j < selfType.getRank(); j++) {
if (j != i) {
broadcastAxes.push_back(j);
broadcastValues.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
}
}

auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()),
expandedShape, broadcastValues,
rewriter.getI64ArrayAttr(broadcastAxes));

auto gather = rewriter.create<tcp::GatherOp>(op.getLoc(), outType, self,
broadcastedShape.getResult(),
rewriter.getIndexAttr(i));
self = gather.getResult();
}

rewriter.replaceOp(op, self);
return success();
}
};

} // namespace

void torch_to_tcp::populateDataMovementPatternsAndLegality(
Expand All @@ -294,4 +367,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenIndexSelectOp,
AtenIndexSelectOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<
ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>(
typeConverter, patterns, target, convertTorchOpsSet);
}
11 changes: 11 additions & 0 deletions lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Expand Down Expand Up @@ -162,6 +163,16 @@ class ConvertValueTensorLiteralOp
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, denseIntAttr);
return success();
}
if (auto elements =
dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
if (resultType.getElementType().isInteger() &&
resultType != adaptor.getValue().getType()) {
auto attr =
DenseResourceElementsAttr::get(resultType, elements.getRawHandle());
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, attr);
return success();
}
}

rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType,
adaptor.getValue());
Expand Down
29 changes: 1 addition & 28 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -46,33 +47,6 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
return helper.replace();
}
};

class ConvertAtenIndexTensorHackedTwinOp
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

helper.addOperand("self", input);
helper.addAsMultipleTensorOperands("index_", op.getIndices());

return helper.replace();
}
};

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
public:
Expand Down Expand Up @@ -381,7 +355,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<Convert##AtenOp, AtenOp>( \
typeConverter, patterns, target, convertTorchOpsSet)
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
if (rankIncrease == 0)
return input;
RankedTensorType inputType = input.getType().cast<RankedTensorType>();

SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());
Expand Down
42 changes: 42 additions & 0 deletions lib/Dialect/IR/TcpOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() {

OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); }

LogicalResult ConstOp::verify() {
if (getValueAttr().getType() != getType())
return emitOpError("can not be used to cast types");
return success();
}

LogicalResult CastOp::verify() {
auto inputType = getIn().getType().cast<RankedTensorType>();
auto outputType = getOut().getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -170,6 +176,42 @@ LogicalResult CastOp::verify() {
return success();
}

LogicalResult GatherOp::verify() {
auto inputTensor = cast<RankedTensorType>(getInput().getType());
auto indicesTensor = cast<RankedTensorType>(getIndices().getType());
int64_t gatherDim = getDimAttr().getValue().getSExtValue();

if (inputTensor.getRank() != indicesTensor.getRank())
return emitOpError(
"requires that the input tensor and indices are the same rank");

for (int i = 0; i < inputTensor.getRank(); i++) {
if (inputTensor.getShape()[i] < indicesTensor.getShape()[i] &&
!(inputTensor.getShape()[i] == ShapedType::kDynamic ||
indicesTensor.getShape()[i] == ShapedType::kDynamic ||
i == gatherDim)) {
std::stringstream ss;
ss << "indicies index " << i
<< " expected to be less than or equal to input "
<< " (" << indicesTensor.getShape()[i]
<< " <= " << inputTensor.getShape()[i] << ")";
return emitOpError(ss.str());
}
}

if (getResult().getType().getShape() != indicesTensor.getShape()) {
return emitOpError(
"Expect the shape of the indicies to match the output shape");
}

if (getResult().getType().getElementType() != inputTensor.getElementType()) {
return emitOpError(
"Expect the element type of the return to match the input");
}

return success();
}

//===----------------------------------------------------------------------===//
// BindSymbolicShapeOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ AOT_TEST_SUITE = [
("broadcast_unit_dim_to_dynamic_with_rank_increase", False),
("gather_elements", False),
("gather_slices", False),
("index_hacked_twin", False),
]

py_library(
Expand Down
16 changes: 16 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(
model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes
)


def index_hacked_twin_loader() -> TorchLoaderOutput:
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# not using dynamic dim currently as the i1 tensor would ideally
# be generated conditioned on the shape
i1 = torch.tensor([[0], [1], [2], [3]])
return x[i1, [2, 5, 7]]

x = torch.rand(4, 10)

return TorchLoaderOutput(
model=Model(),
inputs=(x,),
)
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor
%0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32>
return %0 : !torch.vtensor<[4,2],f32>
}

// -----

// CHECK-label: @torch.aten.index.tensor_hacked_twin
// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32>
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32>
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32>
// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]]
// CHECK: return %[[RET]]
func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {
// there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses
// arange to select all of the
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4 // this is a dtype on arange....
%int-1 = torch.constant.int -1
%arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>

%l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list<vtensor>
%ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,5,20],f32>
return %ret : !torch.vtensor<[1,5,20],f32>
}
38 changes: 0 additions & 38 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,43 +1,5 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32>
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64>
// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32>
func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32>
return %0 : !torch.vtensor<[2,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64>
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64>
// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64>
// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64>
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64>
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64>
// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] {torch_operand_names = ["self", "index_0", "index_1", "index_2", "index_3"]} : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32>
func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,30,19,3],f32>
return %1 : !torch.vtensor<[1,30,19,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index_put_impl_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32>
Expand Down
Loading

0 comments on commit 2f129aa

Please sign in to comment.