Skip to content

Commit

Permalink
Add mlir ir test showing updated async token inheritance after L2 spl…
Browse files Browse the repository at this point in the history
…itting
  • Loading branch information
erwei-xilinx committed Dec 31, 2024
1 parent 99d310d commit 14ef130
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1746,3 +1746,136 @@ module {
return
}
}

// -----

// Scf.for and scf.parallel nest: check for async token inheritance.

// CHECK: air.segment
// CHECK: air.herd
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}}
// CHECK-NEXT: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}}
// CHECK: air.channel.put async @channel_0
// CHECK: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}} iter_args(%{{.*}} = %{{.*}})
// CHECK: scf.for %{{.*}} = %c0{{.*}} to %c512{{.*}} step %c256{{.*}} iter_args(%[[VAL0:.*]] = %{{.*}})
// CHECK: %[[GET0:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL0:.*]] = air.wait_all async [%[[GET0]]]
// CHECK: %[[GET1:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL1:.*]] = air.wait_all async [%[[GET1]]]
// CHECK: %[[GET2:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL2:.*]] = air.wait_all async [%[[GET2]]]
// CHECK: %[[GET3:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL3:.*]] = air.wait_all async [%[[GET3]]]
// CHECK: %[[GET4:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL4:.*]] = air.wait_all async [%[[GET4]]]
// CHECK: %[[GET5:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL5:.*]] = air.wait_all async [%[[GET5]]]
// CHECK: %[[GET6:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL6:.*]] = air.wait_all async [%[[GET6]]]
// CHECK: %[[GET7:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL7:.*]] = air.wait_all async [%[[GET7]]]
// CHECK: %[[GET8:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL8:.*]] = air.wait_all async [%[[GET8]]]
// CHECK: %[[GET9:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL9:.*]] = air.wait_all async [%[[GET9]]]
// CHECK: %[[GET10:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL10:.*]] = air.wait_all async [%[[GET10]]]
// CHECK: %[[GET11:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL11:.*]] = air.wait_all async [%[[GET11]]]
// CHECK: %[[GET12:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL12:.*]] = air.wait_all async [%[[GET12]]]
// CHECK: %[[GET13:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL13:.*]] = air.wait_all async [%[[GET13]]]
// CHECK: %[[GET14:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL14:.*]] = air.wait_all async [%[[GET14]]]
// CHECK: %[[GET15:.*]] = air.channel.get async [%[[VAL0]]] @channel_0
// CHECK-NEXT: %[[WAITALL15:.*]] = air.wait_all async [%[[GET15]]]
// CHECK-NEXT: %[[YIELDED:.*]] = air.wait_all async [%[[WAITALL0]], %[[WAITALL1]], %[[WAITALL2]], %[[WAITALL3]], %[[WAITALL4]], %[[WAITALL5]], %[[WAITALL6]], %[[WAITALL7]], %[[WAITALL8]], %[[WAITALL9]], %[[WAITALL10]], %[[WAITALL11]], %[[WAITALL12]], %[[WAITALL13]], %[[WAITALL14]], %[[WAITALL15]]]
// CHECK-NEXT: %[[PUT0:.*]] = air.channel.put async [%[[YIELDED]]] @channel_2
// CHECK-NEXT: %[[PUT1:.*]] = air.channel.put async [%[[YIELDED]]] @channel_2
// CHECK-NEXT: %[[PUT2:.*]] = air.channel.put async [%[[YIELDED]]] @channel_2
// CHECK-NEXT: %[[PUT3:.*]] = air.channel.put async [%[[YIELDED]]] @channel_2
// CHECK-NEXT: %[[YIELDED:.*]] = air.wait_all async [%[[PUT0]], %[[PUT1]], %[[PUT2]], %[[PUT3]]]
// CHECK-NEXT: scf.yield %[[YIELDED]]
// CHECK: scf.yield

module {
air.channel @channel_0 [4, 4]
air.channel @channel_1 [1, 1]
func.func @test13(%arg0: memref<512x512xbf16>) {
%c1 = arith.constant 1 : index
%0 = air.launch async (%arg1) in (%arg2=%c1) args(%arg3=%arg0) : memref<512x512xbf16> attributes {id = 1 : i32} {
%c0 = arith.constant 0 : index
%c1_0 = arith.constant 1 : index
%c512 = arith.constant 512 : index
%c256 = arith.constant 256 : index
%1 = air.wait_all async
%2 = scf.for %arg4 = %c0 to %c512 step %c256 iter_args(%arg5 = %1) -> (!air.async.token) {
%4 = scf.for %arg6 = %c0 to %c512 step %c256 iter_args(%arg7 = %arg5) -> (!air.async.token) {
%5 = air.channel.get async [%arg7] @channel_1[] (%arg3[%arg4, %arg6] [%c256, %c256] [%c512, %c1_0]) {id = 3 : i32} : (memref<512x512xbf16>)
scf.yield %5 : !air.async.token
}
scf.yield %4 : !air.async.token
}
%3 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c4096 = arith.constant 4096 : index
%c16384 = arith.constant 16384 : index
%c64 = arith.constant 64 : index
%c4 = arith.constant 4 : index
%c0_1 = arith.constant 0 : index
%c1_2 = arith.constant 1 : index
%c512_3 = arith.constant 512 : index
%c256_4 = arith.constant 256 : index
%async_token, %results = air.execute -> (memref<4x4x64x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<4x4x64x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<4x4x64x64xbf16, 1 : i32>
}
%async_token_5, %results_6 = air.execute -> (memref<4x4x16x16x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<4x4x16x16x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<4x4x16x16x4x4xbf16, 2 : i32>
}
%4 = air.herd @herd_0 async tile (%arg4, %arg5) in (%arg6=%c4, %arg7=%c4) args(%arg8=%results_6) : memref<4x4x16x16x4x4xbf16, 2 : i32> attributes {id = 5 : i32} {
%c1_9 = arith.constant 1 : index
%c4_10 = arith.constant 4 : index
%c16 = arith.constant 16 : index
%c4096_11 = arith.constant 4096 : index
%c16384_12 = arith.constant 16384 : index
%c512_13 = arith.constant 512 : index
%c0_14 = arith.constant 0 : index
%c256_15 = arith.constant 256 : index
scf.for %arg9 = %c0_14 to %c512_13 step %c256_15 {
scf.for %arg10 = %c0_14 to %c512_13 step %c256_15 {
%6 = air.channel.put async @channel_0[%arg4, %arg5] (%arg8[%arg4, %arg5, %c0_14, %c0_14, %c0_14, %c0_14] [%c1_9, %c1_9, %c16, %c4_10, %c16, %c4_10] [%c16384_12, %c4096_11, %c16, %c4_10, %c256_15, %c1_9]) {id = 23 : i32} : (memref<4x4x16x16x4x4xbf16, 2 : i32>)
}
}
}
%5 = scf.for %arg4 = %c0_1 to %c512_3 step %c256_4 iter_args(%arg5 = %async_token) -> (!air.async.token) {
%6 = scf.for %arg6 = %c0_1 to %c512_3 step %c256_4 iter_args(%arg7 = %arg5) -> (!air.async.token) {
%7 = scf.parallel (%arg8, %arg9) = (%c0_1, %c0_1) to (%c4, %c4) step (%c1_2, %c1_2) init (%arg7) -> !air.async.token {
%9 = air.channel.get async [%arg7] @channel_0[%arg8, %arg9] (%results[%arg8, %arg9, %c0_1, %c0_1] [%c1_2, %c1_2, %c64, %c64] [%c16384, %c4096, %c64, %c1_2]) {id = 22 : i32} : (memref<4x4x64x64xbf16, 1 : i32>)
scf.reduce(%9 : !air.async.token) {
^bb0(%arg10: !air.async.token, %arg11: !air.async.token):
%10 = air.wait_all async [%arg10, %arg11]
scf.reduce.return %10 : !air.async.token
}
}
%8 = air.channel.put async [%7] @channel_1[] (%results[%c0_1, %c0_1, %c0_1, %c0_1] [%c4, %c64, %c4, %c64] [%c16384, %c64, %c4096, %c1_2]) {id = 24 : i32} : (memref<4x4x64x64xbf16, 1 : i32>)
scf.yield %8 : !air.async.token
}
scf.yield %6 : !air.async.token
}
%async_token_7 = air.execute {
memref.dealloc %results_6 : memref<4x4x16x16x4x4xbf16, 2 : i32>
}
%async_token_8 = air.execute {
memref.dealloc %results : memref<4x4x64x64xbf16, 1 : i32>
}
}
}
return
}
}

0 comments on commit 14ef130

Please sign in to comment.