Skip to content

Commit

Permalink
cp
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfl committed Sep 19, 2024
1 parent 543f8ee commit 0e8e84a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 13 deletions.
69 changes: 63 additions & 6 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,78 @@ class ConvertAtenIndexTensorHackedTwin
LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ------- Matching the OP -------
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
auto indicesList = op.getIndices();
SmallVector<Value> indices;
if (!getListConstructElements(indicesList, indices))
return op.emitError("Failed to match list of indices");

for(unsigned int i = 0; i < indices.size(); i++) {
auto ttype = cast<RankedTensorType>(getTypeConverter()->convertType(indices[i].getType()));
if(ttype.getRank() != selfType.getRank() - i) {
// Can use tensor.gather instead for this. But will require that there are some broadcasting to get the shapes to match
// what is expected
return failure("Failed to rewrite Tensor_hacked_twin. Need the element gather for this");
}
for(int j = 1; j < ttype.getRank(); j++) {
if(ttype.getShape()[j] != 1)
return failure("Expected the axes >=1 to have size 1");
}
}

// ------ Rewriting the OP ---------

indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);

// possible that this should ignore the first batch dim?
if (indices.size() != cast<RankedTensorType>(self.getType()).getRank())
return op.emitError(
"Expected the number of indicies to equal rank of self");

for (unsigned int i = 0; i < indices.size(); i++) {

for(unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
auto ttype = cast<RankedTensorType>(idx.getType());
auto selfType = cast<RankedTensorType>(self.getType());
SmallVector<int64_t> outShape(selfType.getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank());
SmallVector<Value> broadcastValues;
SmallVector<int64_t> broadcastAxes;


for(unsigned int j = 0; j < selfType.getRank(); j++) {
if(j != i) {
broadcastAxes.push_back(j);
broadcastValues.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
}
}

auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
op.getLoc(),
RankedTensorType::get(outShape, ttype.getElementType()),
expandedShape,
broadcastValues,
rewriter.getI64ArrayAttr(broadcastAxes)
);

auto gather = rewriter.create<tcp::GatherOp>(
op.getLoc(), outType, self, broadcastedShape.getResult(), rewriter.getIndexAttr(i)
);
self = gather.getResult();
}

/*for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
int numNonOneAxis = 0;
auto ttype = cast<RankedTensorType>(idx.getType());
if(ttype.getRank() != indices.size() - i) {
// there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each?
// so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are
return op.emitError("unsure what to do");
}
for (int j = 0; j < ttype.getRank(); j++)
if (ttype.getShape()[j] != 1)
numNonOneAxis++;
Expand All @@ -328,7 +383,9 @@ class ConvertAtenIndexTensorHackedTwin
auto gather = rewriter.create<tcp::GatherOp>(
op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i));
self = gather.getResult();
}
}*/

// assert(op.getType() == self.getType());

rewriter.replaceOp(op, self);
return success();
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ getTcpSignednessAttr(MLIRContext *context,
// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
if(rankIncrease == 0)
return input;
RankedTensorType inputType = input.getType().cast<RankedTensorType>();

SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/IR/TcpOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,17 @@ LogicalResult GatherOp::verify() {
return emitOpError("tcp.gather requires that the input tensor and indices are the same rank");

for(int i = 0; i < inputTensor.getRank(); i++) {
if(inputTensor.getShape()[i] != indicesTensor.getShape()[i]) {
if(!(inputTensor.getShape()[i] == ShapedType::kDynamic ||
indicesTensor.getShape()[i] == 1 ||
if(inputTensor.getShape()[i] != indicesTensor.getShape()[i] && !(
inputTensor.getShape()[i] == ShapedType::kDynamic ||
i == gatherDim)) {
return emitOpError("indices tensor does not match expected shape");
}
}
}

if(getResult().getType().getShape() != indicesTensor.getShape()) {
return emitOpError("Expect the shape of the indicies to match the output shape");
}

return success();
}

Expand Down
14 changes: 14 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(
model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes
)

def gather_slices_select_loader() -> TorchLoaderOutput:
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
i1 = torch.tensor([[0],[1],[2],[3]])
return x[i1,[2,5,7]]

x = torch.rand(4,10)
# batch = Dim("batch", min=3)
# dynamic_shapes = {"x": {0: batch}}

return TorchLoaderOutput(
model=Model(), inputs=(x,),# dynamic_shapes=dynamic_shapes
)
6 changes: 3 additions & 3 deletions test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !tor

// CHECK-label: @torch.aten.index.tensor_hacked_twin
// CHECK-DAG: %[[CAST0:.+]] = torch_c.to_builtin_tensor %arg0
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1xi64> -> tensor<1x20x30xf32>
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<5xi64> -> tensor<1x5x30xf32>
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<20xi64> -> tensor<1x5x20xf32>
// CHECK-DAG: %[[GATHER0:.+]] = tcp.gather %[[CAST0]], %[[SELECT0:.+]] {dim = 0 : index} : tensor<1x20x30xf32>, tensor<1x20x30xi64> -> tensor<1x20x30xf32>
// CHECK-DAG: %[[GATHER1:.+]] = tcp.gather %[[GATHER0]], %[[SELECT1:.+]] {dim = 1 : index} : tensor<1x20x30xf32>, tensor<1x5x30xi64> -> tensor<1x5x30xf32>
// CHECK-DAG: %[[GATHER2:.+]] = tcp.gather %[[GATHER1]], %[[SELECT2:.+]] {dim = 2 : index} : tensor<1x5x30xf32>, tensor<1x5x20xi64> -> tensor<1x5x20xf32>
// CHECK-DAG: %[[RET:.+]] = torch_c.from_builtin_tensor %[[GATHER2]]
// CHECK: return %[[RET]]
func.func @torch.aten.index.tensor_hacked_twin(%arg0: !torch.vtensor<[1,20,30],f32>, %select1: !torch.vtensor<[5,1],si64>, %select2: !torch.vtensor<[20],si64>) -> !torch.vtensor<[1,5,20],f32> {
Expand Down

0 comments on commit 0e8e84a

Please sign in to comment.