Skip to content

Commit

Permalink
Add mlir ir test showing wait all token folding
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx committed Dec 31, 2024
1 parent 1eba001 commit e3a86f4
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion mlir/test/Dialect/AIR/air_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,53 @@ func.func @wait_all_1() {
return
}

// CHECK: func.func @wait_all_2
// CHECK: scf.for
// CHECK: scf.for{{.*}}iter_args(%[[TOK0:.*]] = %{{.*}})
// CHECK: %[[GET0:.*]] = air.channel.get async [%[[TOK0]]] @channel_0
// CHECK: %[[GET1:.*]] = air.channel.get async [%[[TOK0]]] @channel_0
// CHECK: %[[GET2:.*]] = air.channel.get async [%[[TOK0]]] @channel_0
// CHECK: %[[GET3:.*]] = air.channel.get async [%[[TOK0]]] @channel_0
// CHECK: %[[PUT0:.*]] = air.channel.put async [%[[GET0]]] @channel_1
// CHECK: %[[PUT1:.*]] = air.channel.put async [%[[GET1]]] @channel_1
// CHECK: %[[YIELD:.*]] = air.wait_all async [%[[PUT0]], %[[PUT1]]]
// CHECK: scf.yield %[[YIELD]]
// CHECK: scf.yield

func.func @wait_all_2(%arg0: memref<1xi8>, %arg1: memref<1xi8>, %arg2: memref<1xi8>, %arg3: memref<1xi8>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c16384 = arith.constant 16384 : index
%0 = air.wait_all async
%6 = scf.for %arg4 = %c0 to %c512 step %c256 iter_args(%arg5 = %0) -> (!air.async.token) {
%7 = air.wait_all async [%arg5]
%8 = scf.for %arg6 = %c0 to %c512 step %c256 iter_args(%arg7 = %7) -> (!air.async.token) {
%9 = air.channel.get async [%arg7] @channel_0[%c0, %c0] (%arg0[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 6 : i32} : (memref<1xi8>)
%10 = air.wait_all async [%9]
%11 = air.channel.get async [%arg7] @channel_0[%c1, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 7 : i32} : (memref<1xi8>)
%12 = air.wait_all async [%11]
%13 = air.channel.get async [%arg7] @channel_0[%c2, %c0] (%arg2[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 8 : i32} : (memref<1xi8>)
%14 = air.wait_all async [%13]
%15 = air.channel.get async [%arg7] @channel_0[%c3, %c0] (%arg3[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 9 : i32} : (memref<1xi8>)
%16 = air.wait_all async [%15]
%41 = air.wait_all async [%10, %12, %14, %16]
%42 = air.channel.put async [%41] @channel_1[%c0, %c0] (%arg0[%c0, %c0, %c0, %c0] [%c1, %c64, %c4, %c64] [%c16384, %c64, %c4096, %c1]) {id = 22 : i32} : (memref<1xi8>)
%43 = air.channel.put async [%41] @channel_1[%c1, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c1, %c64, %c4, %c64] [%c16384, %c64, %c4096, %c1]) {id = 23 : i32} : (memref<1xi8>)
%46 = air.wait_all async [%42, %43]
scf.yield %46 : !air.async.token
}
scf.yield %8 : !air.async.token
}
return
}

// CHECK-LABEL: execute_0
// CHECK-NEXT: return
func.func @execute_0() {
Expand Down Expand Up @@ -616,4 +663,4 @@ func.func @func3(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: mem
scf.yield %21 : !air.async.token
}
return
}
}

0 comments on commit e3a86f4

Please sign in to comment.