Skip to content

Commit

Permalink
Fixes, refactoring, handling of -ve indices, added lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 7, 2024
1 parent 0eb426a commit 4371df0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,35 @@ bool isListPotentiallyMutated(Value list);
/// the list.
bool potentiallyMutatesListOperands(Operation *op);

/// Returns the value from an `IntegerAttr` as an `int64_t`.
///
/// @param intAttr the `IntegerAttr` from which to extract the value
/// @return the value as an `int64_t`
///
/// Regardless of the signed-ness of the attribute, this function returns
/// the value as a signed integer, which implies that if the attribute has
/// a 64-bit unsigned value, it will be converted to an int64_t in the manner
/// that uint64_t is cast to int64_t in C++.
int64_t getIntAttrAsSigned(IntegerAttr intAttr);

/// Returns the value from an `IntegerAttr` as an integral index.
///
/// @param intAttr the `IntegerAttr` from which to extract the index
/// @param dimSize the size of the dimension that the attribute indexes into
/// @return the index value
///
/// Use this function when the given `IntegerAttr` represents an index into
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
/// negative index values are converted into positive vales by counting
/// elements from the "right" side of the dimension, as in python, numpy, etc.
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
/// given, any negative indices are returned as negative numbers.
///
/// No bounds checking is performed on the index to ensure that it is within
/// the legal range for `dimSize`.
int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
37 changes: 28 additions & 9 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,25 @@ static Value getScalarFloatValue(Value input, Location loc,
return nullptr;
}

int64_t mlir::torch::Torch::getIntAttrAsSigned(IntegerAttr intAttr) {
if (intAttr.getType().isSignedInteger())
return intAttr.getSInt();
if (intAttr.getType().isUnsignedInteger())
return int64_t(intAttr.getUInt());
if (intAttr.getType().isSignlessInteger())
return intAttr.getInt(); // signless returns as int64_t
assert(false && "Unhandled integer attribute type");
return 0;
}

int64_t mlir::torch::Torch::getIntAttrAsIndex(IntegerAttr intAttr,
int dimSize) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize - (-signedIndex);
}

//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2895,6 +2914,11 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
// a single element. Handles float and int types.

int64_t dimInt = dimAttr.getInt();
// If the selected dim is negative, count backwards from the last dim
if (dimInt < 0)
dimInt = selfSizes.size() - (-dimInt);
assert(uint64_t(dimInt) < selfSizes.size() &&
"Selected dim > number of dims");

bool scalarFold = true;
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
Expand All @@ -2905,17 +2929,12 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
if (!scalarFold)
return nullptr;

// Get the single index value for the selected dimension
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
uint64_t indexInt = 0;
if (splatValue.getType().isSignedInteger())
indexInt = uint64_t(splatValue.getSInt());
else if (splatValue.getType().isUnsignedInteger())
indexInt = splatValue.getUInt();
else if (splatValue.getType().isSignlessInteger())
indexInt = uint64_t(splatValue.getInt());
else
return nullptr;
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);

// Extract the single constant value from the input tensor and turn the
// extracted value into a single-element tensor of the output shape and dtype
auto splattr = selfAttr.getValues<Attribute>()[indexInt];

auto dty = resultTy.getDtype();
Expand Down
11 changes: 11 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2245,3 +2245,14 @@ func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> {
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> {
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
%dim = torch.constant.int -1
%index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

0 comments on commit 4371df0

Please sign in to comment.