Skip to content

Commit

Permalink
folder (clang-formatted) and lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 6, 2024
1 parent cb52c4b commit 7d00efa
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9784,6 +9784,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [
Expand Down
78 changes: 78 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,84 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenIndexSelectOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
auto self = getSelf();
auto index = getIndex();
auto selfTy = cast<ValueTensorType>(self.getType());
auto indexTy = cast<ValueTensorType>(index.getType());
auto resultTy = cast<ValueTensorType>(getType());

auto selfSizes = selfTy.getSizes();
auto indexSizes = indexTy.getSizes();
auto resultSizes = resultTy.getSizes();

if (selfTy.getDtype() != resultTy.getDtype())
return nullptr;
if (selfSizes.size() != resultSizes.size())
return nullptr;
if (indexSizes.size() != 1)
return nullptr;

// If the selection results in a tensor of the same dimensions as the
// input, the selection must have specified every index of the input,
// so the result is exactly the same as the input.

bool fullTensor = true;
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
fullTensor &= selfSizes[i] == resultSizes[i];
fullTensor &= selfSizes[i] != Torch::kUnknownSize;
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
}

if (fullTensor && indexSizes[0] == 1)
return self;

// If the input tensor, index dimension, or indexes are non-constant,
// can't fold.

auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
auto indexAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getIndex());

if (!selfAttr || !dimAttr || !indexAttr)
return {};

// If the input's dimensions are all 1 except for one dimension, and if
// there is a single index in the index list (as detected by the result
// dimension being 1), then fold to a <1x1x...x1> tensor literal containing
// a single element. Handles float and int types.

int64_t dimInt = dimAttr.getInt();

bool scalarFold = true;
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
scalarFold &= selfSizes[i] == 1 || i == dimInt;
scalarFold &= resultSizes[i] == 1;
}

if (!scalarFold)
return nullptr;

auto indexInt = indexAttr.getSplatValue<IntegerAttr>().getInt();
auto splattr = selfAttr.getValues<Attribute>()[indexInt];

auto dty = resultTy.getDtype();
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
return DenseElementsAttr::get(
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));

if (auto intAttr = dyn_cast<IntegerAttr>(splattr))
return DenseElementsAttr::get(attrTy,
IntegerAttr::get(dty, intAttr.getInt()));

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenItemOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2193,3 +2193,22 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
return %1 : !torch.tensor
}

// CHECK-LABEL: func.func @torch.aten.index_select$noop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
// CHECK-NEXT: foo
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> (!torch.vtensor<[1,2,3],si64>) {
%0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64>
return %0 : !torch.vtensor<[1,2,3],si64>
}

// CHECK-LABEL: func.func @torch.aten.index_select$const(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
// CHECK-NEXT: foo
func.func @torch.aten.index_select$const(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> (!torch.vtensor<[1],si64>) {
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}

0 comments on commit 7d00efa

Please sign in to comment.