From f93afb746b77c8d6561f5b7088c4e70bfb99063e Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 6 Mar 2024 13:14:19 -0800 Subject: [PATCH] Expand on broadcast detection conditions (#471) * Expand on broadcast detection conditions * Clang format * Clang format * Clang format --- .../Transform/AIRDependencyScheduleOpt.cpp | 54 ++++---- .../broadcast_detection.mlir | 119 ++++++++++++------ 2 files changed, 111 insertions(+), 62 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 0f35b9a87..6be236c81 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -2138,23 +2138,46 @@ struct BroadcastDetection { auto dma_op = dma_op_history[i]; SmallVector loop_dep_history = dma_op_loop_dep_history[i]; air::HerdOp hl_op = nullptr; - bool hasDepInHerdRows = false; - bool hasDepInHerdCols = false; + bool isVariantWrtHerdRows = false; + bool isVariantWrtHerdCols = false; // Create an affine set to represent the broadcast pattern auto ctx = dma_op->getContext(); for (auto v : loop_dep_history) { + // Check row-wise or col-wise broadcastable based on variance wrt herd + // dimensions. if (getHerdArgOwner(v)) { hl_op = getHerdArgOwner(v); if (v == hl_op.getIds()[0]) { - hasDepInHerdRows = true; + isVariantWrtHerdRows = true; } if (v == hl_op.getIds()[1]) { - hasDepInHerdCols = true; + isVariantWrtHerdCols = true; } } } - - if (hl_op && hasDepInHerdRows && !hasDepInHerdCols) { + // If not variant wrt herd, then check for fixed row-wise or col-wise + // offset. + int src_memspace = dma_op.getSrcMemref() + .getType() + .cast() + .getMemorySpaceAsInt(); + int dst_memspace = dma_op.getDstMemref() + .getType() + .cast() + .getMemorySpaceAsInt(); + auto externalOffsets = src_memspace == (int)air::MemorySpace::L1 + ? dma_op.getDstOffsets() + : dma_op.getSrcOffsets(); + if (!hl_op && externalOffsets.size() == + dma_op->getParentOfType().getNumDims()) { + hl_op = dma_op->getParentOfType(); + if (getConstantIntValue(externalOffsets[0])) + isVariantWrtHerdRows = true; + if (getConstantIntValue(externalOffsets[1])) + isVariantWrtHerdCols = true; + } + + if (hl_op && isVariantWrtHerdRows && !isVariantWrtHerdCols) { auto numColsOp = dyn_cast( hl_op.getSizeOperands()[1].getDefiningOp()); auto numCols = numColsOp.value(); @@ -2169,7 +2192,7 @@ struct BroadcastDetection { dma_op->setAttr("broadcast_pattern", mlir::IntegerSetAttr::get(int_set)); } - } else if (hl_op && !hasDepInHerdRows && hasDepInHerdCols) { + } else if (hl_op && !isVariantWrtHerdRows && isVariantWrtHerdCols) { auto numRowsOp = dyn_cast( hl_op.getSizeOperands()[0].getDefiningOp()); auto numRows = numRowsOp.value(); @@ -2184,23 +2207,6 @@ struct BroadcastDetection { dma_op->setAttr("broadcast_pattern", mlir::IntegerSetAttr::get(int_set)); } - } else if (hl_op && !hasDepInHerdRows && !hasDepInHerdCols) { - auto numRowsOp = dyn_cast( - hl_op.getSizeOperands()[0].getDefiningOp()); - auto numRows = numRowsOp.value(); - auto numColsOp = dyn_cast( - hl_op.getSizeOperands()[1].getDefiningOp()); - auto numCols = numColsOp.value(); - if (numCols > 1 && numRows > 1) { - SmallVector constraints{ - getAffineDimExpr(0, ctx), numRows - 1 - getAffineDimExpr(0, ctx), - getAffineDimExpr(1, ctx), numCols - 1 - getAffineDimExpr(1, ctx), - getAffineSymbolExpr(0, ctx)}; - SmallVector eqflags{false, false, false, false, true}; - auto int_set = IntegerSet::get(2, 1, constraints, eqflags); - dma_op->setAttr("broadcast_pattern", - mlir::IntegerSetAttr::get(int_set)); - } } } } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir index 0c133caf3..cd8d186fc 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir @@ -6,59 +6,102 @@ // //===----------------------------------------------------------------------===// -// RUN: air-opt %s -air-dependency -air-broadcast-detection | FileCheck %s +// RUN: air-opt %s -air-dependency -air-broadcast-detection --split-input-file | FileCheck %s // Detects broadcast pattern for DMAs // CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 + 1 >= 0)> // CHECK: [[$SET1:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 1 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 1 >= 0)> +// CHECK-LABEL: func.func @matmul // CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}} // CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET1]]{{.*}} #map = affine_map<()[s0] -> (s0 * 32)> +func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg2: memref<512x512xbf16>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %0 = memref.alloc() {alignment = 128 : i64} : memref<512x512xbf16> + memref.copy %arg2, %0 : memref<512x512xbf16> to memref<512x512xbf16> + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c512, %c512) step (%c64, %c64) { + scf.for %arg5 = %c0 to %c512 step %c64 { + %1 = memref.alloc() : memref<64x64xbf16, 1> + %2 = memref.alloc() : memref<64x64xbf16, 1> + %3 = memref.alloc() : memref<64x64xbf16, 1> + air.dma_memcpy_nd (%1[] [] [], %arg0[%arg3, %arg5] [%c64, %c64] [%c512, %c1]) {id = 1 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) + air.dma_memcpy_nd (%2[] [] [], %arg1[%arg5, %arg4] [%c64, %c64] [%c512, %c1]) {id = 2 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) + air.dma_memcpy_nd (%3[] [] [], %0[%arg3, %arg4] [%c64, %c64] [%c512, %c1]) {id = 3 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) + air.herd tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c2) args(%arg10=%1, %arg11=%2, %arg12=%3) : memref<64x64xbf16, 1>, memref<64x64xbf16, 1>, memref<64x64xbf16, 1> attributes {sym_name = "herd_0"} { + %c1_0 = arith.constant 1 : index + %c0_1 = arith.constant 0 : index + %c64_2 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %4 = affine.apply #map()[%arg6] + %5 = affine.apply #map()[%arg7] + scf.for %arg13 = %c0_1 to %c64_2 step %c32 { + %6 = memref.alloc() : memref<32x32xbf16, 2> + %7 = memref.alloc() : memref<32x32xbf16, 2> + %8 = memref.alloc() : memref<32x32xbf16, 2> + air.dma_memcpy_nd (%6[] [] [], %arg10[%4, %arg13] [%c32, %c32] [%c64_2, %c1_0]) {id = 4 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) + air.dma_memcpy_nd (%7[] [] [], %arg11[%arg13, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 5 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) + air.dma_memcpy_nd (%8[] [] [], %arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 6 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) + linalg.matmul {cast = #linalg.type_fn} ins(%6, %7 : memref<32x32xbf16, 2>, memref<32x32xbf16, 2>) outs(%8 : memref<32x32xbf16, 2>) + air.dma_memcpy_nd (%arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0], %8[] [] []) {id = 7 : i32} : (memref<64x64xbf16, 1>, memref<32x32xbf16, 2>) + memref.dealloc %6 : memref<32x32xbf16, 2> + memref.dealloc %7 : memref<32x32xbf16, 2> + memref.dealloc %8 : memref<32x32xbf16, 2> + } + air.herd_terminator + } + air.dma_memcpy_nd (%0[%arg3, %arg4] [%c64, %c64] [%c512, %c1], %3[] [] []) {id = 8 : i32} : (memref<512x512xbf16>, memref<64x64xbf16, 1>) + memref.dealloc %1 : memref<64x64xbf16, 1> + memref.dealloc %2 : memref<64x64xbf16, 1> + memref.dealloc %3 : memref<64x64xbf16, 1> + } + } + return +} + +// ----- + +// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 3 >= 0)> +// CHECK-LABEL: func.func @func0 +// CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}} {id = 1 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>) +// CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}} + +#map = affine_map<()[s0] -> (s0 * 256)> +#map1 = affine_map<()[s0] -> (s0 * 64)> module { - func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg2: memref<512x512xbf16>) { + func.func @func0(%arg0: memref<256x1024xbf16>, %arg1: memref<1024x256xbf16>, %arg2: memref<256x256xbf16>) { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c512 = arith.constant 512 : index - %c64 = arith.constant 64 : index - %0 = memref.alloc() {alignment = 128 : i64} : memref<512x512xbf16> - memref.copy %arg2, %0 : memref<512x512xbf16> to memref<512x512xbf16> - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c512, %c512) step (%c64, %c64) { - scf.for %arg5 = %c0 to %c512 step %c64 { - %1 = memref.alloc() : memref<64x64xbf16, 1> - %2 = memref.alloc() : memref<64x64xbf16, 1> - %3 = memref.alloc() : memref<64x64xbf16, 1> - air.dma_memcpy_nd (%1[] [] [], %arg0[%arg3, %arg5] [%c64, %c64] [%c512, %c1]) {id = 1 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) - air.dma_memcpy_nd (%2[] [] [], %arg1[%arg5, %arg4] [%c64, %c64] [%c512, %c1]) {id = 2 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) - air.dma_memcpy_nd (%3[] [] [], %0[%arg3, %arg4] [%c64, %c64] [%c512, %c1]) {id = 3 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>) - air.herd tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c2) args(%arg10=%1, %arg11=%2, %arg12=%3) : memref<64x64xbf16, 1>, memref<64x64xbf16, 1>, memref<64x64xbf16, 1> attributes {sym_name = "herd_0"} { + air.launch (%arg3, %arg4) in (%arg5=%c1, %arg6=%c1) args(%arg7=%arg1) : memref<1024x256xbf16> attributes {id = 3 : i32} { + air.segment @segment_0 args(%arg8=%arg4, %arg9=%arg7) : index, memref<1024x256xbf16> attributes {id = 2 : i32} { + %c4 = arith.constant 4 : index + %0 = affine.apply #map()[%arg8] + air.herd @herd_0 tile (%arg10, %arg11) in (%arg12=%c4, %arg13=%c4) args(%arg14=%0, %arg15=%arg9) : index, memref<1024x256xbf16> attributes {id = 1 : i32} { %c1_0 = arith.constant 1 : index - %c0_1 = arith.constant 0 : index - %c64_2 = arith.constant 64 : index - %c32 = arith.constant 32 : index - %4 = affine.apply #map()[%arg6] - %5 = affine.apply #map()[%arg7] - scf.for %arg13 = %c0_1 to %c64_2 step %c32 { - %6 = memref.alloc() : memref<32x32xbf16, 2> - %7 = memref.alloc() : memref<32x32xbf16, 2> - %8 = memref.alloc() : memref<32x32xbf16, 2> - air.dma_memcpy_nd (%6[] [] [], %arg10[%4, %arg13] [%c32, %c32] [%c64_2, %c1_0]) {id = 4 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) - air.dma_memcpy_nd (%7[] [] [], %arg11[%arg13, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 5 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) - air.dma_memcpy_nd (%8[] [] [], %arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 6 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>) - linalg.matmul {cast = #linalg.type_fn} ins(%6, %7 : memref<32x32xbf16, 2>, memref<32x32xbf16, 2>) outs(%8 : memref<32x32xbf16, 2>) - air.dma_memcpy_nd (%arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0], %8[] [] []) {id = 7 : i32} : (memref<64x64xbf16, 1>, memref<32x32xbf16, 2>) - memref.dealloc %6 : memref<32x32xbf16, 2> - memref.dealloc %7 : memref<32x32xbf16, 2> - memref.dealloc %8 : memref<32x32xbf16, 2> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %1 = affine.apply #map1()[%arg11] + %2 = arith.addi %arg14, %1 : index + scf.for %arg16 = %c0 to %c1024 step %c256 { + %alloc = memref.alloc() : memref<256x64xbf16, 1> + air.dma_memcpy_nd (%alloc[] [] [], %arg15[%arg16, %2] [%c256, %c64] [%c256, %c1_0]) {id = 2 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>) + scf.for %arg17 = %c0 to %c256 step %c64 { + %alloc_1 = memref.alloc() : memref<64x64xbf16, 2> + air.dma_memcpy_nd (%alloc_1[] [] [], %alloc[%arg17, %c0] [%c64, %c64] [%c64, %c1_0]) {id = 4 : i32} : (memref<64x64xbf16, 2>, memref<256x64xbf16, 1>) + memref.dealloc %alloc_1 : memref<64x64xbf16, 2> + } + memref.dealloc %alloc : memref<256x64xbf16, 1> } air.herd_terminator } - air.dma_memcpy_nd (%0[%arg3, %arg4] [%c64, %c64] [%c512, %c1], %3[] [] []) {id = 8 : i32} : (memref<512x512xbf16>, memref<64x64xbf16, 1>) - memref.dealloc %1 : memref<64x64xbf16, 1> - memref.dealloc %2 : memref<64x64xbf16, 1> - memref.dealloc %3 : memref<64x64xbf16, 1> + air.segment_terminator } + air.launch_terminator } return }