From 46d3066e44c9f9867a5bfc334d61ea611a3eb117 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 18 Oct 2024 08:04:31 -0700 Subject: [PATCH 1/5] gather nd and tensor_hacked_twin using gather nd --- include/mlir-tcp/Dialect/IR/TcpOps.td | 24 ++++ lib/Conversion/TcpToLinalg/DataMovement.cpp | 86 +++++++++++++ lib/Conversion/TorchToTcp/DataMovement.cpp | 86 ++++++------- lib/Conversion/TorchToTcp/Utils.cpp | 118 ++++++++++++++++++ lib/Conversion/TorchToTcp/Utils.h | 11 ++ lib/Dialect/IR/TcpOps.cpp | 40 ++++++ .../Conversion/TcpToLinalg/data_movement.mlir | 21 ++++ test/Conversion/TorchToTcp/data_movement.mlir | 27 ++-- 8 files changed, 357 insertions(+), 56 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 659bf06e..4c900b32 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -662,6 +662,30 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] let hasVerifier = 1; } +def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> { + + let summary = "Gather elements from input based on indices over numtiple dimentions"; + + let description = [{ + Gathers elements from a given tensor based on indices that index along multiple dimensions. + + More details regarding this op: docs/gather.md + }]; + + let arguments = (ins + Tcp_Tensor:$input, + Tcp_IntTensor:$indices + ); + + let results = (outs + Tcp_Tensor:$out + ); + + let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)"; + + let hasVerifier = 1; +} + def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> { let summary = "Extracts a slice of the input tensor"; diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp index f3ba7895..c3a13297 100644 --- a/lib/Conversion/TcpToLinalg/DataMovement.cpp +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -91,6 +91,90 @@ class ConvertGatherOp : public OpConversionPattern { } }; +class ConvertGatherNDOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GatherNDOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTensorType = getTypeConverter() + ->convertType(op.getOut().getType()) + .cast(); + + auto inputTensor = adaptor.getInput(); + auto indicesTensor = adaptor.getIndices(); + auto indiciesType = cast(indicesTensor.getType()); + auto inputType = cast(inputTensor.getType()); + int numGatherAxes = indiciesType.getShape()[indiciesType.getRank() - 1]; + + SmallVector resultDimSizes; + for (int i = 0; i < indiciesType.getRank() - 1; i++) { + resultDimSizes.push_back( + rewriter.createOrFold(loc, indicesTensor, i)); + } + for (int i = numGatherAxes; i < inputType.getRank(); i++) { + resultDimSizes.push_back( + rewriter.createOrFold(loc, inputTensor, i)); + } + + assert(resultDimSizes.size() == resultTensorType.getRank()); + + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultDimSizes), + resultTensorType.getElementType()); + + auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + SmallVector valueIndices, gatherIndices; + for (int i = 0; i < indiciesType.getRank() - 1; i++) { + auto idx = b.create(loc, b.getIndexType(), + b.getI64IntegerAttr(i)); + gatherIndices.push_back(idx); + } + for (int i = 0; i < numGatherAxes; i++) { + SmallVector gi = gatherIndices; + auto gidx = b.create(loc, b.getIndexAttr(i)); + gi.push_back(gidx); + assert(gi.size() == indiciesType.getRank()); + auto idxExtract = b.create( + loc, indiciesType.getElementType(), indicesTensor, gi); + auto idxCast = + b.create(loc, b.getIndexType(), idxExtract); + valueIndices.push_back(idxCast); + } + for (int i = indiciesType.getRank() - 1; i < resultTensorType.getRank(); + i++) { + auto idx = b.create(loc, b.getIndexType(), + b.getI64IntegerAttr(i)); + valueIndices.push_back(idx); + } + assert(valueIndices.size() == inputType.getRank()); + auto extract = + b.create(loc, resultTensorType.getElementType(), + inputTensor, valueIndices) + .getResult(); + + b.create(loc, extract); + }; + + SmallVector empty; + SmallVector indexingMaps; + indexingMaps.push_back( + rewriter.getMultiDimIdentityMap(resultTensorType.getRank())); + SmallVector iteratorTypes( + resultTensorType.getRank(), utils::IteratorType::parallel); + + auto generic = rewriter.create( + loc, resultTensorType, empty, emptyTensor, indexingMaps, iteratorTypes, + bodyBuilder); + + rewriter.replaceOp(op, generic.getResult(0)); + + return success(); + } +}; + } // namespace void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality( @@ -100,4 +184,6 @@ void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 004bc9c0..c6466c8a 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -278,6 +278,11 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { } }; +/** + * The index.Tensor_hacked_twin takes a list of tensors which have to be + * broadcast together to be the same shape, and then those are feed into a + * gather which will select the different axes + */ class ConvertAtenIndexTensorHackedTwin : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -285,68 +290,51 @@ class ConvertAtenIndexTensorHackedTwin LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // ------- Matching the OP ------- auto self = adaptor.getSelf(); - auto selfType = cast(self.getType()); auto indicesList = op.getIndices(); SmallVector indices; if (!getListConstructElements(indicesList, indices)) return op.emitError("Failed to match list of indices"); - for (unsigned int i = 0; i < indices.size(); i++) { - auto ttype = cast( - getTypeConverter()->convertType(indices[i].getType())); - if (ttype.getRank() != selfType.getRank() - i) { - // Can use tensor.gather instead for this. But will require that there - // are some broadcasting to get the shapes to match what is expected - return failure("Failed to rewrite Tensor_hacked_twin. Need the " - "element gather for this"); - } - for (int j = 1; j < ttype.getRank(); j++) { - if (ttype.getShape()[j] != 1) - return failure("Expected the axes >=1 to have size 1"); - } + indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), + indices); + + if (auto indiciesBroadcasted = torch_to_tcp::broadcastManyToMatchShape( + rewriter, op.getLoc(), indices)) { + indices = indiciesBroadcasted.value(); + } else { + return failure("failed to broadcast the shapes of the input indicies"); } - // ------ Rewriting the OP --------- + for (int i = 0; i < indices.size(); i++) { + indices[i] = + torch_to_tcp::broadcastRankInTrailingDims(rewriter, indices[i], 1); + } - indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), - indices); + auto indicesType = cast(indices[0].getType()); + int indicesRank = indicesType.getRank(); + SmallVector outIndexShape; + outIndexShape.insert(outIndexShape.begin(), indicesType.getShape().begin(), + indicesType.getShape().end()); + outIndexShape.back() = indices.size(); - for (unsigned int i = 0; i < indices.size(); i++) { - auto idx = indices[i]; - auto ttype = cast(idx.getType()); - auto selfType = cast(self.getType()); - SmallVector outShape(selfType.getShape()); - outShape[i] = ttype.getNumElements(); - auto outType = RankedTensorType::get( - outShape, cast(self.getType()).getElementType()); - - auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims( - rewriter, idx, outShape.size() - ttype.getRank()); - - SmallVector broadcastValues; - SmallVector broadcastAxes; - for (unsigned int j = 0; j < selfType.getRank(); j++) { - if (j != i) { - broadcastAxes.push_back(j); - broadcastValues.push_back( - rewriter.create(op.getLoc(), self, j)); - } - } + auto outIndexType = + RankedTensorType::get(outIndexShape, indicesType.getElementType()); + auto indexTensor = + rewriter + .create( + op.getLoc(), outIndexType, + rewriter.getI64IntegerAttr(indicesRank - 1), indices) + .getResult(); - auto broadcastedShape = rewriter.create( - op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()), - expandedShape, broadcastValues, - rewriter.getI64ArrayAttr(broadcastAxes)); + auto outType = + cast(getTypeConverter()->convertType(op.getType())); - auto gather = rewriter.create(op.getLoc(), outType, self, - broadcastedShape.getResult(), - rewriter.getIndexAttr(i)); - self = gather.getResult(); - } + auto gatherOp = rewriter.create(op.getLoc(), outType, self, + indexTensor); + + rewriter.replaceOp(op, gatherOp); - rewriter.replaceOp(op, self); return success(); } }; diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index 5249287f..da4c8ef3 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -72,6 +72,32 @@ Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, input.getDefiningOp()->getLoc(), resultType, input, reassociationMap); } +// The parameter input is expected to be of RankedTensorType. +Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter, + Value input, int64_t rankIncrease) { + if (rankIncrease == 0) + return input; + RankedTensorType inputType = input.getType().cast(); + + SmallVector reassociationMap(inputType.getRank()); + if (inputType.getRank() > 0) { + for (int64_t inputAxis = 0; inputAxis < inputType.getRank(); inputAxis++) + reassociationMap[inputAxis].push_back( + rewriter.getAffineDimExpr(inputAxis)); + for (int64_t axis = 0; axis < rankIncrease; axis++) + reassociationMap.back().push_back( + rewriter.getAffineDimExpr(axis + inputType.getRank())); + } + + SmallVector resultShape(inputType.getShape()); + resultShape.insert(resultShape.end(), rankIncrease, 1); + auto resultType = + inputType.cloneWith(ArrayRef(resultShape), inputType.getElementType()); + + return rewriter.create( + input.getDefiningOp()->getLoc(), resultType, input, reassociationMap); +} + Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input, int64_t targetRank, int64_t axisInOutput) { RankedTensorType inputType = input.getType().cast(); @@ -130,6 +156,98 @@ Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input, axesAttr); } +// the parameter values is expected to be an array of RankedTensorType tensors +std::optional> +broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, + ValueRange values) { + if (values.size() <= 1) { + return values; + } + SmallVector ret; + + int64_t maxRank = 0; + for (auto v : values) { + assert(isa(v.getType()) && "assert 1"); + auto t = cast(v.getType()); + if (t.getRank() > maxRank) + maxRank = t.getRank(); + } + + for (auto v : values) { + auto type = cast(v.getType()); + v = broadcastRankInLeadingDims(rewriter, v, maxRank - type.getRank()); + ret.push_back(v); + } + + // figure out what the shape should be for each dim + struct ShapeInfo { + Value value; + bool found = false; + int64_t static_value = 1; + }; + SmallVector shapes(maxRank); + + for (auto v : ret) { + auto t = cast(v.getType()); + auto shape = t.getShape(); + for (int64_t i = 0; i < maxRank; i++) { + if (shape[i] != 1) { + // meaning that this is not something that is already 1, and therefore + // would get broadcast + if (shapes[i].found) { + // then there are multiple inputs which have non-1 values for this + // axis we should check that the size is the same. If there are + // different shapes then this would result in an error when + // broadcasting + if (shape[i] != ShapedType::kDynamic && + shapes[i].static_value != ShapedType::kDynamic && + shapes[i].static_value != shape[i]) { + // the broadcast failed as there are two different shapes for this + llvm::errs() + << "failed with broadcasting, have two different shapes " + << shape[i] << " " << shapes[i].static_value << "\n"; + return {}; + } + } else { + shapes[i].found = true; + if (shape[i] == ShapedType::kDynamic) { + shapes[i].value = rewriter.create(loc, v, i); + shapes[i].static_value = ShapedType::kDynamic; + } else { + shapes[i].value = rewriter.create( + loc, rewriter.getIndexAttr(shape[i])); + shapes[i].static_value = shape[i]; + } + } + } + } + } + + // do the broadcasts into the shapes + for (int64_t i = 0; i < ret.size(); i++) { + auto v = ret[i]; + auto t = cast(v.getType()); + SmallVector axes; + SmallVector sizes; + SmallVector staticShape; + for (int64_t j = 0; j < maxRank; j++) { + if (t.getShape()[j] == 1 && shapes[j].found) { + axes.push_back(j); + sizes.push_back(shapes[j].value); + } + staticShape.push_back(shapes[j].static_value); + } + if (!axes.empty()) { + // there is something to broadcast here, so add the op + Type resultType = t.cloneWith(staticShape, t.getElementType()); + ret[i] = rewriter.create( + loc, resultType, ret[i], sizes, rewriter.getI64ArrayAttr(axes)); + } + } + + return ret; +} + // The parameters input are expected to be of RankedTensorType. std::pair broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs, diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index 23bf0293..da636590 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -31,6 +31,11 @@ getTcpSignedness(IntegerType::SignednessSemantics signednessInfo); Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease); +// Helper function to expand the rank of the input tensor. Works by +// adding 1-dim shape to the trailing dims using `tensor::ExpandShapeOp`. +Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter, + Value input, int64_t rankIncrease); + // Broadcasts the rank of the input tensor from 0D or 1D to ND. If the input // tensor is 1D, `axisInOutput` specifies the axis where the input axis should // end up in the output. @@ -49,6 +54,12 @@ std::pair broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs, Value rhs); +// Helper function that broadcasts two or more tensors used for indexing to be +// the same shape If a tensor is 1-dim, then it will be used on its index +std::optional> +broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, + ValueRange values); + // Helper function to broadcast a 0D or 1D input tensor to match rank and shape // of target. For the 1D case, this projects the input vector to the // `axisInOutput` in the result. diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 80549217..fd01fbf5 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -212,6 +212,46 @@ LogicalResult GatherOp::verify() { return success(); } +LogicalResult GatherNDOp::verify() { + auto inputTensor = cast(getInput().getType()); + auto indicesTensor = cast(getIndices().getType()); + + if (indicesTensor.getRank() < 1) + return emitError("indicies tensor should have a rank of at least one"); + if (indicesTensor.getShape()[indicesTensor.getRank() - 1] == + ShapedType::kDynamic) + return emitError( + "Last dimension of the indicies tensor can not be dynamic"); + if (indicesTensor.getShape()[indicesTensor.getRank() - 1] > + inputTensor.getRank()) + return emitError("The last dimension of the indicies tensor should be used " + "to index into the input tensor. Its shape is too large"); + + SmallVector outputShape; + for (int i = 0; i < indicesTensor.getRank() - 1; i++) + outputShape.push_back(indicesTensor.getShape()[i]); + for (int i = indicesTensor.getShape()[indicesTensor.getRank() - 1]; + i < inputTensor.getRank(); i++) + outputShape.push_back(inputTensor.getShape()[i]); + + auto outputType = + RankedTensorType::get(outputShape, inputTensor.getElementType()); + + if (outputType != getResult().getType()) { + std::string ss = + "Output shape of tcp.gather_nd does not match what is expected "; + llvm::raw_string_ostream rs(ss); + outputType.print(rs); + rs.flush(); + ss += " != "; + getResult().getType().print(rs); + rs.flush(); + return emitError(ss); + } + + return success(); +} + //===----------------------------------------------------------------------===// // BindSymbolicShapeOp //===----------------------------------------------------------------------===// diff --git a/test/Conversion/TcpToLinalg/data_movement.mlir b/test/Conversion/TcpToLinalg/data_movement.mlir index ee85853f..522be277 100644 --- a/test/Conversion/TcpToLinalg/data_movement.mlir +++ b/test/Conversion/TcpToLinalg/data_movement.mlir @@ -18,3 +18,24 @@ func.func @gather(%arg0 : tensor<1x4x3xf32>, %arg1 : tensor<1x4x2xi64>) -> tenso %0 = "tcp.gather"(%arg0, %arg1) {dim = 2 : index} : (tensor<1x4x3xf32>, tensor<1x4x2xi64>) -> tensor<1x4x2xf32> return %0 : tensor<1x4x2xf32> } + +// ----- + +// CHECK-LABEL: func.func @gatherND +// CHECK: %[[ret:.+]] = linalg.generic +// CHECK-DAG: %[[idx0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[const0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[gather0:.+]] = tensor.extract %arg1[%[[idx0]], %[[const0]]] : tensor<3x2xi64> +// CHECK-DAG: %[[gather0cast:.+]] = arith.index_cast %[[gather0]] : i64 to index +// CHECK-DAG: %[[const1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[gather1:.+]] = tensor.extract %arg1[%[[idx0]], %[[const1]]] : tensor<3x2xi64> +// CHECK-DAG: %[[gather1cast:.+]] = arith.index_cast %[[gather1]] : i64 to index +// CHECK-DAG: %[[idx1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[idx2:.+]] = linalg.index 2 : index +// CHECK-DAG: %[[value:.+]] = tensor.extract %arg0[%[[gather0cast]], %[[gather1cast]], %[[idx1]], %[[idx2]]] : tensor<7x11x13x17xf32> +// CHECK: linalg.yield %[[value]] : f32 +// CHECK: } -> tensor<3x13x17xf32> +func.func @gatherND(%arg0 : tensor<7x11x13x17xf32>, %arg1 : tensor<3x2xi64>) -> tensor<3x13x17xf32> { + %0 = "tcp.gather_nd" (%arg0, %arg1) : (tensor<7x11x13x17xf32>, tensor<3x2xi64>) -> tensor<3x13x17xf32> + return %0 : tensor<3x13x17xf32> +} diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index b087f9a5..938595d1 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -65,15 +65,28 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor return %0 : !torch.vtensor<[4,2],f32> } + // ----- -// CHECK-label: @torch.aten.index.tensor_hacked_twin -// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32> -// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32> -// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32> -// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]] -// CHECK: return %[[RET]] +// CHECK-LABEL: @torch.aten.index.tensor_hacked_twin +// CHECK: %[[A0:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,20,30],f32> -> tensor<1x20x30xf32> +// CHECK-DAG: %[[v5:.+]] = torch_c.to_builtin_tensor %[[arange:.+]] : !torch.vtensor<[1,1,1],si64> -> tensor<1x1x1xi64> +// CHECK-DAG: %[[v6:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[5,1],si64> -> tensor<5x1xi64> +// CHECK-DAG: %[[v7:.+]] = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[20],si64> -> tensor<20xi64> +// CHECK-DAG: %[[expanded:.+]] = tensor.expand_shape %[[v6]] {{\[\[}}0, 1], [2]] output_shape [1, 5, 1] : tensor<5x1xi64> into tensor<1x5x1xi64> +// CHECK-DAG: %[[expanded_0:.+]] = tensor.expand_shape %[[v7]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 20] : tensor<20xi64> into tensor<1x1x20xi64> +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[v8:.+]] = tcp.broadcast %[[v5]], %[[c5]], %[[c20]] {axes = [1, 2]} : tensor<1x1x1xi64>, index, index -> tensor<1x5x20xi64> +// CHECK-DAG: %[[v9:.+]] = tcp.broadcast %[[expanded]], %[[c20]] {axes = [2]} : tensor<1x5x1xi64>, index -> tensor<1x5x20xi64> +// CHECK-DAG: %[[v10:.+]] = tcp.broadcast %[[expanded_0]], %[[c5]] {axes = [1]} : tensor<1x1x20xi64>, index -> tensor<1x5x20xi64> +// CHECK-DAG: %[[expanded_1:.+]] = tensor.expand_shape %[[v8]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 5, 20, 1] : tensor<1x5x20xi64> into tensor<1x5x20x1xi64> +// CHECK-DAG: %[[expanded_2:.+]] = tensor.expand_shape %[[v9]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 5, 20, 1] : tensor<1x5x20xi64> into tensor<1x5x20x1xi64> +// CHECK-DAG: %[[expanded_3:.+]] = tensor.expand_shape %[[v10]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 5, 20, 1] : tensor<1x5x20xi64> into tensor<1x5x20x1xi64> +// CHECK: %[[concat:.+]] = tensor.concat dim(3) %[[expanded_1]], %[[expanded_2]], %[[expanded_3]] : (tensor<1x5x20x1xi64>, tensor<1x5x20x1xi64>, tensor<1x5x20x1xi64>) -> tensor<1x5x20x3xi64> +// CHECK: %[[gather:.+]] = tcp.gather_nd %[[A0]], %[[concat]] : tensor<1x20x30xf32>, tensor<1x5x20x3xi64> -> tensor<1x5x20xf32> +// CHECK: %[[ret:.+]] = torch_c.from_builtin_tensor %11 : tensor<1x5x20xf32> -> !torch.vtensor<[1,5,20],f32> +// CHECK: return %[[ret]] func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> { // there is a strange pattern that is being generated when selecting one axis. It seems that it uses the Tensor_hacked_twin to select along all axis, but uses // arange to select all of the From af01815fa18923e7e3c00038e7dc41d3fc05d0b4 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Mon, 21 Oct 2024 10:39:05 -0700 Subject: [PATCH 2/5] address comments --- include/mlir-tcp/Dialect/IR/TcpOps.td | 2 +- lib/Conversion/TcpToLinalg/DataMovement.cpp | 25 ++++++++++++----- lib/Conversion/TorchToTcp/Utils.cpp | 30 ++++++++++----------- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 4c900b32..0bafcba5 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -664,7 +664,7 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> { - let summary = "Gather elements from input based on indices over numtiple dimentions"; + let summary = "Gather elements from input based on indices over multiple dimentions"; let description = [{ Gathers elements from a given tensor based on indices that index along multiple dimensions. diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp index c3a13297..c9e7442b 100644 --- a/lib/Conversion/TcpToLinalg/DataMovement.cpp +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -91,6 +91,17 @@ class ConvertGatherOp : public OpConversionPattern { } }; +/** + * tcp.gather_nd is lowered to linalg.generic, which allows us to define every + * element in the result tensor using a programmatic expression. The last + * dimension of the indicies tensor is used to index into the input tensor. + * + * For example, we we have an indices tensor of shape 9x4x3x2 and an input + * tensor of shape 5x6x7x8, then the resulting tensor will be of shape + * 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used + * to index into the indicies tensor. Then the last dimension of the index + * tensor (the 2 sized dimension) is used to index into the input tensor. + */ class ConvertGatherNDOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -105,12 +116,12 @@ class ConvertGatherNDOp : public OpConversionPattern { auto inputTensor = adaptor.getInput(); auto indicesTensor = adaptor.getIndices(); - auto indiciesType = cast(indicesTensor.getType()); + auto indicesType = cast(indicesTensor.getType()); auto inputType = cast(inputTensor.getType()); - int numGatherAxes = indiciesType.getShape()[indiciesType.getRank() - 1]; + int numGatherAxes = indicesType.getShape().back(); SmallVector resultDimSizes; - for (int i = 0; i < indiciesType.getRank() - 1; i++) { + for (int i = 0; i < indicesType.getRank() - 1; i++) { resultDimSizes.push_back( rewriter.createOrFold(loc, indicesTensor, i)); } @@ -127,7 +138,7 @@ class ConvertGatherNDOp : public OpConversionPattern { auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { SmallVector valueIndices, gatherIndices; - for (int i = 0; i < indiciesType.getRank() - 1; i++) { + for (int i = 0; i < indicesType.getRank() - 1; i++) { auto idx = b.create(loc, b.getIndexType(), b.getI64IntegerAttr(i)); gatherIndices.push_back(idx); @@ -136,14 +147,14 @@ class ConvertGatherNDOp : public OpConversionPattern { SmallVector gi = gatherIndices; auto gidx = b.create(loc, b.getIndexAttr(i)); gi.push_back(gidx); - assert(gi.size() == indiciesType.getRank()); + assert(gi.size() == indicesType.getRank()); auto idxExtract = b.create( - loc, indiciesType.getElementType(), indicesTensor, gi); + loc, indicesType.getElementType(), indicesTensor, gi); auto idxCast = b.create(loc, b.getIndexType(), idxExtract); valueIndices.push_back(idxCast); } - for (int i = indiciesType.getRank() - 1; i < resultTensorType.getRank(); + for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank(); i++) { auto idx = b.create(loc, b.getIndexType(), b.getI64IntegerAttr(i)); diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index da4c8ef3..65fe948b 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -180,12 +180,12 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, } // figure out what the shape should be for each dim - struct ShapeInfo { + struct DimInfo { Value value; bool found = false; - int64_t static_value = 1; + int64_t staticValue = 1; }; - SmallVector shapes(maxRank); + SmallVector resultShape(maxRank); for (auto v : ret) { auto t = cast(v.getType()); @@ -194,29 +194,29 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, if (shape[i] != 1) { // meaning that this is not something that is already 1, and therefore // would get broadcast - if (shapes[i].found) { + if (resultShape[i].found) { // then there are multiple inputs which have non-1 values for this // axis we should check that the size is the same. If there are // different shapes then this would result in an error when // broadcasting if (shape[i] != ShapedType::kDynamic && - shapes[i].static_value != ShapedType::kDynamic && - shapes[i].static_value != shape[i]) { + resultShape[i].staticValue != ShapedType::kDynamic && + resultShape[i].staticValue != shape[i]) { // the broadcast failed as there are two different shapes for this llvm::errs() << "failed with broadcasting, have two different shapes " - << shape[i] << " " << shapes[i].static_value << "\n"; + << shape[i] << " " << resultShape[i].staticValue << "\n"; return {}; } } else { - shapes[i].found = true; + resultShape[i].found = true; if (shape[i] == ShapedType::kDynamic) { - shapes[i].value = rewriter.create(loc, v, i); - shapes[i].static_value = ShapedType::kDynamic; + resultShape[i].value = rewriter.create(loc, v, i); + resultShape[i].staticValue = ShapedType::kDynamic; } else { - shapes[i].value = rewriter.create( + resultShape[i].value = rewriter.create( loc, rewriter.getIndexAttr(shape[i])); - shapes[i].static_value = shape[i]; + resultShape[i].staticValue = shape[i]; } } } @@ -231,11 +231,11 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, SmallVector sizes; SmallVector staticShape; for (int64_t j = 0; j < maxRank; j++) { - if (t.getShape()[j] == 1 && shapes[j].found) { + if (t.getShape()[j] == 1 && resultShape[j].found) { axes.push_back(j); - sizes.push_back(shapes[j].value); + sizes.push_back(resultShape[j].value); } - staticShape.push_back(shapes[j].static_value); + staticShape.push_back(resultShape[j].staticValue); } if (!axes.empty()) { // there is something to broadcast here, so add the op From f0dd927731b4a22ad4312071261b85532c95d79b Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Mon, 4 Nov 2024 11:55:21 -0800 Subject: [PATCH 3/5] add cast to tensor hacked twin lowering --- lib/Conversion/TorchToTcp/DataMovement.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index c6466c8a..8efa4275 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -307,8 +307,14 @@ class ConvertAtenIndexTensorHackedTwin } for (int i = 0; i < indices.size(); i++) { - indices[i] = + auto v = torch_to_tcp::broadcastRankInTrailingDims(rewriter, indices[i], 1); + indices[i] = rewriter.createOrFold( + op.getLoc(), + RankedTensorType::get(cast(v.getType()).getShape(), + rewriter.getI64Type()), + v, SignednessAttr::get(op->getContext(), Signedness::Signed), + SignednessAttr::get(op->getContext(), Signedness::Signless)); } auto indicesType = cast(indices[0].getType()); From aacbd68b70bbe7c4ba7f46d78a5c76c2c319e312 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Mon, 4 Nov 2024 12:09:19 -0800 Subject: [PATCH 4/5] cast non int64 tensors to be int64 --- lib/Conversion/TorchToTcp/DataMovement.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 8efa4275..29b0c27b 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -307,14 +307,18 @@ class ConvertAtenIndexTensorHackedTwin } for (int i = 0; i < indices.size(); i++) { - auto v = + Value v = torch_to_tcp::broadcastRankInTrailingDims(rewriter, indices[i], 1); - indices[i] = rewriter.createOrFold( - op.getLoc(), - RankedTensorType::get(cast(v.getType()).getShape(), - rewriter.getI64Type()), - v, SignednessAttr::get(op->getContext(), Signedness::Signed), - SignednessAttr::get(op->getContext(), Signedness::Signless)); + if (!cast(v.getType()).getElementType().isInteger(64)) { + v = rewriter.createOrFold( + op.getLoc(), + RankedTensorType::get( + cast(v.getType()).getShape(), + rewriter.getI64Type()), + v, SignednessAttr::get(op->getContext(), Signedness::Signed), + SignednessAttr::get(op->getContext(), Signedness::Signless)); + } + indices[i] = v; } auto indicesType = cast(indices[0].getType()); From 2d1b471485fdf9462ef8be8e7a8be068920885f9 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Mon, 4 Nov 2024 12:13:24 -0800 Subject: [PATCH 5/5] nits --- include/mlir-tcp/Dialect/IR/TcpOps.td | 2 +- lib/Conversion/TcpToLinalg/DataMovement.cpp | 2 +- lib/Conversion/TorchToTcp/DataMovement.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 0bafcba5..bb30b76b 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -664,7 +664,7 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> { - let summary = "Gather elements from input based on indices over multiple dimentions"; + let summary = "Gather elements from input based on indices over multiple dimensions"; let description = [{ Gathers elements from a given tensor based on indices that index along multiple dimensions. diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp index c9e7442b..7c3b35db 100644 --- a/lib/Conversion/TcpToLinalg/DataMovement.cpp +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -96,7 +96,7 @@ class ConvertGatherOp : public OpConversionPattern { * element in the result tensor using a programmatic expression. The last * dimension of the indicies tensor is used to index into the input tensor. * - * For example, we we have an indices tensor of shape 9x4x3x2 and an input + * For example, we have an indices tensor of shape 9x4x3x2 and an input * tensor of shape 5x6x7x8, then the resulting tensor will be of shape * 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used * to index into the indicies tensor. Then the last dimension of the index diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 29b0c27b..3b2990c8 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -280,7 +280,7 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { /** * The index.Tensor_hacked_twin takes a list of tensors which have to be - * broadcast together to be the same shape, and then those are feed into a + * broadcast together to be the same shape, and then those are fed into a * gather which will select the different axes */ class ConvertAtenIndexTensorHackedTwin