Skip to content

Commit

Permalink
Extend canonicalizeFalseDependencies method to also analyze scf.for a…
Browse files Browse the repository at this point in the history
…nd scf.parallel loops (#832)
  • Loading branch information
erwei-xilinx authored Dec 31, 2024
1 parent af36e31 commit a2e4acc
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 6 deletions.
31 changes: 25 additions & 6 deletions mlir/lib/Dialect/AIR/IR/AIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ CanonicalizeAsyncLoopCarriedDepsInRegion(OpT op, PatternRewriter &rewriter) {
return success();
}

// Break any faulty async dependencies.
// Break any wrong async dependencies.
template <class T>
static LogicalResult canonicalizeFalseDependencies(T op,
PatternRewriter &rewriter) {
Expand All @@ -320,11 +320,29 @@ static LogicalResult canonicalizeFalseDependencies(T op,
};
auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
llvm::SetVector<Value> memrefs;
auto memrefOpers = getMemrefsFromVec(o->getOperands());
memrefs.insert(memrefOpers.begin(), memrefOpers.end());
for (auto &region : o->getRegions()) {
SmallVector<Value> vals = o->getOperands();
vals.insert(vals.end(), o->getResults().begin(), o->getResults().end());
SmallVector<Region *> regions;
for (auto &region : o->getRegions())
regions.push_back(&region);
// If air.wait_all, then we analyze the dependency by collecting all
// operations that depend on it.
auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
if (waitAllOp && waitAllOp.getAsyncToken()) {
for (auto user : waitAllOp.getAsyncToken().getUsers()) {
vals.insert(vals.end(), user->getOperands().begin(),
user->getOperands().end());
vals.insert(vals.end(), user->getResults().begin(),
user->getResults().end());
for (auto &region : user->getRegions())
regions.push_back(&region);
}
}
auto memrefvals = getMemrefsFromVec(vals);
memrefs.insert(memrefvals.begin(), memrefvals.end());
for (auto region : regions) {
llvm::SetVector<Value> usedVals;
getUsedValuesDefinedAbove(region, usedVals);
getUsedValuesDefinedAbove(*region, usedVals);
auto usedMemrefs = getMemrefsFromVec(usedVals.takeVector());
memrefs.insert(usedMemrefs.begin(), usedMemrefs.end());
}
Expand All @@ -334,7 +352,7 @@ static LogicalResult canonicalizeFalseDependencies(T op,
auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp(op.getOperation());
if (memrefsTouchedByOp.empty())
return failure();
auto depList = asyncOp.getAsyncDependencies();
SmallVector<Value> depList = asyncOp.getAsyncDependencies();
for (int i = depList.size() - 1; i >= 0; i--) {
auto tokDefOp = depList[i].getDefiningOp();
if (!tokDefOp)
Expand Down Expand Up @@ -1244,6 +1262,7 @@ void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(FoldWaitAll);
patterns.add(CanonicalizeAsyncOpDeps<WaitAllOp>);
patterns.add(canonicalizeFalseDependencies<WaitAllOp>);
}

// Get strides from MemRefType.
Expand Down
112 changes: 112 additions & 0 deletions mlir/test/Dialect/AIR/air_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,115 @@ func.func @func1() {
}
return
}

// CHECK: func.func @func2
// CHECK: %[[TOK0:.*]], %[[RES0:.*]] = air.execute
// CHECK-NEXT: memref.alloc()
// CHECK-NEXT: air.execute_terminator
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK1:.*]] = %[[TOK0]])
// CHECK-NEXT: air.channel.put async [%[[TOK1]]] @channel_3

func.func @func2(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: memref<1024xi32>) {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c256_2 = arith.constant 256 : index
%c1_3 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0_4 = arith.constant 0 : index
%c2048_5 = arith.constant 2048 : index
%c128_6 = arith.constant 128 : index
%async_token_7, %results_8 = air.execute -> (memref<2048xi8, 1>) {
%alloc = memref.alloc() : memref<2048xi8, 1>
air.execute_terminator %alloc : memref<2048xi8, 1>
}
%async_token_9, %results_10 = air.execute -> (memref<2048x256xi8, 1>) {
%alloc = memref.alloc() : memref<2048x256xi8, 1>
air.execute_terminator %alloc : memref<2048x256xi8, 1>
}
%async_token_11, %results_12 = air.execute -> (memref<256xi32, 1>) {
%alloc = memref.alloc() : memref<256xi32, 1>
air.execute_terminator %alloc : memref<256xi32, 1>
}
%async_token_13, %results_14 = air.execute -> (index) {
air.execute_terminator %c0_4 : index
}
%10 = air.wait_all async [%async_token_7, %async_token_9, %async_token_11, %async_token_13]
%11 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %10) -> (!air.async.token) {
%21 = air.channel.put async [%arg9] @channel_3[%c0_4, %c0_4] (%results_10[%c0_4, %c0_4, %arg8, %results_14] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256_2, %c1_3]) {id = 7 : i32} : (memref<2048x256xi8, 1>)
scf.yield %21 : !air.async.token
}
return
}

// CHECK: func.func @func3
// CHECK: %[[TOK0:.*]], %[[RES0:.*]] = air.execute
// CHECK-NEXT: memref.alloc()
// CHECK-NEXT: air.execute_terminator
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK1:.*]] = %[[TOK0]])
// CHECK-NEXT: air.channel.get async [%[[TOK1]]] @channel_1
// CHECK: %[[TOK2:.*]], %[[RES1:.*]] = air.execute
// CHECK-NEXT: memref.alloc()
// CHECK-NEXT: air.execute_terminator
// CHECK: %[[TOK6:.*]] = scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK3:.*]] = %[[TOK2]])
// CHECK-NEXT: air.channel.get async [%[[TOK3]]] @channel_2
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK5:.*]] = %[[TOK0]])
// CHECK-NEXT: air.channel.put async [%[[TOK5]]] @channel_0
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK7:.*]] = %[[TOK6]])
// CHECK-NEXT: air.channel.put async [%[[TOK7]]] @channel_3
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c2048{{.*}} step %c128{{.*}} iter_args(%[[TOK8:.*]] = %[[TOK6]])
// CHECK-NEXT: air.channel.put async [%[[TOK8]]] @channel_3

func.func @func3(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: memref<1024xi32>) {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c256_2 = arith.constant 256 : index
%c1_3 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0_4 = arith.constant 0 : index
%c2048_5 = arith.constant 2048 : index
%c128_6 = arith.constant 128 : index
%async_token_7, %results_8 = air.execute -> (memref<2048xi8, 1>) {
%alloc = memref.alloc() : memref<2048xi8, 1>
air.execute_terminator %alloc : memref<2048xi8, 1>
}
%6 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %async_token_7) -> (!air.async.token) {
%21 = air.channel.get async [%arg9] @channel_1[] (%results_8[%arg8] [%c128_6] [%c1_3]) {id = 4 : i32} : (memref<2048xi8, 1>)
scf.yield %21 : !air.async.token
}
%async_token_9, %results_10 = air.execute -> (memref<2048x256xi8, 1>) {
%alloc = memref.alloc() : memref<2048x256xi8, 1>
air.execute_terminator %alloc : memref<2048x256xi8, 1>
}
%7 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %async_token_9) -> (!air.async.token) {
%21 = air.channel.get async [%arg9] @channel_2[] (%results_10[%arg8, %c0_4] [%c128_6, %c256_2] [%c256_2, %c1_3]) {id = 5 : i32} : (memref<2048x256xi8, 1>)
scf.yield %21 : !air.async.token
}
%async_token_11, %results_12 = air.execute -> (memref<256xi32, 1>) {
%alloc = memref.alloc() : memref<256xi32, 1>
air.execute_terminator %alloc : memref<256xi32, 1>
}
%8 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %async_token_7) -> (!air.async.token) {
%21 = air.channel.put async [%arg9] @channel_0[] (%results_8[%c0_4, %arg8] [%c8, %c16] [%c16, %c1_3]) {id = 6 : i32} : (memref<2048xi8, 1>)
scf.yield %21 : !air.async.token
}
%9 = air.wait_all async [%6, %7, %async_token_11]
%async_token_13, %results_14 = air.execute -> (index) {
air.execute_terminator %c0_4 : index
}
%10 = air.wait_all async [%9, %async_token_13]
%11 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %10) -> (!air.async.token) {
%21 = air.channel.put async [%arg9] @channel_3[%c0_4, %c0_4] (%results_10[%c0_4, %c0_4, %arg8, %results_14] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256_2, %c1_3]) {id = 7 : i32} : (memref<2048x256xi8, 1>)
scf.yield %21 : !air.async.token
}
%async_token_15, %results_16 = air.execute -> (index) {
air.execute_terminator %c128_6 : index
}
%12 = air.wait_all async [%9, %async_token_15]
%13 = scf.for %arg8 = %c0_4 to %c2048_5 step %c128_6 iter_args(%arg9 = %12) -> (!air.async.token) {
%21 = air.channel.put async [%arg9] @channel_3[%c1_3, %c0_4] (%results_10[%c0_4, %c0_4, %arg8, %results_16] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256_2, %c1_3]) {id = 7 : i32} : (memref<2048x256xi8, 1>)
scf.yield %21 : !air.async.token
}
return
}

0 comments on commit a2e4acc

Please sign in to comment.