diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 045a08c..0edea8d 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -285,23 +285,78 @@ class ConvertAtenIndexTensorHackedTwin LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // ------- Matching the OP ------- auto self = adaptor.getSelf(); + auto selfType = cast(self.getType()); auto indicesList = op.getIndices(); SmallVector indices; if (!getListConstructElements(indicesList, indices)) return op.emitError("Failed to match list of indices"); + + for(unsigned int i = 0; i < indices.size(); i++) { + auto ttype = cast(getTypeConverter()->convertType(indices[i].getType())); + if(ttype.getRank() != selfType.getRank() - i) { + // Can use tensor.gather instead for this. But will require that there are some broadcasting to get the shapes to match + // what is expected + return failure("Failed to rewrite Tensor_hacked_twin. Need the element gather for this"); + } + for(int j = 1; j < ttype.getRank(); j++) { + if(ttype.getShape()[j] != 1) + return failure("Expected the axes >=1 to have size 1"); + } + } + + // ------ Rewriting the OP --------- + indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), indices); - // possible that this should ignore the first batch dim? - if (indices.size() != cast(self.getType()).getRank()) - return op.emitError( - "Expected the number of indicies to equal rank of self"); - for (unsigned int i = 0; i < indices.size(); i++) { + + for(unsigned int i = 0; i < indices.size(); i++) { + auto idx = indices[i]; + auto ttype = cast(idx.getType()); + auto selfType = cast(self.getType()); + SmallVector outShape(selfType.getShape()); + outShape[i] = ttype.getNumElements(); + auto outType = RankedTensorType::get( + outShape, cast(self.getType()).getElementType()); + + auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank()); + SmallVector broadcastValues; + SmallVector broadcastAxes; + + + for(unsigned int j = 0; j < selfType.getRank(); j++) { + if(j != i) { + broadcastAxes.push_back(j); + broadcastValues.push_back(rewriter.create(op.getLoc(), self, j)); + } + } + + auto broadcastedShape = rewriter.create( + op.getLoc(), + RankedTensorType::get(outShape, ttype.getElementType()), + expandedShape, + broadcastValues, + rewriter.getI64ArrayAttr(broadcastAxes) + ); + + auto gather = rewriter.create( + op.getLoc(), outType, self, broadcastedShape.getResult(), rewriter.getIndexAttr(i) + ); + self = gather.getResult(); + } + + /*for (unsigned int i = 0; i < indices.size(); i++) { auto idx = indices[i]; int numNonOneAxis = 0; auto ttype = cast(idx.getType()); + if(ttype.getRank() != indices.size() - i) { + // there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each? + // so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are + return op.emitError("unsure what to do"); + } for (int j = 0; j < ttype.getRank(); j++) if (ttype.getShape()[j] != 1) numNonOneAxis++; @@ -328,7 +383,9 @@ class ConvertAtenIndexTensorHackedTwin auto gather = rewriter.create( op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); self = gather.getResult(); - } + }*/ + + // assert(op.getType() == self.getType()); rewriter.replaceOp(op, self); return success(); diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index c762ff1..16fdd5b 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -41,6 +41,8 @@ getTcpSignednessAttr(MLIRContext *context, // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { + if(rankIncrease == 0) + return input; RankedTensorType inputType = input.getType().cast(); SmallVector reassociationMap(inputType.getRank()); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 204bcc9..83c098d 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -179,15 +179,17 @@ LogicalResult GatherOp::verify() { return emitOpError("tcp.gather requires that the input tensor and indices are the same rank"); for(int i = 0; i < inputTensor.getRank(); i++) { - if(inputTensor.getShape()[i] != indicesTensor.getShape()[i]) { - if(!(inputTensor.getShape()[i] == ShapedType::kDynamic || - indicesTensor.getShape()[i] == 1 || + if(inputTensor.getShape()[i] != indicesTensor.getShape()[i] && !( + inputTensor.getShape()[i] == ShapedType::kDynamic || i == gatherDim)) { return emitOpError("indices tensor does not match expected shape"); - } } } + if(getResult().getType().getShape() != indicesTensor.getShape()) { + return emitOpError("Expect the shape of the indicies to match the output shape"); + } + return success(); } diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 1b3e88c..6f9cc7c 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -590,3 +590,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return TorchLoaderOutput( model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) + +def gather_slices_select_loader() -> TorchLoaderOutput: + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + i1 = torch.tensor([[0],[1],[2],[3]]) + return x[i1,[2,5,7]] + + x = torch.rand(4,10) + # batch = Dim("batch", min=3) + # dynamic_shapes = {"x": {0: batch}} + + return TorchLoaderOutput( + model=Model(), inputs=(x,),# dynamic_shapes=dynamic_shapes + ) diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index be5fa1d..dcf5cda 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -69,9 +69,9 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor // CHECK-label: @torch.aten.index.tensor_hacked_twin // CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1xi64> -> tensor<1x20x30xf32> -// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<5xi64> -> tensor<1x5x30xf32> -// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<20xi64> -> tensor<1x5x20xf32> +// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32> +// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32> +// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32> // CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]] // CHECK: return %[[RET]] func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {