From cd6f99d7cebbb959e361b3a71eb6747d73e41577 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Wed, 18 Sep 2024 11:59:40 -0700 Subject: [PATCH 01/11] add converter for index.Tensor_hacked_twin --- lib/Conversion/TorchToTcp/DataMovement.cpp | 60 +++++++++++++++++++ lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 45 -------------- test/Conversion/TorchToTcp/data_movement.mlir | 26 ++++++++ .../Conversion/TorchToTcp/tcp_custom_ops.mlir | 38 ------------ 4 files changed, 86 insertions(+), 83 deletions(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 312b9e9..069d4a5 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -278,6 +278,63 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { } }; +class ConvertAtenIndexTensorHackedTwin + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto self = adaptor.getSelf(); + auto indicesList = op.getIndices(); + SmallVector indices; + if (!getListConstructElements(indicesList, indices)) + return op.emitError("Failed to match list of indices"); + indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), + indices); + + // possible that this should ignore the first batch dim? + if (indices.size() != cast(self.getType()).getRank()) + return op.emitError( + "Expected the number of indicies to equal rank of self"); + + for (unsigned int i = 0; i < indices.size(); i++) { + auto idx = indices[i]; + int numNonOneAxis = 0; + auto ttype = cast(idx.getType()); + for (int j = 0; j < ttype.getRank(); j++) + if (ttype.getShape()[j] != 1) + numNonOneAxis++; + if (numNonOneAxis > 1) + return op.emitError( + "Expected the input shape to have a single non-one axis"); + // convert it to a 1-dim vector + if (ttype.getRank() != 1) { + ReassociationIndices reassocIndices; + for (int j = 0; j < ttype.getRank(); j++) + reassocIndices.push_back(j); + SmallVector ri = {reassocIndices}; + auto reshape = + rewriter.create(op.getLoc(), idx, ri); + idx = reshape.getResult(); + } + + SmallVector outShape( + cast(self.getType()).getShape()); + outShape[i] = ttype.getNumElements(); + auto outType = RankedTensorType::get( + outShape, cast(self.getType()).getElementType()); + + auto gather = rewriter.create( + op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); + self = gather.getResult(); + } + + rewriter.replaceOp(op, self); + return success(); + } +}; + } // namespace void torch_to_tcp::populateDataMovementPatternsAndLegality( @@ -294,4 +351,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality( torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< + ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>( + typeConverter, patterns, target, convertTorchOpsSet); } diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 3f3468d..9d36e0d 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -29,49 +29,6 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { -class ConvertAtenGatherOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, - getTypeConverter()}; - - helper.addOperand("self", adaptor.getSelf()); - helper.addOperand("index", adaptor.getIndex()); - helper.addIntAttr("axis", op.getDim()); - - return helper.replace(); - } -}; - -class ConvertAtenIndexTensorHackedTwinOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, - getTypeConverter()}; - - Value input = adaptor.getSelf(); - auto inputTensorType = input.getType().dyn_cast(); - // Check input is a tensor type. - if (!inputTensorType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - - helper.addOperand("self", input); - helper.addAsMultipleTensorOperands("index_", op.getIndices()); - - return helper.replace(); - } -}; class ConvertAten_IndexPutImplOp : public OpConversionPattern { @@ -380,8 +337,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( #define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \ torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ typeConverter, patterns, target, convertTorchOpsSet) - INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); - INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN( diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index 4e76298..be5fa1d 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor %0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32> return %0 : !torch.vtensor<[4,2],f32> } + +// ----- + +// CHECK-label: @torch.aten.index.tensor_hacked_twin +// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1xi64> -> tensor<1x20x30xf32> +// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<5xi64> -> tensor<1x5x30xf32> +// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<20xi64> -> tensor<1x5x20xf32> +// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]] +// CHECK: return %[[RET]] +func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> { + // there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses + // arange to select all of the + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 // this is a dtype on arange.... + %int-1 = torch.constant.int -1 + %arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> + %arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + %arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64> + + %l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list + %ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list -> !torch.vtensor<[1,5,20],f32> + return %ret : !torch.vtensor<[1,5,20],f32> +} diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index bb84631..9bfcba9 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -1,43 +1,5 @@ // RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s -// CHECK-LABEL: func.func @torch.aten.gather_op( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64> -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { -// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32> -// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64> -// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32> -// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32> -func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> - return %0 : !torch.vtensor<[2,2],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32> -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64> -// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64> -// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64> -// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64> -// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64> -// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64> -// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64> -// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] {torch_operand_names = ["self", "index_0", "index_1", "index_2", "index_3"]} : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32> -// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32> -// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32> -func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { - %0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list - %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list -> !torch.vtensor<[1,30,19,3],f32> - return %1 : !torch.vtensor<[1,30,19,3],f32> -} - -// ----- // CHECK-LABEL: func.func @torch.aten.index_put_impl_op( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32> From db2ba3f2d42ecd6b21c8e7b2d5fdf59f5a23e98f Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Wed, 18 Sep 2024 12:29:39 -0700 Subject: [PATCH 02/11] clang format --- lib/Conversion/TorchToTcp/DataMovement.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 069d4a5..045a08c 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -329,7 +329,7 @@ class ConvertAtenIndexTensorHackedTwin op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); self = gather.getResult(); } - + rewriter.replaceOp(op, self); return success(); } From 936d691ad678b2b38963bd01e35bc1187c4116f8 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Thu, 19 Sep 2024 09:53:50 -0700 Subject: [PATCH 03/11] make it use the same build as everything else when running tcp-opt so that asserts are enabled --- test/AotCompile/BUILD | 1 + tools/aot/aot_compile.bzl | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index cd79257..a2cf1da 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -37,6 +37,7 @@ AOT_TEST_SUITE = [ ("broadcast_unit_dim_to_dynamic_with_rank_increase", False), ("gather_elements", False), ("gather_slices", False), + ("gather_slices_select", False), ] py_library( diff --git a/tools/aot/aot_compile.bzl b/tools/aot/aot_compile.bzl index 6366008..e8af36e 100644 --- a/tools/aot/aot_compile.bzl +++ b/tools/aot/aot_compile.bzl @@ -146,24 +146,22 @@ def aot_compile( native.genrule( name = "gen_" + name + "_mlir_tcp", - srcs = [_name + "_torch.mlir"], + srcs = [_name + "_torch.mlir", "//:tcp-opt"], outs = [_name + "_tcp.mlir"], cmd = "./$(location //:tcp-opt)" + - " -torch-backend-to-tcp-backend-pipeline $(SRCS)" + + " -torch-backend-to-tcp-backend-pipeline $(location " + _name + "_torch.mlir)" + " > $(OUTS)", - tools = ["//:tcp-opt"], ) native.genrule( name = "gen_" + name + "_mlir_llvm", # When tcp_source is provided, prefer that as the start for aot_compile; # else continue using genrule generated *_tcp.mlir (torch_export workflow) - srcs = [tcp_source or (_name + "_tcp.mlir")], + srcs = [tcp_source or (_name + "_tcp.mlir"), "//:tcp-opt"], outs = [_name + "_llvm.mlir"], cmd = "./$(location //:tcp-opt)" + - " -tcp-to-llvm-pipeline $(SRCS)" + + " -tcp-to-llvm-pipeline $(location " + (tcp_source or (_name + "_tcp.mlir")) + ")" + " > $(OUTS)", - tools = ["//:tcp-opt"], ) native.genrule( From 543f8ee44525ce4593de9470b4d4b2393d9615c6 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Thu, 19 Sep 2024 11:32:10 -0700 Subject: [PATCH 04/11] add a verifier for gather --- docker/Dockerfile | 6 ++++-- include/mlir-tcp/Dialect/IR/TcpOps.td | 2 ++ lib/Dialect/IR/TcpOps.cpp | 21 +++++++++++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f0eda43..4ea60dc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,7 +23,8 @@ RUN apt-get update && \ clang \ clang-format \ gdb \ - black + black \ + sudo # Install bazel ARG ARCH="x86_64" @@ -42,7 +43,8 @@ WORKDIR /opt/src/mlir-tcp RUN groupadd -o -g ${GID} ${GROUP} && \ useradd -u ${UID} -g ${GROUP} -ms /bin/bash ${USER} && \ usermod -aG sudo ${USER} && \ - chown -R ${USER}:${GROUP} /opt/src/mlir-tcp + chown -R ${USER}:${GROUP} /opt/src/mlir-tcp && \ + echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers # Switch to user USER ${USER} diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 45fa0e9..e6d6e3a 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -617,6 +617,8 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] ); let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)"; + + let hasVerifier = 1; } def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> { diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 3a8e5af..204bcc9 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -170,6 +170,27 @@ LogicalResult CastOp::verify() { return success(); } +LogicalResult GatherOp::verify() { + auto inputTensor = cast(getInput().getType()); + auto indicesTensor = cast(getIndices().getType()); + int64_t gatherDim = getDimAttr().getValue().getSExtValue(); + + if(inputTensor.getRank() != indicesTensor.getRank()) + return emitOpError("tcp.gather requires that the input tensor and indices are the same rank"); + + for(int i = 0; i < inputTensor.getRank(); i++) { + if(inputTensor.getShape()[i] != indicesTensor.getShape()[i]) { + if(!(inputTensor.getShape()[i] == ShapedType::kDynamic || + indicesTensor.getShape()[i] == 1 || + i == gatherDim)) { + return emitOpError("indices tensor does not match expected shape"); + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // BindSymbolicShapeOp //===----------------------------------------------------------------------===// From 0e8e84a6eccc646ddb387c16e2beafe8ad87dbca Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Thu, 19 Sep 2024 14:30:22 -0700 Subject: [PATCH 05/11] cp --- lib/Conversion/TorchToTcp/DataMovement.cpp | 69 +++++++++++++++++-- lib/Conversion/TorchToTcp/Utils.cpp | 2 + lib/Dialect/IR/TcpOps.cpp | 10 +-- test/AotCompile/model_loader_lib.py | 14 ++++ test/Conversion/TorchToTcp/data_movement.mlir | 6 +- 5 files changed, 88 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 045a08c..0edea8d 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -285,23 +285,78 @@ class ConvertAtenIndexTensorHackedTwin LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // ------- Matching the OP ------- auto self = adaptor.getSelf(); + auto selfType = cast(self.getType()); auto indicesList = op.getIndices(); SmallVector indices; if (!getListConstructElements(indicesList, indices)) return op.emitError("Failed to match list of indices"); + + for(unsigned int i = 0; i < indices.size(); i++) { + auto ttype = cast(getTypeConverter()->convertType(indices[i].getType())); + if(ttype.getRank() != selfType.getRank() - i) { + // Can use tensor.gather instead for this. But will require that there are some broadcasting to get the shapes to match + // what is expected + return failure("Failed to rewrite Tensor_hacked_twin. Need the element gather for this"); + } + for(int j = 1; j < ttype.getRank(); j++) { + if(ttype.getShape()[j] != 1) + return failure("Expected the axes >=1 to have size 1"); + } + } + + // ------ Rewriting the OP --------- + indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), indices); - // possible that this should ignore the first batch dim? - if (indices.size() != cast(self.getType()).getRank()) - return op.emitError( - "Expected the number of indicies to equal rank of self"); - for (unsigned int i = 0; i < indices.size(); i++) { + + for(unsigned int i = 0; i < indices.size(); i++) { + auto idx = indices[i]; + auto ttype = cast(idx.getType()); + auto selfType = cast(self.getType()); + SmallVector outShape(selfType.getShape()); + outShape[i] = ttype.getNumElements(); + auto outType = RankedTensorType::get( + outShape, cast(self.getType()).getElementType()); + + auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank()); + SmallVector broadcastValues; + SmallVector broadcastAxes; + + + for(unsigned int j = 0; j < selfType.getRank(); j++) { + if(j != i) { + broadcastAxes.push_back(j); + broadcastValues.push_back(rewriter.create(op.getLoc(), self, j)); + } + } + + auto broadcastedShape = rewriter.create( + op.getLoc(), + RankedTensorType::get(outShape, ttype.getElementType()), + expandedShape, + broadcastValues, + rewriter.getI64ArrayAttr(broadcastAxes) + ); + + auto gather = rewriter.create( + op.getLoc(), outType, self, broadcastedShape.getResult(), rewriter.getIndexAttr(i) + ); + self = gather.getResult(); + } + + /*for (unsigned int i = 0; i < indices.size(); i++) { auto idx = indices[i]; int numNonOneAxis = 0; auto ttype = cast(idx.getType()); + if(ttype.getRank() != indices.size() - i) { + // there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each? + // so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are + return op.emitError("unsure what to do"); + } for (int j = 0; j < ttype.getRank(); j++) if (ttype.getShape()[j] != 1) numNonOneAxis++; @@ -328,7 +383,9 @@ class ConvertAtenIndexTensorHackedTwin auto gather = rewriter.create( op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); self = gather.getResult(); - } + }*/ + + // assert(op.getType() == self.getType()); rewriter.replaceOp(op, self); return success(); diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index c762ff1..16fdd5b 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -41,6 +41,8 @@ getTcpSignednessAttr(MLIRContext *context, // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { + if(rankIncrease == 0) + return input; RankedTensorType inputType = input.getType().cast(); SmallVector reassociationMap(inputType.getRank()); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 204bcc9..83c098d 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -179,15 +179,17 @@ LogicalResult GatherOp::verify() { return emitOpError("tcp.gather requires that the input tensor and indices are the same rank"); for(int i = 0; i < inputTensor.getRank(); i++) { - if(inputTensor.getShape()[i] != indicesTensor.getShape()[i]) { - if(!(inputTensor.getShape()[i] == ShapedType::kDynamic || - indicesTensor.getShape()[i] == 1 || + if(inputTensor.getShape()[i] != indicesTensor.getShape()[i] && !( + inputTensor.getShape()[i] == ShapedType::kDynamic || i == gatherDim)) { return emitOpError("indices tensor does not match expected shape"); - } } } + if(getResult().getType().getShape() != indicesTensor.getShape()) { + return emitOpError("Expect the shape of the indicies to match the output shape"); + } + return success(); } diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 1b3e88c..6f9cc7c 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -590,3 +590,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return TorchLoaderOutput( model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) + +def gather_slices_select_loader() -> TorchLoaderOutput: + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + i1 = torch.tensor([[0],[1],[2],[3]]) + return x[i1,[2,5,7]] + + x = torch.rand(4,10) + # batch = Dim("batch", min=3) + # dynamic_shapes = {"x": {0: batch}} + + return TorchLoaderOutput( + model=Model(), inputs=(x,),# dynamic_shapes=dynamic_shapes + ) diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index be5fa1d..dcf5cda 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -69,9 +69,9 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor // CHECK-label: @torch.aten.index.tensor_hacked_twin // CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1xi64> -> tensor<1x20x30xf32> -// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<5xi64> -> tensor<1x5x30xf32> -// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<20xi64> -> tensor<1x5x20xf32> +// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32> +// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32> +// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32> // CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]] // CHECK: return %[[RET]] func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> { From 60fdf0d7004a9e68787326156cc269b4ecbb9e48 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 08:15:47 -0700 Subject: [PATCH 06/11] add casting of int types for resources --- include/mlir-tcp/Dialect/IR/TcpOps.td | 1 + lib/Conversion/TorchToTcp/DataMovement.cpp | 44 +--------------------- lib/Conversion/TorchToTcp/Misc.cpp | 9 +++++ lib/Dialect/IR/TcpOps.cpp | 6 +++ 4 files changed, 17 insertions(+), 43 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index e6d6e3a..6aa267a 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -178,6 +178,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> { let assemblyFormat = "attr-dict `:` type($out)"; let hasFolder = 1; + let hasVerifier = 1; } def Tcp_BroadcastOp : Tcp_Op<"broadcast", [ diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 0edea8d..80868fa 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -311,8 +311,6 @@ class ConvertAtenIndexTensorHackedTwin indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), indices); - - for(unsigned int i = 0; i < indices.size(); i++) { auto idx = indices[i]; auto ttype = cast(idx.getType()); @@ -323,10 +321,9 @@ class ConvertAtenIndexTensorHackedTwin outShape, cast(self.getType()).getElementType()); auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank()); + SmallVector broadcastValues; SmallVector broadcastAxes; - - for(unsigned int j = 0; j < selfType.getRank(); j++) { if(j != i) { broadcastAxes.push_back(j); @@ -348,45 +345,6 @@ class ConvertAtenIndexTensorHackedTwin self = gather.getResult(); } - /*for (unsigned int i = 0; i < indices.size(); i++) { - auto idx = indices[i]; - int numNonOneAxis = 0; - auto ttype = cast(idx.getType()); - if(ttype.getRank() != indices.size() - i) { - // there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each? - // so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are - return op.emitError("unsure what to do"); - } - for (int j = 0; j < ttype.getRank(); j++) - if (ttype.getShape()[j] != 1) - numNonOneAxis++; - if (numNonOneAxis > 1) - return op.emitError( - "Expected the input shape to have a single non-one axis"); - // convert it to a 1-dim vector - if (ttype.getRank() != 1) { - ReassociationIndices reassocIndices; - for (int j = 0; j < ttype.getRank(); j++) - reassocIndices.push_back(j); - SmallVector ri = {reassocIndices}; - auto reshape = - rewriter.create(op.getLoc(), idx, ri); - idx = reshape.getResult(); - } - - SmallVector outShape( - cast(self.getType()).getShape()); - outShape[i] = ttype.getNumElements(); - auto outType = RankedTensorType::get( - outShape, cast(self.getType()).getElementType()); - - auto gather = rewriter.create( - op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); - self = gather.getResult(); - }*/ - - // assert(op.getType() == self.getType()); - rewriter.replaceOp(op, self); return success(); } diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index c66453a..3acffd1 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -16,6 +16,7 @@ #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -162,6 +163,14 @@ class ConvertValueTensorLiteralOp rewriter.replaceOpWithNewOp(op, resultType, denseIntAttr); return success(); } + if (auto elements = dyn_cast(op.getValueAttr())) { + if(resultType.getElementType().isInteger() && + resultType != adaptor.getValue().getType()) { + auto attr = DenseResourceElementsAttr::get(resultType, elements.getRawHandle()); + rewriter.replaceOpWithNewOp(op, resultType, attr); + return success(); + } + } rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue()); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 83c098d..ffafaa9 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() { OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); } +LogicalResult ConstOp::verify() { + if(getValueAttr().getType() != getType()) + return emitOpError("can not be used to cast types"); + return success(); +} + LogicalResult CastOp::verify() { auto inputType = getIn().getType().cast(); auto outputType = getOut().getType().cast(); From b7fc082d4c1b52a951324e0f88d9e958f1643dc6 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 08:35:08 -0700 Subject: [PATCH 07/11] tests pass --- test/AotCompile/BUILD | 2 +- test/AotCompile/model_loader_lib.py | 10 +++++----- tools/aot/aot_compile.bzl | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index a2cf1da..633c8f2 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -37,7 +37,7 @@ AOT_TEST_SUITE = [ ("broadcast_unit_dim_to_dynamic_with_rank_increase", False), ("gather_elements", False), ("gather_slices", False), - ("gather_slices_select", False), + ("index_hacked_twin", False), ] py_library( diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 6f9cc7c..0c45258 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -591,16 +591,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) -def gather_slices_select_loader() -> TorchLoaderOutput: +def index_hacked_twin_loader() -> TorchLoaderOutput: class Model(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: + # not using dynamic dim currently as the i1 tensor would ideally + # be generated conditioned on the shape i1 = torch.tensor([[0],[1],[2],[3]]) return x[i1,[2,5,7]] - + x = torch.rand(4,10) - # batch = Dim("batch", min=3) - # dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=Model(), inputs=(x,),# dynamic_shapes=dynamic_shapes + model=Model(), inputs=(x,), ) diff --git a/tools/aot/aot_compile.bzl b/tools/aot/aot_compile.bzl index e8af36e..fb23e49 100644 --- a/tools/aot/aot_compile.bzl +++ b/tools/aot/aot_compile.bzl @@ -179,7 +179,7 @@ def aot_compile( name = "gen_" + name + "_host_asm", srcs = [_name + ".ll"], outs = [_name + ".S"], - cmd = "./$(location @llvm-project//llvm:llc) -O3 < $(SRCS)" + + cmd = "./$(location @llvm-project//llvm:llc) -O3 --relocation-model=pic < $(SRCS)" + " > $(OUTS)", tools = ["@llvm-project//llvm:llc"], ) From 9d4a0325997c19fca2f042ef62e2e6d63da0fa48 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 08:38:05 -0700 Subject: [PATCH 08/11] formatting --- lib/Conversion/TorchToTcp/DataMovement.cpp | 53 +++++++++++----------- lib/Conversion/TorchToTcp/Misc.cpp | 16 ++++--- lib/Conversion/TorchToTcp/Utils.cpp | 2 +- lib/Dialect/IR/TcpOps.cpp | 36 ++++++++------- test/AotCompile/model_loader_lib.py | 12 +++-- 5 files changed, 63 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 80868fa..004bc9c 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -293,15 +293,17 @@ class ConvertAtenIndexTensorHackedTwin if (!getListConstructElements(indicesList, indices)) return op.emitError("Failed to match list of indices"); - for(unsigned int i = 0; i < indices.size(); i++) { - auto ttype = cast(getTypeConverter()->convertType(indices[i].getType())); - if(ttype.getRank() != selfType.getRank() - i) { - // Can use tensor.gather instead for this. But will require that there are some broadcasting to get the shapes to match - // what is expected - return failure("Failed to rewrite Tensor_hacked_twin. Need the element gather for this"); + for (unsigned int i = 0; i < indices.size(); i++) { + auto ttype = cast( + getTypeConverter()->convertType(indices[i].getType())); + if (ttype.getRank() != selfType.getRank() - i) { + // Can use tensor.gather instead for this. But will require that there + // are some broadcasting to get the shapes to match what is expected + return failure("Failed to rewrite Tensor_hacked_twin. Need the " + "element gather for this"); } - for(int j = 1; j < ttype.getRank(); j++) { - if(ttype.getShape()[j] != 1) + for (int j = 1; j < ttype.getRank(); j++) { + if (ttype.getShape()[j] != 1) return failure("Expected the axes >=1 to have size 1"); } } @@ -311,7 +313,7 @@ class ConvertAtenIndexTensorHackedTwin indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), indices); - for(unsigned int i = 0; i < indices.size(); i++) { + for (unsigned int i = 0; i < indices.size(); i++) { auto idx = indices[i]; auto ttype = cast(idx.getType()); auto selfType = cast(self.getType()); @@ -319,30 +321,29 @@ class ConvertAtenIndexTensorHackedTwin outShape[i] = ttype.getNumElements(); auto outType = RankedTensorType::get( outShape, cast(self.getType()).getElementType()); - - auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank()); - + + auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims( + rewriter, idx, outShape.size() - ttype.getRank()); + SmallVector broadcastValues; SmallVector broadcastAxes; - for(unsigned int j = 0; j < selfType.getRank(); j++) { - if(j != i) { + for (unsigned int j = 0; j < selfType.getRank(); j++) { + if (j != i) { broadcastAxes.push_back(j); - broadcastValues.push_back(rewriter.create(op.getLoc(), self, j)); + broadcastValues.push_back( + rewriter.create(op.getLoc(), self, j)); } } auto broadcastedShape = rewriter.create( - op.getLoc(), - RankedTensorType::get(outShape, ttype.getElementType()), - expandedShape, - broadcastValues, - rewriter.getI64ArrayAttr(broadcastAxes) - ); - - auto gather = rewriter.create( - op.getLoc(), outType, self, broadcastedShape.getResult(), rewriter.getIndexAttr(i) - ); - self = gather.getResult(); + op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()), + expandedShape, broadcastValues, + rewriter.getI64ArrayAttr(broadcastAxes)); + + auto gather = rewriter.create(op.getLoc(), outType, self, + broadcastedShape.getResult(), + rewriter.getIndexAttr(i)); + self = gather.getResult(); } rewriter.replaceOp(op, self); diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index 3acffd1..fdde4d4 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -163,13 +163,15 @@ class ConvertValueTensorLiteralOp rewriter.replaceOpWithNewOp(op, resultType, denseIntAttr); return success(); } - if (auto elements = dyn_cast(op.getValueAttr())) { - if(resultType.getElementType().isInteger() && - resultType != adaptor.getValue().getType()) { - auto attr = DenseResourceElementsAttr::get(resultType, elements.getRawHandle()); - rewriter.replaceOpWithNewOp(op, resultType, attr); - return success(); - } + if (auto elements = + dyn_cast(op.getValueAttr())) { + if (resultType.getElementType().isInteger() && + resultType != adaptor.getValue().getType()) { + auto attr = + DenseResourceElementsAttr::get(resultType, elements.getRawHandle()); + rewriter.replaceOpWithNewOp(op, resultType, attr); + return success(); + } } rewriter.replaceOpWithNewOp(op, resultType, diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index 16fdd5b..7f80b34 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -41,7 +41,7 @@ getTcpSignednessAttr(MLIRContext *context, // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { - if(rankIncrease == 0) + if (rankIncrease == 0) return input; RankedTensorType inputType = input.getType().cast(); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index ffafaa9..ccbe44a 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -128,7 +128,7 @@ LogicalResult IsolatedGroupOp::verify() { OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); } LogicalResult ConstOp::verify() { - if(getValueAttr().getType() != getType()) + if (getValueAttr().getType() != getType()) return emitOpError("can not be used to cast types"); return success(); } @@ -178,25 +178,27 @@ LogicalResult CastOp::verify() { LogicalResult GatherOp::verify() { auto inputTensor = cast(getInput().getType()); - auto indicesTensor = cast(getIndices().getType()); - int64_t gatherDim = getDimAttr().getValue().getSExtValue(); - - if(inputTensor.getRank() != indicesTensor.getRank()) - return emitOpError("tcp.gather requires that the input tensor and indices are the same rank"); - - for(int i = 0; i < inputTensor.getRank(); i++) { - if(inputTensor.getShape()[i] != indicesTensor.getShape()[i] && !( - inputTensor.getShape()[i] == ShapedType::kDynamic || - i == gatherDim)) { - return emitOpError("indices tensor does not match expected shape"); - } + auto indicesTensor = cast(getIndices().getType()); + int64_t gatherDim = getDimAttr().getValue().getSExtValue(); + + if (inputTensor.getRank() != indicesTensor.getRank()) + return emitOpError("tcp.gather requires that the input tensor and indices " + "are the same rank"); + + for (int i = 0; i < inputTensor.getRank(); i++) { + if (inputTensor.getShape()[i] != indicesTensor.getShape()[i] && + !(inputTensor.getShape()[i] == ShapedType::kDynamic || + i == gatherDim)) { + return emitOpError("indices tensor does not match expected shape"); } + } - if(getResult().getType().getShape() != indicesTensor.getShape()) { - return emitOpError("Expect the shape of the indicies to match the output shape"); - } + if (getResult().getType().getShape() != indicesTensor.getShape()) { + return emitOpError( + "Expect the shape of the indicies to match the output shape"); + } - return success(); + return success(); } //===----------------------------------------------------------------------===// diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 0c45258..68174c8 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -591,16 +591,18 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) + def index_hacked_twin_loader() -> TorchLoaderOutput: class Model(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # not using dynamic dim currently as the i1 tensor would ideally # be generated conditioned on the shape - i1 = torch.tensor([[0],[1],[2],[3]]) - return x[i1,[2,5,7]] - - x = torch.rand(4,10) + i1 = torch.tensor([[0], [1], [2], [3]]) + return x[i1, [2, 5, 7]] + + x = torch.rand(4, 10) return TorchLoaderOutput( - model=Model(), inputs=(x,), + model=Model(), + inputs=(x,), ) From ca634c701cfe438315e5a7bb25f738bd6c47f2fb Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 08:51:06 -0700 Subject: [PATCH 09/11] formatting --- test/Conversion/TorchToTcp/data_movement.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index dcf5cda..b087f9a 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -76,16 +76,16 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor // CHECK: return %[[RET]] func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> { // there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses - // arange to select all of the + // arange to select all of the %none = torch.constant.none %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %int4 = torch.constant.int 4 // this is a dtype on arange.... %int-1 = torch.constant.int -1 - %arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> - %arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - %arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64> - + %arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> + %arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + %arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64> + %l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list %ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list -> !torch.vtensor<[1,5,20],f32> return %ret : !torch.vtensor<[1,5,20],f32> From fe6f405bcbb309d71888473cc6f7176af0a687f9 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 10:11:13 -0700 Subject: [PATCH 10/11] gather allows for non-gather dims to be smaller than input tensor. Also gather custom op seems to be required --- lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 18 ++++++++++++++++++ lib/Dialect/IR/TcpOps.cpp | 3 ++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 9d36e0d..88ffe21 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -30,6 +30,23 @@ using namespace mlir::torch::Torch; namespace { +class ConvertAtenGatherOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, + getTypeConverter()}; + + helper.addOperand("self", adaptor.getSelf()); + helper.addOperand("index", adaptor.getIndex()); + helper.addIntAttr("axis", op.getDim()); + + return helper.replace(); + } +}; class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: @@ -337,6 +354,7 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( #define INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenOp) \ torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN( diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index ccbe44a..b65f1d9 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -186,8 +186,9 @@ LogicalResult GatherOp::verify() { "are the same rank"); for (int i = 0; i < inputTensor.getRank(); i++) { - if (inputTensor.getShape()[i] != indicesTensor.getShape()[i] && + if (inputTensor.getShape()[i] < indicesTensor.getShape()[i] && !(inputTensor.getShape()[i] == ShapedType::kDynamic || + indicesTensor.getShape()[i] == ShapedType::kDynamic || i == gatherDim)) { return emitOpError("indices tensor does not match expected shape"); } From ea0e5d80bd91a65dc153fab4845125f39fb5a6a5 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 10:49:18 -0700 Subject: [PATCH 11/11] update error message --- lib/Dialect/IR/TcpOps.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index b65f1d9..8054921 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -182,15 +182,20 @@ LogicalResult GatherOp::verify() { int64_t gatherDim = getDimAttr().getValue().getSExtValue(); if (inputTensor.getRank() != indicesTensor.getRank()) - return emitOpError("tcp.gather requires that the input tensor and indices " - "are the same rank"); + return emitOpError( + "requires that the input tensor and indices are the same rank"); for (int i = 0; i < inputTensor.getRank(); i++) { if (inputTensor.getShape()[i] < indicesTensor.getShape()[i] && !(inputTensor.getShape()[i] == ShapedType::kDynamic || indicesTensor.getShape()[i] == ShapedType::kDynamic || i == gatherDim)) { - return emitOpError("indices tensor does not match expected shape"); + std::stringstream ss; + ss << "indicies index " << i + << " expected to be less than or equal to input " + << " (" << indicesTensor.getShape()[i] + << " <= " << inputTensor.getShape()[i] << ")"; + return emitOpError(ss.str()); } } @@ -199,6 +204,11 @@ LogicalResult GatherOp::verify() { "Expect the shape of the indicies to match the output shape"); } + if (getResult().getType().getElementType() != inputTensor.getElementType()) { + return emitOpError( + "Expect the element type of the return to match the input"); + } + return success(); }