Skip to content

Commit

Permalink
Merge branch 'dliddell-aten-index-select' of https://github.com/davel…
Browse files Browse the repository at this point in the history
…iddell/torch-mlir into dliddell-aten-index-select
  • Loading branch information
Dave Liddell committed Feb 6, 2024
2 parents ce873e4 + f0a3f1a commit 6d14771
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2843,20 +2843,17 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
auto self = getSelf();
auto index = getIndex();
auto selfTy = cast<ValueTensorType>(self.getType());
auto indexTy = cast<ValueTensorType>(index.getType());
assert(index.getType().isa<IntegerType>());
auto resultTy = cast<ValueTensorType>(getType());

auto selfSizes = selfTy.getSizes();
auto indexSizes = indexTy.getSizes();
auto resultSizes = resultTy.getSizes();


if (selfTy.getDtype() != resultTy.getDtype())
return nullptr;
if (selfSizes.size() != resultSizes.size())
return nullptr;
if (indexSizes.size() != 1)
return nullptr;

// If the selection results in a tensor of the same dimensions as the
// input, the selection must have specified every index of the input,
Expand All @@ -2869,7 +2866,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
}

if (fullTensor && indexSizes[0] == 1)
if (fullTensor)
return self;

// If the input tensor, index dimension, or indexes are non-constant,
Expand Down

0 comments on commit 6d14771

Please sign in to comment.