Skip to content

Commit

Permalink
Incorporated code review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 7, 2024
1 parent 4371df0 commit 3f7f475
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 38 deletions.
13 changes: 11 additions & 2 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,11 @@ bool potentiallyMutatesListOperands(Operation *op);
/// the value as a signed integer, which implies that if the attribute has
/// a 64-bit unsigned value, it will be converted to an int64_t in the manner
/// that uint64_t is cast to int64_t in C++.
int64_t getIntAttrAsSigned(IntegerAttr intAttr);
inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
if (intAttr.getType().isUnsignedInteger())
return intAttr.getValue().getZExtValue();
return intAttr.getValue().getSExtValue();
}

/// Returns the value from an `IntegerAttr` as an integral index.
///
Expand All @@ -321,7 +325,12 @@ int64_t getIntAttrAsSigned(IntegerAttr intAttr);
///
/// No bounds checking is performed on the index to ensure that it is within
/// the legal range for `dimSize`.
int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1);
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize + signedIndex; // count backwards from dimSize
}

} // namespace Torch
} // namespace torch
Expand Down
45 changes: 9 additions & 36 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,25 +199,6 @@ static Value getScalarFloatValue(Value input, Location loc,
return nullptr;
}

int64_t mlir::torch::Torch::getIntAttrAsSigned(IntegerAttr intAttr) {
if (intAttr.getType().isSignedInteger())
return intAttr.getSInt();
if (intAttr.getType().isUnsignedInteger())
return int64_t(intAttr.getUInt());
if (intAttr.getType().isSignlessInteger())
return intAttr.getInt(); // signless returns as int64_t
assert(false && "Unhandled integer attribute type");
return 0;
}

int64_t mlir::torch::Torch::getIntAttrAsIndex(IntegerAttr intAttr,
int dimSize) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize - (-signedIndex);
}

//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2861,16 +2842,12 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
auto self = getSelf();
auto index = getIndex();
auto selfTy = dyn_cast_or_null<ValueTensorType>(self.getType());
auto indexTy = dyn_cast_or_null<ValueTensorType>(index.getType());
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!selfTy || !indexTy || !resultTy)
return nullptr;

if (!selfTy.hasSizes() || !indexTy.hasSizes() || !resultTy.hasSizes())
return nullptr;

if (!selfTy.hasDtype() || !indexTy.hasDtype() || !resultTy.hasDtype())
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
auto indexTy = dyn_cast<ValueTensorType>(index.getType());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() ||
!indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() ||
!indexTy.hasDtype() || !resultTy.hasDtype())
return nullptr;

auto selfSizes = selfTy.getSizes();
Expand Down Expand Up @@ -2916,19 +2893,15 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
int64_t dimInt = dimAttr.getInt();
// If the selected dim is negative, count backwards from the last dim
if (dimInt < 0)
dimInt = selfSizes.size() - (-dimInt);
dimInt = selfSizes.size() + dimInt;
assert(uint64_t(dimInt) < selfSizes.size() &&
"Selected dim > number of dims");

bool scalarFold = true;
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
scalarFold &= selfSizes[i] == 1 || i == dimInt;
scalarFold &= resultSizes[i] == 1;
if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1)
return nullptr;
}

if (!scalarFold)
return nullptr;

// Get the single index value for the selected dimension
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
Expand Down

0 comments on commit 3f7f475

Please sign in to comment.