Skip to content

Commit

Permalink
AIRSpecializeWrapsAndStrides: Refactor BD canonicalization to reuse t…
Browse files Browse the repository at this point in the history
…he core method impl. (Xilinx#657)

* Fixup an error with offset calculation if the corresponding stride isn't 1

* Remove code repetition by reusing eraseWrapNStrideDim method; implement eraseWrapNStrideDim as function instead of lambda

* New unit test
  • Loading branch information
erwei-xilinx authored Jul 9, 2024
1 parent e4c9ab9 commit b53f1cf
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 106 deletions.
221 changes: 115 additions & 106 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,106 @@ void air::getDefiningOpsToOperands(Operation *op,
}
}

// Erase dims; recalculate offsets if needed.
LogicalResult eraseWrapNStrideDim(OpBuilder builder,
SmallVector<int> erase_dims,
SmallVector<Value> &offsets,
SmallVector<Value> &sizes,
SmallVector<Value> &strides) {
auto original_insert_point = builder.saveInsertionPoint();
bool erased = false;
// Multiply adjacent wraps.
auto multiplyAdjWraps = [](OpBuilder builder, int erase_dim,
SmallVector<Value> &sizes) {
auto const_size = getConstantIntValue(sizes[erase_dim]);
if (!const_size)
return false; // non-static wrap, NYI.
auto const_size_next = getConstantIntValue(sizes[erase_dim + 1]);
if (!const_size_next)
return false; // non-static wrap, NYI.
sizes[erase_dim + 1] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), (*const_size) * (*const_size_next));
return true;
};
for (auto i : erase_dims) {
auto const_offset = getConstantIntValue(offsets[i]);
if (const_offset && *const_offset == 0) {
erased |= multiplyAdjWraps(builder, i, sizes);
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
erased = true;
continue;
}
// Propagate any non-zero offset and non-unit wrap to the next dimension
if (offsets.begin() + i + 1 == offsets.end())
continue;
auto const_stride = getConstantIntValue(strides[i]);
assert(const_stride && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[i + 1]);
if (!const_offset_next)
continue;
auto const_stride_next = getConstantIntValue(strides[i + 1]);
assert(const_stride_next && "non-static stride, NYI.");
erased |= multiplyAdjWraps(builder, i, sizes);
if (const_offset) {
offsets[i + 1] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(),
(*const_stride) * (*const_offset) / (*const_stride_next) +
(*const_offset_next));
} else {
// Get affine.apply which produces the offset ssa
Operation *offset_producer = offsets[i].getDefiningOp();
if (offset_producer && isa<arith::IndexCastOp>(offset_producer)) {
auto castOp = dyn_cast<arith::IndexCastOp>(offset_producer);
offsets[i] = castOp.getIn();
offset_producer = castOp.getIn().getDefiningOp();
}
if (!offset_producer) {
if (!affine::getForInductionVarOwner(offsets[i]))
continue;
auto afo = affine::getForInductionVarOwner(offsets[i]);
builder.setInsertionPointToStart(afo.getBody());
// Create a new affine.apply on affine.for ind. vars, as handle for
// subsequent offset composition.
auto sym0_expr = getAffineSymbolExpr(0, builder.getContext());
auto iv_map = AffineMap::get(0, 1, sym0_expr);
offset_producer = builder.create<affine::AffineApplyOp>(
builder.getUnknownLoc(), iv_map, offsets[i]);
}
if (auto exec = dyn_cast<air::ExecuteOp>(offset_producer))
offset_producer = exec.getChildOp();
auto affine_apply = dyn_cast<affine::AffineApplyOp>(offset_producer);
assert(affine_apply && "ssa offset not produced by affine.apply, NYI.");
// Compose affine map
auto offset_expr = getAffineSymbolExpr(0, builder.getContext());
auto stride_expr =
getAffineConstantExpr(*const_stride, builder.getContext());
auto next_stride_expr =
getAffineConstantExpr(*const_stride_next, builder.getContext());
offset_expr = offset_expr * stride_expr;
offset_expr = offset_expr.ceilDiv(next_stride_expr);
offset_expr = offset_expr + getAffineConstantExpr(*const_offset_next,
builder.getContext());
SmallVector<AffineExpr, 8> symReplacements(
affine_apply.getAffineMap().getResults().begin(),
affine_apply.getAffineMap().getResults().end());
offset_expr = offset_expr.replaceDimsAndSymbols({}, symReplacements);
auto next_offset_map = AffineMap::get(0, 1, offset_expr);
affine_apply.setMap(next_offset_map);
offsets[i + 1] = offsets[i];
}
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
erased = true;
}
builder.restoreInsertionPoint(original_insert_point);
if (erased)
return success();
return failure();
};

// Canonicalize wrap and stride lists by removing redundant dimensions.
LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
SmallVector<Value> &offsets,
Expand All @@ -838,88 +938,6 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
listsHaveChanged = true;
}

// Erase dims; recalculate offsets if needed.
auto eraseWrapNStrideDim = [](OpBuilder builder, SmallVector<int> erase_dims,
SmallVector<Value> &offsets,
SmallVector<Value> &sizes,
SmallVector<Value> &strides) {
auto original_insert_point = builder.saveInsertionPoint();
bool erased = false;
for (auto i : erase_dims) {
auto const_offset = getConstantIntValue(offsets[i]);
if (const_offset && *const_offset == 0) {
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
erased = true;
continue;
}
// Propagate any non-zero offset to the next dimension
if (offsets.begin() + i + 1 == offsets.end())
continue;
auto const_stride = getConstantIntValue(strides[i]);
assert(const_stride && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[i + 1]);
if (!const_offset_next)
continue;
auto const_stride_next = getConstantIntValue(strides[i + 1]);
assert(const_stride_next && "non-static stride, NYI.");
if (const_offset) {
offsets[i + 1] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(),
(*const_stride) * (*const_offset) / (*const_stride_next) +
(*const_offset_next));
} else {
// Get affine.apply which produces the offset ssa
Operation *offset_producer = offsets[i].getDefiningOp();
if (offset_producer && isa<arith::IndexCastOp>(offset_producer)) {
auto castOp = dyn_cast<arith::IndexCastOp>(offset_producer);
offsets[i] = castOp.getIn();
offset_producer = castOp.getIn().getDefiningOp();
}
if (!offset_producer) {
if (!affine::getForInductionVarOwner(offsets[i]))
continue;
auto afo = affine::getForInductionVarOwner(offsets[i]);
builder.setInsertionPointToStart(afo.getBody());
// Create a new affine.apply on affine.for ind. vars, as handle for
// subsequent offset composition.
auto sym0_expr = getAffineSymbolExpr(0, builder.getContext());
auto iv_map = AffineMap::get(0, 1, sym0_expr);
offset_producer = builder.create<affine::AffineApplyOp>(
builder.getUnknownLoc(), iv_map, offsets[i]);
}
if (auto exec = dyn_cast<air::ExecuteOp>(offset_producer))
offset_producer = exec.getChildOp();
auto affine_apply = dyn_cast<affine::AffineApplyOp>(offset_producer);
assert(affine_apply && "ssa offset not produced by affine.apply, NYI.");
// Compose affine map
auto offset_expr = getAffineSymbolExpr(0, builder.getContext());
auto stride_expr =
getAffineConstantExpr(*const_stride, builder.getContext());
auto next_stride_expr =
getAffineConstantExpr(*const_stride_next, builder.getContext());
offset_expr = offset_expr * stride_expr;
offset_expr = offset_expr.ceilDiv(next_stride_expr);
offset_expr = offset_expr + getAffineConstantExpr(*const_offset_next,
builder.getContext());
SmallVector<AffineExpr, 8> symReplacements(
affine_apply.getAffineMap().getResults().begin(),
affine_apply.getAffineMap().getResults().end());
offset_expr = offset_expr.replaceDimsAndSymbols({}, symReplacements);
auto next_offset_map = AffineMap::get(0, 1, offset_expr);
affine_apply.setMap(next_offset_map);
offsets[i + 1] = offsets[i];
}
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
erased = true;
}
builder.restoreInsertionPoint(original_insert_point);
return erased;
};

// Canonicalize dimensions with size = 1
SmallVector<int> erase_dims;
for (int i = sizes.size() - 1; i >= 0; i--) {
Expand All @@ -932,35 +950,26 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
}

listsHaveChanged |=
eraseWrapNStrideDim(builder, erase_dims, offsets, sizes, strides);
eraseWrapNStrideDim(builder, erase_dims, offsets, sizes, strides)
.succeeded();
erase_dims.clear();

// Canonicalize adjacent dimensions
if (!sizes.empty()) {
for (int i = sizes.size() - 1; i >= 1; i--) {
if (getConstantIntValue(offsets[i]) &&
getConstantIntValue(offsets[i - 1]) &&
getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) &&
getConstantIntValue(strides[i]) &&
getConstantIntValue(strides[i - 1])) {
auto const_offset = *getConstantIntValue(offsets[i]);
auto const_offset_next = *getConstantIntValue(offsets[i - 1]);
auto const_size = *getConstantIntValue(sizes[i]);
auto const_size_next = *getConstantIntValue(sizes[i - 1]);
auto const_stride = *getConstantIntValue(strides[i]);
auto const_stride_next = *getConstantIntValue(strides[i - 1]);
if (const_stride_next == const_size * const_stride) {
sizes[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), const_size * const_size_next);
offsets[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(),
const_stride_next * const_offset_next + const_offset);
offsets.erase(offsets.begin() + i - 1);
sizes.erase(sizes.begin() + i - 1);
strides.erase(strides.begin() + i - 1);
listsHaveChanged = true;
}
}
auto const_offset = getConstantIntValue(offsets[i]);
auto const_size = getConstantIntValue(sizes[i]);
auto const_stride = getConstantIntValue(strides[i]);
auto const_offset_prev = getConstantIntValue(offsets[i - 1]);
auto const_stride_prev = getConstantIntValue(strides[i - 1]);
if (!(const_offset && const_size && const_stride && const_offset_prev &&
const_stride_prev))
continue;
if (*const_stride_prev == *const_size * *const_stride)
listsHaveChanged |=
eraseWrapNStrideDim(builder, SmallVector<int>{i - 1}, offsets,
sizes, strides)
.succeeded();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module {
// CHECK: put @channel_3[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_4[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: get @channel_5[%c0, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c4, %c4, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_5[] (%alloc_0[%c0, %c18, %c0] [%c4, %c18, %c8] [%c8, %c32, %c1]) : (memref<1x6x6x32xi8, 1>)

func.func @test1(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand All @@ -93,6 +94,15 @@ module {
air.channel.get @channel_5[%c0, %c0] (%arg1[%arg2, %arg3] [%c32, %c32] [%c128, %c1]) : (memref<128x128xf32>)
}
}
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
%c8 = arith.constant 8 : index
%c192 = arith.constant 192 : index
%c1152 = arith.constant 1152 : index
%alloc_0 = memref.alloc() : memref<1x6x6x32xi8, 1>
scf.for %arg2 = %c0 to %c32 step %c8 {
air.channel.put @channel_5[] (%alloc_0[%c0, %c3, %c0, %arg2] [%c1, %c3, %c6, %c8] [%c1152, %c192, %c32, %c1]) : (memref<1x6x6x32xi8, 1>)
}
return %alloc : memref<128xf32>
}

Expand Down

0 comments on commit b53f1cf

Please sign in to comment.