From 2e2d2f129d19ef1b76b7ce05393daf93b371a74c Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Tue, 12 Mar 2024 17:52:30 -0700 Subject: [PATCH] Enable user-provided maximum column size for air-collapse-herd (#486) * Enable user-provided maximum column size for air-collapse-herd * Set default to be 'disabled' * Add registration method for options --- mlir/include/air/Transform/Passes.td | 4 ++ mlir/lib/Transform/AIRMiscPasses.cpp | 8 +++- .../AIRMiscPasses/air_collapse_herd.mlir | 48 ++++++++++++++----- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/mlir/include/air/Transform/Passes.td b/mlir/include/air/Transform/Passes.td index c726c4cf7..b4fec5b8f 100644 --- a/mlir/include/air/Transform/Passes.td +++ b/mlir/include/air/Transform/Passes.td @@ -1159,6 +1159,10 @@ def AIRLabelBroadcastChannelWithTilePass : Pass<"air-label-broadcast-channel-wit def AIRCollapseHerdPass : Pass<"air-collapse-herd", "func::FuncOp"> { let summary = "Collapse a multi-dimensional air.herd into a single column."; let constructor = "xilinx::air::createAIRCollapseHerdPass()"; + let options = [ + Option<"clMaxColSize", "max-col-size", "int", /*default=*/"-1", + "The maximum column size after collapse, before collapse is cancelled. Disabled by default."> + ]; } def AIRUnrollOuterPerfectlyNestedLoopsPass : Pass<"air-unroll-outer-affine-loops", "func::FuncOp"> { diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index db6bf6965..0983300ae 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -947,6 +947,8 @@ class AIRCollapseHerdPass public: AIRCollapseHerdPass() = default; AIRCollapseHerdPass(const AIRCollapseHerdPass &pass){}; + AIRCollapseHerdPass(const ::xilinx::air::AIRCollapseHerdPassOptions &options) + : AIRCollapseHerdPassBase(options) {} void runOnOperation() override; @@ -956,8 +958,12 @@ class AIRCollapseHerdPass void AIRCollapseHerdPass::runOnOperation() { SmallVector herds; auto func = getOperation(); + int maximumColumnSize = clMaxColSize; + if (clMaxColSize == -1) + maximumColumnSize = INT_MAX; // max-col-size disabled. func.walk([&](air::HerdOp op) { - if (op.getNumCols() != 1 && op.getNumDims() == 2) + if (op.getNumCols() != 1 && op.getNumDims() == 2 && + op.getNumRows() * op.getNumCols() <= maximumColumnSize) herds.push_back(op); }); diff --git a/mlir/test/Transform/AIRMiscPasses/air_collapse_herd.mlir b/mlir/test/Transform/AIRMiscPasses/air_collapse_herd.mlir index 4284e9b4a..a34880315 100644 --- a/mlir/test/Transform/AIRMiscPasses/air_collapse_herd.mlir +++ b/mlir/test/Transform/AIRMiscPasses/air_collapse_herd.mlir @@ -5,16 +5,21 @@ // //===----------------------------------------------------------------------===// -// RUN: air-opt %s -air-collapse-herd -canonicalize --split-input-file | FileCheck %s +// RUN: air-opt %s -air-collapse-herd="max-col-size=4" -canonicalize --split-input-file | FileCheck %s +// RUN: air-opt %s -air-collapse-herd -canonicalize --split-input-file | FileCheck %s --check-prefix=MAXCOL -// CHECK: func.func @test0 +// CHECK-LABEL: func.func @test0 // CHECK: %[[CST1:.*]] = arith.constant 1 : index -// CHECK: %[[CST9:.*]] = arith.constant 9 : index -// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST9]]) +// CHECK: %[[CST4:.*]] = arith.constant 4 : index +// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST4]]) +// MAXCOL-LABEL: func.func @test0 +// MAXCOL: %[[CST1:.*]] = arith.constant 1 : index +// MAXCOL: %[[CST4:.*]] = arith.constant 4 : index +// MAXCOL: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST4]]) func.func @test0() -> () { - %c3 = arith.constant 3 : index - air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) { + %c2 = arith.constant 2 : index + air.herd tile (%x, %y) in (%sx=%c2, %sy=%c2) { } return } @@ -25,12 +30,12 @@ func.func @test0() -> () { // CHECK: [[$SET1:#set[0-9]+]] = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 2 >= 0, s1 == 0)> // CHECK: func.func @test1 // CHECK: %[[CST1:.*]] = arith.constant 1 : index -// CHECK: %[[CST9:.*]] = arith.constant 9 : index -// CHECK: air.herd tile (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[ARG2:.*]]=%[[CST1]], %[[ARG3:.*]]=%[[CST9]]) +// CHECK: %[[CST4:.*]] = arith.constant 4 : index +// CHECK: air.herd tile (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[ARG2:.*]]=%[[CST1]], %[[ARG3:.*]]=%[[CST4]]) // CHECK: %[[CST0:.*]] = arith.constant 0 : index -// CHECK: %[[CST3:.*]] = arith.constant 3 : index -// CHECK: %[[VAL0:.*]] = arith.remsi %[[ARG1]], %[[CST3]] : index -// CHECK: %[[VAL1:.*]] = arith.divsi %[[ARG1]], %[[CST3]] : index +// CHECK: %[[CST2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL0:.*]] = arith.remsi %[[ARG1]], %[[CST2]] : index +// CHECK: %[[VAL1:.*]] = arith.divsi %[[ARG1]], %[[CST2]] : index // CHECK: affine.if [[$SET0]]()[%[[VAL1]], %[[VAL0]]] { // CHECK: } else { // CHECK: } @@ -40,8 +45,8 @@ func.func @test0() -> () { #set0 = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 2 >= 0)> #set1 = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 2 >= 0, s1 == 0)> func.func @test1() -> () { - %c3 = arith.constant 3 : index - air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) { + %c2 = arith.constant 2 : index + air.herd tile (%x, %y) in (%sx=%c2, %sy=%c2) { %c0 = arith.constant 0 : index affine.if #set0()[%x, %y] { %alloc = memref.alloc() : memref<8x16xi32, 2> @@ -68,3 +73,20 @@ func.func @test1() -> () { } return } + +// ----- + +// CHECK-LABEL: func.func @test2 +// CHECK: %[[CST3:.*]] = arith.constant 3 : index +// CHECK: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST3]], %[[VAL3:.*]]=%[[CST3]]) +// MAXCOL-LABEL: func.func @test2 +// MAXCOL: %[[CST1:.*]] = arith.constant 1 : index +// MAXCOL: %[[CST9:.*]] = arith.constant 9 : index +// MAXCOL: air.herd tile (%[[VAL0:.*]], %[[VAL1:.*]]) in (%[[VAL2:.*]]=%[[CST1]], %[[VAL3:.*]]=%[[CST9]]) + +func.func @test2() -> () { + %c3 = arith.constant 3 : index + air.herd tile (%x, %y) in (%sx=%c3, %sy=%c3) { + } + return +}