Skip to content

Commit

Permalink
AIRSegmentLoopFusion: A number of fixups and improvements around asyn…
Browse files Browse the repository at this point in the history
…c dependency (Xilinx#748)

* Trace dep token users through air.wait_all; move fused loop to before the last loop being fused for ssa dominance; more informative failure message when broken dependence is detected after fusion

* Unit test checking for a complex case which was failing before
  • Loading branch information
erwei-xilinx authored Oct 22, 2024
1 parent 7372bd7 commit 24cb14e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 3 deletions.
1 change: 1 addition & 0 deletions mlir/include/air/Util/Dependency.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Value getAsyncTokenFromOp(Operation *op);
void addAsyncDependencyIfNew(Operation *op, Value token);
bool isAsyncOp(Operation *op);
bool areAsyncDependent(Operation *a, Operation *b);
bool isAsyncDependent(Operation *a, Operation *b);
scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
SmallVector<Operation *> target_ops);
LogicalResult unrollAIRChannelPutGetInScfParallel(OpBuilder builder,
Expand Down
30 changes: 27 additions & 3 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4759,6 +4759,22 @@ struct ShrinkMemrefSizesByAccessPattern
}
};

// Get all users to the async op's async token, with type T.
template <typename T>
SmallVector<T> getTokenUsersOfType(air::AsyncOpInterface asyncOp) {
SmallVector<T> tokenUsers;
Value token = asyncOp.getAsyncToken();
for (auto token_user : token.getUsers()) {
if (auto token_user_of_type = dyn_cast<T>(token_user))
tokenUsers.push_back(token_user_of_type);
else if (auto token_user_wait_all = dyn_cast<air::WaitAllOp>(token_user))
for (auto wa_user : token_user_wait_all.getAsyncToken().getUsers())
if (auto token_user_of_type = dyn_cast<T>(wa_user))
tokenUsers.push_back(token_user_of_type);
}
return tokenUsers;
}

struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
using OpRewritePattern<air::SegmentOp>::OpRewritePattern;

Expand Down Expand Up @@ -4866,15 +4882,15 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
if (llvm::any_of(
alloc_dealloc_execs,
[&](std::pair<air::ExecuteOp, air::ExecuteOp> exec_pair) {
return exec_pair.first == iaDefOp;
return isAsyncDependent(exec_pair.first, iaDefOp);
}))
fusableForOps.push_back(forOp);
}
}
if (fusableForOps.empty())
return failure();

rewriter.setInsertionPoint(equalIterationForOps[0]);
rewriter.setInsertionPoint(equalIterationForOps.back());
auto new_loop_op_init_arg =
rewriter
.create<air::WaitAllOp>(
Expand All @@ -4891,7 +4907,8 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
for (auto execOpPair : alloc_dealloc_execs) {
bool canMove = false;
air::ExecuteOp alloc_exec = execOpPair.first;
for (auto token_user : alloc_exec.getAsyncToken().getUsers())
auto token_users = getTokenUsersOfType<scf::ForOp>(alloc_exec);
for (auto token_user : token_users)
if (llvm::any_of(equalIterationForOps, [&](scf::ForOp fusableForOp) {
return fusableForOp == token_user;
}))
Expand Down Expand Up @@ -4970,6 +4987,13 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
put_parent = put_parent->getParentOp();
}
Operation *get_parent = getOp;
if (!get_parent) {
putOp->emitOpError(
"is producing data for memref in the fused scf.for loop, but no "
"consumer is found for this data within the fused loop. This "
"likely indicates a failure in the compiler pass.");
return;
}
while (get_parent->getParentOp() != new_loop_op) {
get_parent = get_parent->getParentOp();
}
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,28 @@ bool areAsyncDependent(Operation *a, Operation *b) {
return false;
}

// Returns true if b is asynchronously dependent on a. This function performs a
// deep dependency tracing that propagates through air.wait_all ops.
bool isAsyncDependent(Operation *a, Operation *b) {
if (a == b)
return true;
Value token_a = getAsyncTokenFromOp(a);
SmallVector<Value> dep_b = getAsyncDependenciesFromOp(b);
if (!token_a)
return false;
if (dep_b.empty())
return false;
for (auto dep : dep_b) {
if (dep == token_a)
return true;
else if (auto dep_wa_defop = dep.getDefiningOp<air::WaitAllOp>()) {
if (isAsyncDependent(a, dep_wa_defop))
return true;
}
}
return false;
}

// Splits an SCF for loop into two for loops, by hoisting target operations in
// for loop to a new for loop located at the same scope.
scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,75 @@ func.func @func9(%arg0: memref<512x256xi8>, %arg1: memref<256x32xi8>) {
}
return
}

// Scf.parallel unrolling pre-proc., with loop tiling.

// CHECK-LABEL: func.func @func10
// CHECK: air.segment
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}}
// CHECK: air.channel.get async [{{.*}}] @channel_2[]
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}}
// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c0{{.*}}]
// CHECK-NEXT: scf.yield
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}}
// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c1{{.*}}]
// CHECK-NEXT: scf.yield
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}}
// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c2{{.*}}]
// CHECK-NEXT: scf.yield
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c256{{.*}} step %c32{{.*}}
// CHECK-NEXT: air.channel.put async [{{.*}}] @channel_3[%c0{{.*}}, %c3{{.*}}]
// CHECK-NEXT: scf.yield
// CHECK: scf.yield

#map15 = affine_map<()[s0] -> (s0 * 32)>
#map16 = affine_map<()[s0] -> (s0 * 8)>
func.func @func10(%arg0: memref<8x512xi32>, %arg1: memref<256x512xi32>, %arg2: memref<8x256xi32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = air.launch async (%arg3, %arg4) in (%arg5=%c1, %arg6=%c2) attributes {id = 1 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c128 = arith.constant 128 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c1_0 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c256 = arith.constant 256 : index
%async_token, %results = air.execute -> (memref<128x512xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<128x512xi32, 1 : i32>
air.execute_terminator %alloc : memref<128x512xi32, 1 : i32>
}
%2 = scf.for %arg7 = %c0 to %c512 step %c256 iter_args(%arg8 = %async_token) -> (!air.async.token) {
%4 = air.channel.get async [%arg8] @channel_2[] (%results[%c0, %arg7] [%c128, %c256] [%c512, %c1_0]) {id = 5 : i32} : (memref<128x512xi32, 1 : i32>)
scf.yield %4 : !air.async.token
}
%3 = scf.parallel (%arg7) = (%c0) to (%c4) step (%c1_0) init (%async_token) -> !air.async.token {
%async_token_2, %results_3 = air.execute -> (index) {
%6 = affine.apply #map15()[%arg7]
air.execute_terminator %6 : index
}
%4 = air.wait_all async [%async_token, %async_token_2]
%5 = scf.for %arg8 = %c0 to %c64 step %c4 iter_args(%arg9 = %4) -> (!air.async.token) {
%async_token_4, %results_5 = air.execute [%arg9] -> (index) {
%7 = affine.apply #map16()[%arg8]
air.execute_terminator %7 : index
}
%6 = air.channel.put async [%async_token_4] @channel_3[%c0, %arg7] (%results[%c0, %c0, %results_3, %results_5] [%c4, %c8, %c4, %c8] [%c8, %c2048, %c512, %c1_0]) {id = 7 : i32} : (memref<128x512xi32, 1 : i32>)
scf.yield %6 : !air.async.token
}
scf.reduce(%5 : !air.async.token) {
^bb0(%arg8: !air.async.token, %arg9: !air.async.token):
%6 = air.wait_all async [%arg8, %arg9]
scf.reduce.return %6 : !air.async.token
}
}
%async_token_1 = air.execute {
memref.dealloc %results : memref<128x512xi32, 1 : i32>
}
}
}
return
}

0 comments on commit 24cb14e

Please sign in to comment.