Skip to content

Commit

Permalink
Switch to using a more efficient impl. of 'largest factor' (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Apr 23, 2024
1 parent 0b3ba11 commit 17a2124
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,17 +531,35 @@ bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) {
return false;
}

// A naive implementation to find largest factor, smaller than a given int, for
// a given integer.
int getLargestFactorSmallerThan(int inputInt, int smallerThanInt = 0) {
int factor = 1;
for (int i = 2; i < inputInt; i++) {
if (smallerThanInt && i >= smallerThanInt)
break;
if (inputInt % i == 0)
factor = i;
// Find the largest factor of 'num' which is not larger than 'max'. Ref:
// https://github.com/nod-ai/iree-amd-aie/blob/main/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEUtils.cpp#L334
int findLargestFactor(int num, int max) {
assert(max > 0 && "No factors less than or equal to 0 exist");

// Do O(1) instead of O(sqrt(num)) computation for this common case.
if (num <= max) {
return num;
}
return factor;

int largestLowFactor = 1;
for (int lowFactor = 2; lowFactor <= max; ++lowFactor) {
const int highFactor = num / lowFactor;

// This early exit is what makes this O(sqrt(num)) instead of O(num).
if (highFactor < lowFactor)
return largestLowFactor;

const bool areActuallyFactors = num % lowFactor == 0;
if (areActuallyFactors) {
// We're certain that here lowFactor <= highFactor, and highFactor is
// descending in this loop. So we can return immediately if highFactor is
// good.
if (highFactor <= max)
return highFactor;
largestLowFactor = lowFactor;
}
}
return largestLowFactor;
}

void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
Expand All @@ -560,8 +578,7 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto const_stride = *getConstantIntValue(strides[i]);
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
// Found dimension with illegal wrap. Tiling.
int inner_wrap =
getLargestFactorSmallerThan(const_wrap, AIE2_WRAP_UPPER_BOUND);
int inner_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int new_wrap = mlir::ceilDiv(const_wrap, inner_wrap);
wraps[i] = builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
Expand Down

0 comments on commit 17a2124

Please sign in to comment.