From b83a3e4b3131641c2ceb59cffad6c7b54cbb7054 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Mon, 18 Mar 2024 16:43:40 -0700 Subject: [PATCH] Fixup a minor issue with BufferMemrefToFuncArgs function which fails with bf16 (#504) --- mlir/lib/Conversion/AIRRtToIpuPass.cpp | 7 +- .../AIRRtToIpu/buffer_memref_to_args.mlir | 90 +++++++++++++++++++ 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AIRRtToIpuPass.cpp b/mlir/lib/Conversion/AIRRtToIpuPass.cpp index f18638b96..6e2abd7bb 100644 --- a/mlir/lib/Conversion/AIRRtToIpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToIpuPass.cpp @@ -1152,10 +1152,9 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase { memref = cast.getOperand(0); } // push back if unique - if (std::find(memrefs.begin(), memrefs.end(), dma.getMemref()) == - memrefs.end()) { - memrefs.push_back(dma.getMemref()); - memrefTypes.push_back(dma.getMemref().getType()); + if (std::find(memrefs.begin(), memrefs.end(), memref) == memrefs.end()) { + memrefs.push_back(memref); + memrefTypes.push_back(memref.getType()); } }); diff --git a/mlir/test/Conversion/AIRRtToIpu/buffer_memref_to_args.mlir b/mlir/test/Conversion/AIRRtToIpu/buffer_memref_to_args.mlir index 5e70056a8..c9599e426 100644 --- a/mlir/test/Conversion/AIRRtToIpu/buffer_memref_to_args.mlir +++ b/mlir/test/Conversion/AIRRtToIpu/buffer_memref_to_args.mlir @@ -111,3 +111,93 @@ module { return } } + +// ----- + +// Bf16 datatype support. + +// CHECK-LABEL: aie.device(ipu) +// CHECK: func.func @func2(%[[VAL_0:.*]]: memref<2097152xi32>, %[[VAL_1:.*]]: memref<2097152xi32>, %[[VAL_2:.*]]: memref<2097152xi32>) { +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][4, 8, 128, 128][0, 128, 1024]) {id = 0 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.ipu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32> + +module { + aie.device(ipu) { + aie.shim_dma_allocation @airMemcpyId26(S2MM, 0, 0) + memref.global "public" @airMemcpyId26 : memref<128x128xbf16, 1 : i32> + aie.shim_dma_allocation @airMemcpyId4(MM2S, 0, 0) + memref.global "public" @airMemcpyId4 : memref<128x256xbf16, 1 : i32> + aie.shim_dma_allocation @airMemcpyId10(MM2S, 0, 0) + memref.global "public" @airMemcpyId10 : memref<128x256xbf16, 1 : i32> + aie.shim_dma_allocation @airMemcpyId7(MM2S, 1, 0) + memref.global "public" @airMemcpyId7 : memref<256x128xbf16, 1 : i32> + aie.shim_dma_allocation @airMemcpyId13(MM2S, 1, 0) + memref.global "public" @airMemcpyId13 : memref<256x128xbf16, 1 : i32> + } {sym_name = "segment_0"} + func.func @func2() { + %c128_i64 = arith.constant 128 : i64 + %c8_i64 = arith.constant 8 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c256_i64 = arith.constant 256 : i64 + %c26_i32 = arith.constant 26 : i32 + %c7_i32 = arith.constant 7 : i32 + %c4_i32 = arith.constant 4 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c0 = arith.constant 0 : index + %0 = memref.alloc() : memref<2048x2048xbf16> + %1 = airrt.wait_all : !airrt.event + airrt.wait_all %1 + memref.assume_alignment %0, 64 : memref<2048x2048xbf16> + %2 = airrt.wait_all : !airrt.event + %3 = memref.alloc() : memref<2048x2048xbf16> + %4 = airrt.wait_all : !airrt.event + airrt.wait_all %4 + memref.assume_alignment %3, 64 : memref<2048x2048xbf16> + %5 = airrt.wait_all : !airrt.event + %6 = memref.alloc() : memref<2048x2048xbf16> + %7 = airrt.wait_all : !airrt.event + airrt.wait_all %7 + memref.assume_alignment %6, 64 : memref<2048x2048xbf16> + %8 = airrt.wait_all : !airrt.event + %9 = airrt.wait_all %8, %5, %2 : !airrt.event + affine.for %arg0 = 0 to 4 { + affine.for %arg1 = 0 to 4 { + %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg0] + %11 = airrt.wait_all : !airrt.event + %12 = airrt.wait_all %11 : !airrt.event + %13 = arith.index_cast %arg0 : index to i64 + %14 = arith.index_cast %arg1 : index to i64 + %15 = arith.index_cast %10 : index to i64 + %16 = airrt.dma_memcpy_nd(%c4_i32, %13, %14, %0[%c0_i64, %c0_i64, %15, %c0_i64], [%c1_i64, %c8_i64, %c128_i64, %c256_i64], [%c0_i64, %c256_i64, %c2048_i64]) {metadata = @airMemcpyId10} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event + %17 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1] + %18 = airrt.wait_all : !airrt.event + %19 = airrt.wait_all %18 : !airrt.event + %20 = arith.index_cast %arg0 : index to i64 + %21 = arith.index_cast %arg1 : index to i64 + %22 = arith.index_cast %17 : index to i64 + %23 = airrt.dma_memcpy_nd(%c7_i32, %20, %21, %3[%c0_i64, %c0_i64, %c0_i64, %22], [%c1_i64, %c1_i64, %c2048_i64, %c128_i64], [%c0_i64, %c0_i64, %c2048_i64]) {metadata = @airMemcpyId13} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event + %24 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg0] + %25 = airrt.wait_all : !airrt.event + %26 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1] + %27 = airrt.wait_all : !airrt.event + %28 = airrt.wait_all %27, %25 : !airrt.event + %29 = arith.index_cast %arg0 : index to i64 + %30 = arith.index_cast %arg1 : index to i64 + %31 = arith.index_cast %24 : index to i64 + %32 = arith.index_cast %26 : index to i64 + %33 = airrt.dma_memcpy_nd(%c26_i32, %29, %30, %6[%c0_i64, %c0_i64, %31, %32], [%c1_i64, %c1_i64, %c128_i64, %c128_i64], [%c0_i64, %c0_i64, %c2048_i64]) {metadata = @airMemcpyId26} : (i32, i64, i64, memref<2048x2048xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event + %p = airrt.segment_load "segment_0" : i64 + %34 = airrt.wait_all : !airrt.event + } + } + return + } +}