Skip to content

Commit

Permalink
Expand on broadcast detection conditions (#471)
Browse files Browse the repository at this point in the history
* Expand on broadcast detection conditions

* Clang format

* Clang format

* Clang format
  • Loading branch information
erwei-xilinx authored Mar 6, 2024
1 parent 77f8c15 commit f93afb7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 62 deletions.
54 changes: 30 additions & 24 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2138,23 +2138,46 @@ struct BroadcastDetection {
auto dma_op = dma_op_history[i];
SmallVector<Value, 1> loop_dep_history = dma_op_loop_dep_history[i];
air::HerdOp hl_op = nullptr;
bool hasDepInHerdRows = false;
bool hasDepInHerdCols = false;
bool isVariantWrtHerdRows = false;
bool isVariantWrtHerdCols = false;
// Create an affine set to represent the broadcast pattern
auto ctx = dma_op->getContext();
for (auto v : loop_dep_history) {
// Check row-wise or col-wise broadcastable based on variance wrt herd
// dimensions.
if (getHerdArgOwner(v)) {
hl_op = getHerdArgOwner(v);
if (v == hl_op.getIds()[0]) {
hasDepInHerdRows = true;
isVariantWrtHerdRows = true;
}
if (v == hl_op.getIds()[1]) {
hasDepInHerdCols = true;
isVariantWrtHerdCols = true;
}
}
}

if (hl_op && hasDepInHerdRows && !hasDepInHerdCols) {
// If not variant wrt herd, then check for fixed row-wise or col-wise
// offset.
int src_memspace = dma_op.getSrcMemref()
.getType()
.cast<MemRefType>()
.getMemorySpaceAsInt();
int dst_memspace = dma_op.getDstMemref()
.getType()
.cast<MemRefType>()
.getMemorySpaceAsInt();
auto externalOffsets = src_memspace == (int)air::MemorySpace::L1
? dma_op.getDstOffsets()
: dma_op.getSrcOffsets();
if (!hl_op && externalOffsets.size() ==
dma_op->getParentOfType<air::HerdOp>().getNumDims()) {
hl_op = dma_op->getParentOfType<air::HerdOp>();
if (getConstantIntValue(externalOffsets[0]))
isVariantWrtHerdRows = true;
if (getConstantIntValue(externalOffsets[1]))
isVariantWrtHerdCols = true;
}

if (hl_op && isVariantWrtHerdRows && !isVariantWrtHerdCols) {
auto numColsOp = dyn_cast<arith::ConstantIndexOp>(
hl_op.getSizeOperands()[1].getDefiningOp());
auto numCols = numColsOp.value();
Expand All @@ -2169,7 +2192,7 @@ struct BroadcastDetection {
dma_op->setAttr("broadcast_pattern",
mlir::IntegerSetAttr::get(int_set));
}
} else if (hl_op && !hasDepInHerdRows && hasDepInHerdCols) {
} else if (hl_op && !isVariantWrtHerdRows && isVariantWrtHerdCols) {
auto numRowsOp = dyn_cast<arith::ConstantIndexOp>(
hl_op.getSizeOperands()[0].getDefiningOp());
auto numRows = numRowsOp.value();
Expand All @@ -2184,23 +2207,6 @@ struct BroadcastDetection {
dma_op->setAttr("broadcast_pattern",
mlir::IntegerSetAttr::get(int_set));
}
} else if (hl_op && !hasDepInHerdRows && !hasDepInHerdCols) {
auto numRowsOp = dyn_cast<arith::ConstantIndexOp>(
hl_op.getSizeOperands()[0].getDefiningOp());
auto numRows = numRowsOp.value();
auto numColsOp = dyn_cast<arith::ConstantIndexOp>(
hl_op.getSizeOperands()[1].getDefiningOp());
auto numCols = numColsOp.value();
if (numCols > 1 && numRows > 1) {
SmallVector<AffineExpr, 5> constraints{
getAffineDimExpr(0, ctx), numRows - 1 - getAffineDimExpr(0, ctx),
getAffineDimExpr(1, ctx), numCols - 1 - getAffineDimExpr(1, ctx),
getAffineSymbolExpr(0, ctx)};
SmallVector<bool, 5> eqflags{false, false, false, false, true};
auto int_set = IntegerSet::get(2, 1, constraints, eqflags);
dma_op->setAttr("broadcast_pattern",
mlir::IntegerSetAttr::get(int_set));
}
}
}
}
Expand Down
119 changes: 81 additions & 38 deletions mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,102 @@
//
//===----------------------------------------------------------------------===//

// RUN: air-opt %s -air-dependency -air-broadcast-detection | FileCheck %s
// RUN: air-opt %s -air-dependency -air-broadcast-detection --split-input-file | FileCheck %s

// Detects broadcast pattern for DMAs
// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 + 1 >= 0)>
// CHECK: [[$SET1:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 1 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 1 >= 0)>
// CHECK-LABEL: func.func @matmul
// CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}}
// CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET1]]{{.*}}

#map = affine_map<()[s0] -> (s0 * 32)>
func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg2: memref<512x512xbf16>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c64 = arith.constant 64 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<512x512xbf16>
memref.copy %arg2, %0 : memref<512x512xbf16> to memref<512x512xbf16>
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c512, %c512) step (%c64, %c64) {
scf.for %arg5 = %c0 to %c512 step %c64 {
%1 = memref.alloc() : memref<64x64xbf16, 1>
%2 = memref.alloc() : memref<64x64xbf16, 1>
%3 = memref.alloc() : memref<64x64xbf16, 1>
air.dma_memcpy_nd (%1[] [] [], %arg0[%arg3, %arg5] [%c64, %c64] [%c512, %c1]) {id = 1 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.dma_memcpy_nd (%2[] [] [], %arg1[%arg5, %arg4] [%c64, %c64] [%c512, %c1]) {id = 2 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.dma_memcpy_nd (%3[] [] [], %0[%arg3, %arg4] [%c64, %c64] [%c512, %c1]) {id = 3 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.herd tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c2) args(%arg10=%1, %arg11=%2, %arg12=%3) : memref<64x64xbf16, 1>, memref<64x64xbf16, 1>, memref<64x64xbf16, 1> attributes {sym_name = "herd_0"} {
%c1_0 = arith.constant 1 : index
%c0_1 = arith.constant 0 : index
%c64_2 = arith.constant 64 : index
%c32 = arith.constant 32 : index
%4 = affine.apply #map()[%arg6]
%5 = affine.apply #map()[%arg7]
scf.for %arg13 = %c0_1 to %c64_2 step %c32 {
%6 = memref.alloc() : memref<32x32xbf16, 2>
%7 = memref.alloc() : memref<32x32xbf16, 2>
%8 = memref.alloc() : memref<32x32xbf16, 2>
air.dma_memcpy_nd (%6[] [] [], %arg10[%4, %arg13] [%c32, %c32] [%c64_2, %c1_0]) {id = 4 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
air.dma_memcpy_nd (%7[] [] [], %arg11[%arg13, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 5 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
air.dma_memcpy_nd (%8[] [] [], %arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 6 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%6, %7 : memref<32x32xbf16, 2>, memref<32x32xbf16, 2>) outs(%8 : memref<32x32xbf16, 2>)
air.dma_memcpy_nd (%arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0], %8[] [] []) {id = 7 : i32} : (memref<64x64xbf16, 1>, memref<32x32xbf16, 2>)
memref.dealloc %6 : memref<32x32xbf16, 2>
memref.dealloc %7 : memref<32x32xbf16, 2>
memref.dealloc %8 : memref<32x32xbf16, 2>
}
air.herd_terminator
}
air.dma_memcpy_nd (%0[%arg3, %arg4] [%c64, %c64] [%c512, %c1], %3[] [] []) {id = 8 : i32} : (memref<512x512xbf16>, memref<64x64xbf16, 1>)
memref.dealloc %1 : memref<64x64xbf16, 1>
memref.dealloc %2 : memref<64x64xbf16, 1>
memref.dealloc %3 : memref<64x64xbf16, 1>
}
}
return
}

// -----

// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 3 >= 0)>
// CHECK-LABEL: func.func @func0
// CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}} {id = 1 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>)
// CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}}

#map = affine_map<()[s0] -> (s0 * 256)>
#map1 = affine_map<()[s0] -> (s0 * 64)>
module {
func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg2: memref<512x512xbf16>) {
func.func @func0(%arg0: memref<256x1024xbf16>, %arg1: memref<1024x256xbf16>, %arg2: memref<256x256xbf16>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c64 = arith.constant 64 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<512x512xbf16>
memref.copy %arg2, %0 : memref<512x512xbf16> to memref<512x512xbf16>
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c512, %c512) step (%c64, %c64) {
scf.for %arg5 = %c0 to %c512 step %c64 {
%1 = memref.alloc() : memref<64x64xbf16, 1>
%2 = memref.alloc() : memref<64x64xbf16, 1>
%3 = memref.alloc() : memref<64x64xbf16, 1>
air.dma_memcpy_nd (%1[] [] [], %arg0[%arg3, %arg5] [%c64, %c64] [%c512, %c1]) {id = 1 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.dma_memcpy_nd (%2[] [] [], %arg1[%arg5, %arg4] [%c64, %c64] [%c512, %c1]) {id = 2 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.dma_memcpy_nd (%3[] [] [], %0[%arg3, %arg4] [%c64, %c64] [%c512, %c1]) {id = 3 : i32} : (memref<64x64xbf16, 1>, memref<512x512xbf16>)
air.herd tile (%arg6, %arg7) in (%arg8=%c2, %arg9=%c2) args(%arg10=%1, %arg11=%2, %arg12=%3) : memref<64x64xbf16, 1>, memref<64x64xbf16, 1>, memref<64x64xbf16, 1> attributes {sym_name = "herd_0"} {
air.launch (%arg3, %arg4) in (%arg5=%c1, %arg6=%c1) args(%arg7=%arg1) : memref<1024x256xbf16> attributes {id = 3 : i32} {
air.segment @segment_0 args(%arg8=%arg4, %arg9=%arg7) : index, memref<1024x256xbf16> attributes {id = 2 : i32} {
%c4 = arith.constant 4 : index
%0 = affine.apply #map()[%arg8]
air.herd @herd_0 tile (%arg10, %arg11) in (%arg12=%c4, %arg13=%c4) args(%arg14=%0, %arg15=%arg9) : index, memref<1024x256xbf16> attributes {id = 1 : i32} {
%c1_0 = arith.constant 1 : index
%c0_1 = arith.constant 0 : index
%c64_2 = arith.constant 64 : index
%c32 = arith.constant 32 : index
%4 = affine.apply #map()[%arg6]
%5 = affine.apply #map()[%arg7]
scf.for %arg13 = %c0_1 to %c64_2 step %c32 {
%6 = memref.alloc() : memref<32x32xbf16, 2>
%7 = memref.alloc() : memref<32x32xbf16, 2>
%8 = memref.alloc() : memref<32x32xbf16, 2>
air.dma_memcpy_nd (%6[] [] [], %arg10[%4, %arg13] [%c32, %c32] [%c64_2, %c1_0]) {id = 4 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
air.dma_memcpy_nd (%7[] [] [], %arg11[%arg13, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 5 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
air.dma_memcpy_nd (%8[] [] [], %arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0]) {id = 6 : i32} : (memref<32x32xbf16, 2>, memref<64x64xbf16, 1>)
linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%6, %7 : memref<32x32xbf16, 2>, memref<32x32xbf16, 2>) outs(%8 : memref<32x32xbf16, 2>)
air.dma_memcpy_nd (%arg12[%4, %5] [%c32, %c32] [%c64_2, %c1_0], %8[] [] []) {id = 7 : i32} : (memref<64x64xbf16, 1>, memref<32x32xbf16, 2>)
memref.dealloc %6 : memref<32x32xbf16, 2>
memref.dealloc %7 : memref<32x32xbf16, 2>
memref.dealloc %8 : memref<32x32xbf16, 2>
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c1024 = arith.constant 1024 : index
%1 = affine.apply #map1()[%arg11]
%2 = arith.addi %arg14, %1 : index
scf.for %arg16 = %c0 to %c1024 step %c256 {
%alloc = memref.alloc() : memref<256x64xbf16, 1>
air.dma_memcpy_nd (%alloc[] [] [], %arg15[%arg16, %2] [%c256, %c64] [%c256, %c1_0]) {id = 2 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>)
scf.for %arg17 = %c0 to %c256 step %c64 {
%alloc_1 = memref.alloc() : memref<64x64xbf16, 2>
air.dma_memcpy_nd (%alloc_1[] [] [], %alloc[%arg17, %c0] [%c64, %c64] [%c64, %c1_0]) {id = 4 : i32} : (memref<64x64xbf16, 2>, memref<256x64xbf16, 1>)
memref.dealloc %alloc_1 : memref<64x64xbf16, 2>
}
memref.dealloc %alloc : memref<256x64xbf16, 1>
}
air.herd_terminator
}
air.dma_memcpy_nd (%0[%arg3, %arg4] [%c64, %c64] [%c512, %c1], %3[] [] []) {id = 8 : i32} : (memref<512x512xbf16>, memref<64x64xbf16, 1>)
memref.dealloc %1 : memref<64x64xbf16, 1>
memref.dealloc %2 : memref<64x64xbf16, 1>
memref.dealloc %3 : memref<64x64xbf16, 1>
air.segment_terminator
}
air.launch_terminator
}
return
}
Expand Down

0 comments on commit f93afb7

Please sign in to comment.