From e3a86f48740bc394a68150c349f4d4493960934e Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Tue, 31 Dec 2024 01:11:25 -0800 Subject: [PATCH] Add mlir ir test showing wait all token folding --- mlir/test/Dialect/AIR/air_canonicalize.mlir | 49 ++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/AIR/air_canonicalize.mlir b/mlir/test/Dialect/AIR/air_canonicalize.mlir index b68524c78..c6c27e50b 100644 --- a/mlir/test/Dialect/AIR/air_canonicalize.mlir +++ b/mlir/test/Dialect/AIR/air_canonicalize.mlir @@ -210,6 +210,53 @@ func.func @wait_all_1() { return } +// CHECK: func.func @wait_all_2 +// CHECK: scf.for +// CHECK: scf.for{{.*}}iter_args(%[[TOK0:.*]] = %{{.*}}) +// CHECK: %[[GET0:.*]] = air.channel.get async [%[[TOK0]]] @channel_0 +// CHECK: %[[GET1:.*]] = air.channel.get async [%[[TOK0]]] @channel_0 +// CHECK: %[[GET2:.*]] = air.channel.get async [%[[TOK0]]] @channel_0 +// CHECK: %[[GET3:.*]] = air.channel.get async [%[[TOK0]]] @channel_0 +// CHECK: %[[PUT0:.*]] = air.channel.put async [%[[GET0]]] @channel_1 +// CHECK: %[[PUT1:.*]] = air.channel.put async [%[[GET1]]] @channel_1 +// CHECK: %[[YIELD:.*]] = air.wait_all async [%[[PUT0]], %[[PUT1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: scf.yield + +func.func @wait_all_2(%arg0: memref<1xi8>, %arg1: memref<1xi8>, %arg2: memref<1xi8>, %arg3: memref<1xi8>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4096 = arith.constant 4096 : index + %c16384 = arith.constant 16384 : index + %0 = air.wait_all async + %6 = scf.for %arg4 = %c0 to %c512 step %c256 iter_args(%arg5 = %0) -> (!air.async.token) { + %7 = air.wait_all async [%arg5] + %8 = scf.for %arg6 = %c0 to %c512 step %c256 iter_args(%arg7 = %7) -> (!air.async.token) { + %9 = air.channel.get async [%arg7] @channel_0[%c0, %c0] (%arg0[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 6 : i32} : (memref<1xi8>) + %10 = air.wait_all async [%9] + %11 = air.channel.get async [%arg7] @channel_0[%c1, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 7 : i32} : (memref<1xi8>) + %12 = air.wait_all async [%11] + %13 = air.channel.get async [%arg7] @channel_0[%c2, %c0] (%arg2[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 8 : i32} : (memref<1xi8>) + %14 = air.wait_all async [%13] + %15 = air.channel.get async [%arg7] @channel_0[%c3, %c0] (%arg3[%c0, %c0, %c0, %c0] [%c1, %c1, %c64, %c64] [%c16384, %c4096, %c64, %c1]) {id = 9 : i32} : (memref<1xi8>) + %16 = air.wait_all async [%15] + %41 = air.wait_all async [%10, %12, %14, %16] + %42 = air.channel.put async [%41] @channel_1[%c0, %c0] (%arg0[%c0, %c0, %c0, %c0] [%c1, %c64, %c4, %c64] [%c16384, %c64, %c4096, %c1]) {id = 22 : i32} : (memref<1xi8>) + %43 = air.channel.put async [%41] @channel_1[%c1, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c1, %c64, %c4, %c64] [%c16384, %c64, %c4096, %c1]) {id = 23 : i32} : (memref<1xi8>) + %46 = air.wait_all async [%42, %43] + scf.yield %46 : !air.async.token + } + scf.yield %8 : !air.async.token + } + return +} + // CHECK-LABEL: execute_0 // CHECK-NEXT: return func.func @execute_0() { @@ -616,4 +663,4 @@ func.func @func3(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: mem scf.yield %21 : !air.async.token } return -} \ No newline at end of file +}