Skip to content

Commit

Permalink
Handle index_select by lowering it to tcp.gather (#88)
Browse files Browse the repository at this point in the history
This PR adds support for lowering `torch.index_select` to `tcp.gather`.
The logic behind this lowering is documented
[here](https://github.com/cruise-automation/mlir-tcp/blob/main/docs/gather.md#gather-slices-along-a-given-dim).

```
bazel test //test:Conversion/TorchToTcp/data_movement.mlir.test
bazel test //test/AotCompile:gather_slices_compile_execute_test
```
  • Loading branch information
navahgar authored Aug 20, 2024
1 parent f3f29bf commit 0061ac5
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 20 deletions.
45 changes: 45 additions & 0 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,48 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
}
};

class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();

auto indices = adaptor.getIndex();
auto indicesRank = cast<RankedTensorType>(indices.getType()).getRank();
// As per the semantics of `torch.index_select` op, the indices is
// always a 1-D tensor. We enforce that check here as some of the
// utilities used below work only for that case.
if (indicesRank != 1)
return rewriter.notifyMatchFailure(op, "indices need to be 1-D");

RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));

int64_t dim = 0;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "dim on torch.index_select must be an int constant");
dim = Torch::toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(
op, "dim on torch.index_select is statically invalid");

auto indicesRankBroadcasted = torch_to_tcp::broadcastRank0Dor1DToND(
rewriter, indices, inputRank, dim);
auto indicesBroadcasted = torch_to_tcp::broadcastShapeExceptDims(
rewriter, indicesRankBroadcasted, input,
llvm::SmallDenseSet<int64_t>{dim});
rewriter.replaceOpWithNewOp<tcp::GatherOp>(
op, resultType, input, indicesBroadcasted, rewriter.getIndexAttr(dim));
return success();
}
};

} // namespace

void torch_to_tcp::populateDataMovementPatternsAndLegality(
Expand All @@ -249,4 +291,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenGatherOp,
AtenGatherOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenIndexSelectOp,
AtenIndexSelectOp>(
typeConverter, patterns, target, convertTorchOpsSet);
}
80 changes: 60 additions & 20 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,64 @@ Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
input.getDefiningOp()->getLoc(), resultType, input, reassociationMap);
}

Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input,
int64_t targetRank, int64_t axisInOutput) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
assert(inputRank < 2 && "Only 0D and 1D tensors are supported!");

// Case 1: 0D -> ND
// [] -> [1, 1, 1, 1]
// reassociation map = [[]]
// Case 2: 1D -> ND
// [C] -> [1, C, 1, 1] if axisInOutput = 1
// reassociation map = [[0, 1, 2, 3]]
SmallVector<ReassociationExprs> reassociationMap(inputRank);
SmallVector<int64_t> resultShape(targetRank, 1);
if (inputRank == 1) {
for (int64_t axis = 0; axis < targetRank; ++axis)
reassociationMap[0].push_back(rewriter.getAffineDimExpr(axis));
resultShape[axisInOutput] = inputType.getShape()[0];
}
Type expandResultType =
inputType.cloneWith(ArrayRef(resultShape), inputType.getElementType());
return rewriter.create<tensor::ExpandShapeOp>(input.getDefiningOp()->getLoc(),
expandResultType, input,
reassociationMap);
}

Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input,
Value target,
llvm::SmallDenseSet<int64_t> dimsToExclude) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
auto inputShape = inputType.getShape();

RankedTensorType targetType = target.getType().cast<RankedTensorType>();
auto targetShape = targetType.getShape();

SmallVector<int64_t> axes;
SmallVector<Value> dimSizes;
SmallVector<int64_t> resultShape;
// Ensure that dimsToBroadcast is sorted.
for (int64_t axis = 0; axis < targetType.getRank(); ++axis) {
if (dimsToExclude.contains(axis)) {
resultShape.push_back(inputShape[axis]);
} else {
resultShape.push_back(targetShape[axis]);
axes.push_back(axis);
dimSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
input.getDefiningOp()->getLoc(), target, axis));
}
}
auto axesAttr = rewriter.getI64ArrayAttr(axes);

Type broadcastResultType =
inputType.cloneWith(resultShape, inputType.getElementType());
return rewriter.create<tcp::BroadcastOp>(input.getDefiningOp()->getLoc(),
broadcastResultType, input, dimSizes,
axesAttr);
}

// The parameters input are expected to be of RankedTensorType.
std::pair<Value, Value>
broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs,
Expand Down Expand Up @@ -135,27 +193,9 @@ Value broadcast0DOr1DToNDAndMatchShape(ConversionPatternRewriter &rewriter,
// This utility only accepts 0D and 1D inputs
assert(inputRank < 2 && "Only 0D and 1D tensors are supported!");

Value result = input;

// First: Broadcast Rank
// Case 1: 0D -> ND
// [] -> [1, 1, 1, 1]
// reassociation map = [[]]
// Case 2: 1D -> ND
// [C] -> [1, C, 1, 1] if axisInOutput = 1
// reassociation map = [[0, 1, 2, 3]]
SmallVector<ReassociationExprs> reassociationMap(inputRank);
SmallVector<int64_t> resultShape(targetRank, 1);
if (inputRank == 1) {
for (int64_t axis = 0; axis < targetRank; ++axis)
reassociationMap[0].push_back(rewriter.getAffineDimExpr(axis));
resultShape[axisInOutput] = inputType.getShape()[0];
}
Type expandResultType =
targetType.cloneWith(ArrayRef(resultShape), resultType);
result = rewriter.create<tensor::ExpandShapeOp>(
result.getDefiningOp()->getLoc(), expandResultType, input,
reassociationMap);
Value result =
broadcastRank0Dor1DToND(rewriter, input, targetRank, axisInOutput);

// Second: Broadcast Shape
// Case 1: 0D -> ND
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ getTcpSignednessAttr(MLIRContext *context,
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease);

// Broadcasts the rank of the input tensor from 0D or 1D to ND. If the input
// tensor is 1D, `axisInOutput` specifies the axis where the input axis should
// end up in the output.
Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input,
int64_t targetRank, int64_t axisInOutput);

// Broadcasts the shape of the input tensor to match the shape of the target
// tensor in all dims except the dims specified in `dimsToExclude`.
Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input,
Value target,
llvm::SmallDenseSet<int64_t> dimsToExclude);

// Helper function to do both rank and shape all-dim broadcasting
// of the inputs to match each other.
std::pair<Value, Value>
Expand Down
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ AOT_TEST_SUITE = [
("broadcast_unit_dim_to_static_with_rank_increase", False),
("broadcast_unit_dim_to_dynamic_with_rank_increase", False),
("gather_elements", False),
("gather_slices", False),
]

py_library(
Expand Down
21 changes: 21 additions & 0 deletions test/AotCompile/model_loader_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return TorchLoaderOutput(
model=GatherElements(), inputs=(x, y), dynamic_shapes=dynamic_shapes
)


def gather_slices_loader() -> TorchLoaderOutput:
class GatherSlices(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.index_select(x, 1, y)

# Sample inputs
x = torch.randn(4, 3)
y = torch.tensor([2, 0])

# Dynamic dim constraints
batch = Dim("batch", min=3)
dynamic_shapes = {"x": {0: batch}, "y": {}}

return TorchLoaderOutput(
model=GatherSlices(), inputs=(x, y), dynamic_shapes=dynamic_shapes
)
18 changes: 18 additions & 0 deletions test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,21 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32>
return %0 : !torch.vtensor<[1,4,2],f32>
}

// -----

// CHECK-LABEL: @torch.aten.index_select
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4,3],f32>,
// CHECK-SAME: %[[ARG1:.+]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],f32>
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape
// CHECK-SAME: tensor<2xi64> into tensor<1x2xi64>
// CHECK: %[[BROADCAST:.+]] = tcp.broadcast %[[EXPAND_SHAPE]], %{{.*}} {axes = [0]} : tensor<1x2xi64>, index -> tensor<4x2xi64>
// CHECK: %[[GATHER:.+]] = tcp.gather %{{.*}}, %[[BROADCAST]] {dim = 1 : index} :
// CHECK-SAME: tensor<4x3xf32>, tensor<4x2xi64> -> tensor<4x2xf32>
// CHECK: %[[V3:.+]] = torch_c.from_builtin_tensor %[[GATHER]] : tensor<4x2xf32> -> !torch.vtensor<[4,2],f32>
// CHECK: return %[[V3]] : !torch.vtensor<[4,2],f32>
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],f32> {
%int-1 = torch.constant.int -1
%0 = torch.aten.index_select %arg0, %int-1, %arg1: !torch.vtensor<[4,3],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,2],f32>
return %0 : !torch.vtensor<[4,2],f32>
}

0 comments on commit 0061ac5

Please sign in to comment.