From 2f129aa39b0b020dfef1424f07d027e48489410c Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 14:20:33 -0400 Subject: [PATCH] Add converter for index.Tensor_hacked_twin (#98) * Add converter for index.Tensor_hacked_twin -> tcp.gather & tcp.broadcast * Remove the `tcp.custom_op` variants of `Tensor_hacked_twin` * The Tcp variants do not have full coverage of the PyTorch, but we should seek to expand the coverage of our converters * Fix tcp.const when it is a dense resource of ints to cast the value * Add verifier for `tcp.gather` and `tcp.const` to ensure that is used correctly --- docker/Dockerfile | 6 +- include/mlir-tcp/Dialect/IR/TcpOps.td | 3 + lib/Conversion/TorchToTcp/DataMovement.cpp | 76 +++++++++++++++++++ lib/Conversion/TorchToTcp/Misc.cpp | 11 +++ lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 29 +------ lib/Conversion/TorchToTcp/Utils.cpp | 2 + lib/Dialect/IR/TcpOps.cpp | 42 ++++++++++ test/AotCompile/BUILD | 1 + test/AotCompile/model_loader_lib.py | 16 ++++ test/Conversion/TorchToTcp/data_movement.mlir | 26 +++++++ .../Conversion/TorchToTcp/tcp_custom_ops.mlir | 38 ---------- tools/aot/aot_compile.bzl | 12 ++- 12 files changed, 187 insertions(+), 75 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f0eda43e..4ea60dc8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,7 +23,8 @@ RUN apt-get update && \ clang \ clang-format \ gdb \ - black + black \ + sudo # Install bazel ARG ARCH="x86_64" @@ -42,7 +43,8 @@ WORKDIR /opt/src/mlir-tcp RUN groupadd -o -g ${GID} ${GROUP} && \ useradd -u ${UID} -g ${GROUP} -ms /bin/bash ${USER} && \ usermod -aG sudo ${USER} && \ - chown -R ${USER}:${GROUP} /opt/src/mlir-tcp + chown -R ${USER}:${GROUP} /opt/src/mlir-tcp && \ + echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers # Switch to user USER ${USER} diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index d73d28b5..659bf06e 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -218,6 +218,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> { let assemblyFormat = "attr-dict `:` type($out)"; let hasFolder = 1; + let hasVerifier = 1; } def Tcp_BroadcastOp : Tcp_Op<"broadcast", [ @@ -657,6 +658,8 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] ); let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)"; + + let hasVerifier = 1; } def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> { diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 312b9e9f..004bc9c0 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -278,6 +278,79 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { } }; +class ConvertAtenIndexTensorHackedTwin + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // ------- Matching the OP ------- + auto self = adaptor.getSelf(); + auto selfType = cast(self.getType()); + auto indicesList = op.getIndices(); + SmallVector indices; + if (!getListConstructElements(indicesList, indices)) + return op.emitError("Failed to match list of indices"); + + for (unsigned int i = 0; i < indices.size(); i++) { + auto ttype = cast( + getTypeConverter()->convertType(indices[i].getType())); + if (ttype.getRank() != selfType.getRank() - i) { + // Can use tensor.gather instead for this. But will require that there + // are some broadcasting to get the shapes to match what is expected + return failure("Failed to rewrite Tensor_hacked_twin. Need the " + "element gather for this"); + } + for (int j = 1; j < ttype.getRank(); j++) { + if (ttype.getShape()[j] != 1) + return failure("Expected the axes >=1 to have size 1"); + } + } + + // ------ Rewriting the OP --------- + + indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), + indices); + + for (unsigned int i = 0; i < indices.size(); i++) { + auto idx = indices[i]; + auto ttype = cast(idx.getType()); + auto selfType = cast(self.getType()); + SmallVector outShape(selfType.getShape()); + outShape[i] = ttype.getNumElements(); + auto outType = RankedTensorType::get( + outShape, cast(self.getType()).getElementType()); + + auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims( + rewriter, idx, outShape.size() - ttype.getRank()); + + SmallVector broadcastValues; + SmallVector broadcastAxes; + for (unsigned int j = 0; j < selfType.getRank(); j++) { + if (j != i) { + broadcastAxes.push_back(j); + broadcastValues.push_back( + rewriter.create(op.getLoc(), self, j)); + } + } + + auto broadcastedShape = rewriter.create( + op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()), + expandedShape, broadcastValues, + rewriter.getI64ArrayAttr(broadcastAxes)); + + auto gather = rewriter.create(op.getLoc(), outType, self, + broadcastedShape.getResult(), + rewriter.getIndexAttr(i)); + self = gather.getResult(); + } + + rewriter.replaceOp(op, self); + return success(); + } +}; + } // namespace void torch_to_tcp::populateDataMovementPatternsAndLegality( @@ -294,4 +367,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality( torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< + ConvertAtenIndexTensorHackedTwin, AtenIndexTensorHackedTwinOp>( + typeConverter, patterns, target, convertTorchOpsSet); } diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index c66453a3..fdde4d43 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -16,6 +16,7 @@ #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -162,6 +163,16 @@ class ConvertValueTensorLiteralOp rewriter.replaceOpWithNewOp(op, resultType, denseIntAttr); return success(); } + if (auto elements = + dyn_cast(op.getValueAttr())) { + if (resultType.getElementType().isInteger() && + resultType != adaptor.getValue().getType()) { + auto attr = + DenseResourceElementsAttr::get(resultType, elements.getRawHandle()); + rewriter.replaceOpWithNewOp(op, resultType, attr); + return success(); + } + } rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue()); diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 3f3468d2..88ffe21d 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -29,6 +29,7 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { + class ConvertAtenGatherOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -46,33 +47,6 @@ class ConvertAtenGatherOp : public OpConversionPattern { return helper.replace(); } }; - -class ConvertAtenIndexTensorHackedTwinOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, - getTypeConverter()}; - - Value input = adaptor.getSelf(); - auto inputTensorType = input.getType().dyn_cast(); - // Check input is a tensor type. - if (!inputTensorType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - - helper.addOperand("self", input); - helper.addAsMultipleTensorOperands("index_", op.getIndices()); - - return helper.replace(); - } -}; - class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: @@ -381,7 +355,6 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ typeConverter, patterns, target, convertTorchOpsSet) INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenGatherOp); - INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerTensorAffineOp); INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN( diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index 85a52673..479d580c 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -49,6 +49,8 @@ Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) { // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { + if (rankIncrease == 0) + return input; RankedTensorType inputType = input.getType().cast(); SmallVector reassociationMap(inputType.getRank()); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 3a8e5afa..80549217 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() { OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); } +LogicalResult ConstOp::verify() { + if (getValueAttr().getType() != getType()) + return emitOpError("can not be used to cast types"); + return success(); +} + LogicalResult CastOp::verify() { auto inputType = getIn().getType().cast(); auto outputType = getOut().getType().cast(); @@ -170,6 +176,42 @@ LogicalResult CastOp::verify() { return success(); } +LogicalResult GatherOp::verify() { + auto inputTensor = cast(getInput().getType()); + auto indicesTensor = cast(getIndices().getType()); + int64_t gatherDim = getDimAttr().getValue().getSExtValue(); + + if (inputTensor.getRank() != indicesTensor.getRank()) + return emitOpError( + "requires that the input tensor and indices are the same rank"); + + for (int i = 0; i < inputTensor.getRank(); i++) { + if (inputTensor.getShape()[i] < indicesTensor.getShape()[i] && + !(inputTensor.getShape()[i] == ShapedType::kDynamic || + indicesTensor.getShape()[i] == ShapedType::kDynamic || + i == gatherDim)) { + std::stringstream ss; + ss << "indicies index " << i + << " expected to be less than or equal to input " + << " (" << indicesTensor.getShape()[i] + << " <= " << inputTensor.getShape()[i] << ")"; + return emitOpError(ss.str()); + } + } + + if (getResult().getType().getShape() != indicesTensor.getShape()) { + return emitOpError( + "Expect the shape of the indicies to match the output shape"); + } + + if (getResult().getType().getElementType() != inputTensor.getElementType()) { + return emitOpError( + "Expect the element type of the return to match the input"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // BindSymbolicShapeOp //===----------------------------------------------------------------------===// diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index cd792579..633c8f23 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -37,6 +37,7 @@ AOT_TEST_SUITE = [ ("broadcast_unit_dim_to_dynamic_with_rank_increase", False), ("gather_elements", False), ("gather_slices", False), + ("index_hacked_twin", False), ] py_library( diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 1b3e88ca..68174c85 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -590,3 +590,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return TorchLoaderOutput( model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) + + +def index_hacked_twin_loader() -> TorchLoaderOutput: + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # not using dynamic dim currently as the i1 tensor would ideally + # be generated conditioned on the shape + i1 = torch.tensor([[0], [1], [2], [3]]) + return x[i1, [2, 5, 7]] + + x = torch.rand(4, 10) + + return TorchLoaderOutput( + model=Model(), + inputs=(x,), + ) diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index 4e762985..b087f9a5 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -64,3 +64,29 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor %0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32> return %0 : !torch.vtensor<[4,2],f32> } + +// ----- + +// CHECK-label: @torch.aten.index.tensor_hacked_twin +// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32> +// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32> +// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32> +// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]] +// CHECK: return %[[RET]] +func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> { + // there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses + // arange to select all of the + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 // this is a dtype on arange.... + %int-1 = torch.constant.int -1 + %arange = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> + %arange1 = torch.aten.unsqueeze %arange, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + %arange2 = torch.aten.unsqueeze %arange1, %int-1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64> + + %l = torch.prim.ListConstruct %arange2, %select1, %select2 : (!torch.vtensor<[1,1,1],si64>, !torch.vtensor<[5,1],si64>, !torch.vtensor<[20],si64>) -> !torch.list + %ret = torch.aten.index.Tensor_hacked_twin %arg0, %l : !torch.vtensor<[1,20,30],f32>, !torch.list -> !torch.vtensor<[1,5,20],f32> + return %ret : !torch.vtensor<[1,5,20],f32> +} diff --git a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir index bb846314..9bfcba93 100644 --- a/test/Conversion/TorchToTcp/tcp_custom_ops.mlir +++ b/test/Conversion/TorchToTcp/tcp_custom_ops.mlir @@ -1,43 +1,5 @@ // RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s -// CHECK-LABEL: func.func @torch.aten.gather_op( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64> -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { -// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2,2],f32> -> tensor<2x2xf32> -// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,2],si64> -> tensor<2x2xi64> -// CHECK: %[[T2:.*]] = tcp.custom_op("torch.aten.gather") %[[T0]], %[[T1]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32> -// CHECK: return %[[T3]] : !torch.vtensor<[2,2],f32> -func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> - return %0 : !torch.vtensor<[2,2],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.index_hacked_twin_op( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,30,19,41],f32> -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,1,1,1],si64> -// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[30,1,1],si64> -// CHECK-SAME: %[[ARG3:.*]]: !torch.vtensor<[19,1],si64> -// CHECK-SAME: %[[ARG4:.*]]: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,30,19,41],f32> -> tensor<1x30x19x41xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,1,1,1],si64> -> tensor<1x1x1x1xi64> -// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[30,1,1],si64> -> tensor<30x1x1xi64> -// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[ARG3]] : !torch.vtensor<[19,1],si64> -> tensor<19x1xi64> -// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[ARG4]] : !torch.vtensor<[3],si64> -> tensor<3xi64> -// CHECK: %[[T5:.*]] = tcp.custom_op("torch.aten.index.Tensor_hacked_twin") %[[T0]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] {torch_operand_names = ["self", "index_0", "index_1", "index_2", "index_3"]} : tensor<1x30x19x41xf32>, tensor<1x1x1x1xi64>, tensor<30x1x1xi64>, tensor<19x1xi64>, tensor<3xi64> -> tensor<1x30x19x3xf32> -// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x30x19x3xf32> -> !torch.vtensor<[1,30,19,3],f32> -// CHECK: return %[[T6]] : !torch.vtensor<[1,30,19,3],f32> -func.func @torch.aten.index_hacked_twin_op(%arg0: !torch.vtensor<[1,30,19,41],f32>, %arg1: !torch.vtensor<[1,1,1,1],si64>, %arg2: !torch.vtensor<[30,1,1],si64>, %arg3: !torch.vtensor<[19,1],si64>, %arg4: !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,30,19,3],f32> { - %0 = torch.prim.ListConstruct %arg1, %arg2, %arg3, %arg4 : (!torch.vtensor<[1,1,1,1],si64>, !torch.vtensor<[30,1,1],si64>, !torch.vtensor<[19,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list - %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[1,30,19,41],f32>, !torch.list -> !torch.vtensor<[1,30,19,3],f32> - return %1 : !torch.vtensor<[1,30,19,3],f32> -} - -// ----- // CHECK-LABEL: func.func @torch.aten.index_put_impl_op( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[25],f32> diff --git a/tools/aot/aot_compile.bzl b/tools/aot/aot_compile.bzl index 63660080..fb23e494 100644 --- a/tools/aot/aot_compile.bzl +++ b/tools/aot/aot_compile.bzl @@ -146,24 +146,22 @@ def aot_compile( native.genrule( name = "gen_" + name + "_mlir_tcp", - srcs = [_name + "_torch.mlir"], + srcs = [_name + "_torch.mlir", "//:tcp-opt"], outs = [_name + "_tcp.mlir"], cmd = "./$(location //:tcp-opt)" + - " -torch-backend-to-tcp-backend-pipeline $(SRCS)" + + " -torch-backend-to-tcp-backend-pipeline $(location " + _name + "_torch.mlir)" + " > $(OUTS)", - tools = ["//:tcp-opt"], ) native.genrule( name = "gen_" + name + "_mlir_llvm", # When tcp_source is provided, prefer that as the start for aot_compile; # else continue using genrule generated *_tcp.mlir (torch_export workflow) - srcs = [tcp_source or (_name + "_tcp.mlir")], + srcs = [tcp_source or (_name + "_tcp.mlir"), "//:tcp-opt"], outs = [_name + "_llvm.mlir"], cmd = "./$(location //:tcp-opt)" + - " -tcp-to-llvm-pipeline $(SRCS)" + + " -tcp-to-llvm-pipeline $(location " + (tcp_source or (_name + "_tcp.mlir")) + ")" + " > $(OUTS)", - tools = ["//:tcp-opt"], ) native.genrule( @@ -181,7 +179,7 @@ def aot_compile( name = "gen_" + name + "_host_asm", srcs = [_name + ".ll"], outs = [_name + ".S"], - cmd = "./$(location @llvm-project//llvm:llc) -O3 < $(SRCS)" + + cmd = "./$(location @llvm-project//llvm:llc) -O3 --relocation-model=pic < $(SRCS)" + " > $(OUTS)", tools = ["@llvm-project//llvm:llc"], )