Skip to content

Commit

Permalink
Ensure that bcast specialization only propagates through single-opera…
Browse files Browse the repository at this point in the history
…nd affine.apply ops (Xilinx#485)
  • Loading branch information
erwei-xilinx authored Mar 11, 2024
1 parent d98f0a3 commit b332bc0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,6 @@ struct BroadcastDetection {
air::HerdOp hl_op = nullptr;
bool isVariantWrtHerdRows = false;
bool isVariantWrtHerdCols = false;
// Create an affine set to represent the broadcast pattern
auto ctx = dma_op->getContext();
for (auto v : loop_dep_history) {
// Check row-wise or col-wise broadcastable based on variance wrt herd
Expand Down Expand Up @@ -2183,6 +2182,7 @@ struct BroadcastDetection {
isVariantWrtHerdCols = true;
}

// Create an affine set to represent the broadcast pattern
if (hl_op && isVariantWrtHerdRows && !isVariantWrtHerdCols) {
auto numColsOp = dyn_cast<arith::ConstantIndexOp>(
hl_op.getSizeOperands()[1].getDefiningOp());
Expand Down
21 changes: 7 additions & 14 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,20 +516,13 @@ class AIRSpecializeDmaBroadcast
// operations in op history, last-in-first-out
for (std::vector<Operation *>::reverse_iterator i = op_history.rbegin();
i != op_history.rend(); ++i) {
if (auto air_region_op = dyn_cast<air::ExecuteOp>(*i)) {
assert(air_region_op.getBody().front().getOperations().size() ==
2 &&
"air::ExecuteOp should have only one child operation beside "
"the terminator");
// Get current scalar op
Operation *op = nullptr;
for (auto &child_op :
air_region_op.getBody().front().getOperations()) {
if (!dyn_cast<air::ExecuteTerminatorOp>(child_op))
op = &child_op;
}
if (auto exec_op = dyn_cast<air::ExecuteOp>(*i)) {
Operation *op = exec_op.getChildOp();
// If the async op is affine.apply
if (auto apply_op = dyn_cast<affine::AffineApplyOp>(op)) {
// Can only propagate affine.apply ops with single operand.
if (apply_op.getNumOperands() != 1)
return;
auto map = apply_op.getAffineMap();
for (unsigned j = 0; j < current_shape_expr.size(); j++) {
if (current_shape_expr[j]) {
Expand All @@ -538,8 +531,8 @@ class AIRSpecializeDmaBroadcast
// Remove dependence from scalar op to memcpyOp if present
auto async_memcpyOp =
dyn_cast<air::AsyncOpInterface>(memcpyOp.getOperation());
eraseAsyncDependencyFromAsyncOp(
async_memcpyOp, air_region_op.getAsyncToken());
eraseAsyncDependencyFromAsyncOp(async_memcpyOp,
exec_op.getAsyncToken());
}
}
}
Expand Down

0 comments on commit b332bc0

Please sign in to comment.