Skip to content

Commit

Permalink
AIRPingpongTransform: Tokens may not always get passed into async `sc…
Browse files Browse the repository at this point in the history
…f.for` via `init_args` (Xilinx#756)

* Tokens used inside scf.for but declared outside should be handled the same way as init_args

* Re-enable pingpong buffering for vecmat example
  • Loading branch information
erwei-xilinx authored Oct 29, 2024
1 parent 5131fd2 commit 3d1a4e1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
25 changes: 18 additions & 7 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,21 @@ struct AnnotateFrontAndBackOpsInForPattern
continue;

if (!dep_list.size())
op.setAttr("async_front", rewriter.getBoolAttr(true));
for (auto token : iterTokens) {
for (auto dep : dep_list) {
if (token == dep) {
setBoolAttrForAsyncOp(rewriter, &op, "async_front");
}
}
setBoolAttrForAsyncOp(rewriter, &op, "async_front");
for (auto dep : dep_list) {
// Token is in iter_args
if (llvm::any_of(iterTokens,
[dep](Value token) { return token == dep; }))
setBoolAttrForAsyncOp(rewriter, &op, "async_front");
}
// Token is declared outside of for loop
if (llvm::any_of(dep_list, [for_op](Value token) {
auto tokenDefOp = token.getDefiningOp();
if (!tokenDefOp)
return false;
return !for_op->isProperAncestor(tokenDefOp);
})) {
setBoolAttrForAsyncOp(rewriter, &op, "async_front");
}
}

Expand Down Expand Up @@ -649,6 +657,9 @@ struct AnnotateFrontAndBackOpsInForPattern
}
for (auto op : back_candidates) {
setBoolAttrForAsyncOp(rewriter, op, "async_back");
if (op->hasAttr("async_front"))
// An op cannot be both "async_back" and "async_front".
op->removeAttr("async_front");
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,55 @@ func.func @test(%arg0: memref<256x1024xbf16>, %arg1: memref<1024x1024xbf16>, %ar
}
return
}

// Label async_front based on tokens declared outside of for loop.
// CHECK-LABEL: test1
// CHECK: air.segment
// CHECK: air.wait_all async
// CHECK: air.wait_all async
// CHECK: scf.for
// CHECK: air.channel.get{{.*}}async_front = true
// CHECK: air.channel.get{{.*}}async_front = true
// CHECK: air.wait_all async{{.*}}async_back = true

func.func @test1(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: memref<1024xi32>) {
%c4 = arith.constant 4 : index
%0 = air.launch async (%arg3) in (%arg4=%c4) {
%1 = air.segment @vecmat_i8_0 async {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c256 = arith.constant 256 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c128 = arith.constant 128 : index
%2 = air.wait_all async
%3 = air.wait_all async
%11 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg9 = %3) -> (!air.async.token) {
%async_token, %results = air.execute -> (memref<128xi8, 1>) {
%alloc = memref.alloc() {hoist_alloc = true} : memref<128xi8, 1>
air.execute_terminator %alloc : memref<128xi8, 1>
}
%async_token_0, %results_1 = air.execute -> (memref<128x256xi8, 1>) {
%alloc = memref.alloc() {hoist_alloc = true} : memref<128x256xi8, 1>
air.execute_terminator %alloc : memref<128x256xi8, 1>
}
%4 = air.channel.get async [%2] @channel_1[] (%results[%arg5] [%c128] [%c1]) {id = 4 : i32} : (memref<128xi8, 1>)
%5 = air.channel.get async [%3] @channel_2[] (%results_1[%arg5, %c0] [%c128, %c256] [%c256, %c1]) {id = 5 : i32} : (memref<128x256xi8, 1>)
%6 = air.channel.put async [%4] @channel_0[] (%results[%c0, %arg5] [%c8, %c16] [%c16, %c1]) {id = 6 : i32} : (memref<128xi8, 1>)
%7 = air.channel.put async [%5] @channel_3[%c0, %c0] (%results_1[%c0, %c0, %arg5, %c0] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256, %c1]) {id = 7 : i32} : (memref<128x256xi8, 1>)
%8 = air.channel.put async [%5] @channel_3[%c1, %c0] (%results_1[%c0, %c0, %arg5, %c128] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256, %c1]) {id = 7 : i32} : (memref<128x256xi8, 1>)
%async_token_2 = air.execute {
memref.dealloc %results : memref<128xi8, 1>
}
%async_token_3 = air.execute {
memref.dealloc %results_1 : memref<128x256xi8, 1>
}
%9 = air.wait_all async [%async_token, %async_token_0, %4, %5, %6, %7, %8, %async_token_2, %async_token_3]
scf.yield %9 : !air.async.token
} {isolated = true, unroll = 2 : i32}
}
}
return
}
8 changes: 4 additions & 4 deletions test/xrt/26_vecmat_i8/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@
"func.func(air-split-l2-memref)",
"air-isolate-async-dma-loop-nests",
"func.func(air-loop-fusion)",
# "air-label-scf-for-to-ping-pong",
# "air-ping-pong-transform{keep-memref-dealloc=true}",
# "canonicalize",
# "cse",
"air-label-scf-for-to-ping-pong",
"air-ping-pong-transform{keep-memref-dealloc=true}",
"canonicalize",
"cse",
"air-specialize-channel-wrap-and-stride",
"canonicalize",
"cse",
Expand Down

0 comments on commit 3d1a4e1

Please sign in to comment.