Skip to content

Commit

Permalink
Propagate link_with in forall lowering (Xilinx#484)
Browse files Browse the repository at this point in the history
Update air-par-to-herd pass to propagate link_with attribute during
lowering of scf.forall to match lowering of scf.parallel.
  • Loading branch information
fifield authored Mar 11, 2024
1 parent 78863f6 commit d98f0a3
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 134 deletions.
44 changes: 25 additions & 19 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,26 @@ void InsertEmptyLaunchOverHerd(air::HerdOp op) {
return;
}

// func.call itself has a `link_with` which we can absorb into air.herd.
// Walk through all the func.call operations (immediate/nested children)
// within parallel loop. Currently we only assume and enforce that we relay
// `link_with` information from just one func.call op.
static LogicalResult propagateLinkWith(Operation *op, air::HerdOp herdOp) {
auto moduleOp = op->getParentOfType<ModuleOp>();
op->walk([&](func::CallOp callOp) {
// Fetch name.
StringRef fnName = callOp.getCallee();
auto fnDecl = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, fnName));
assert(fnDecl && "expected function declaration");
assert(fnDecl->hasAttr("link_with") &&
"expected 'link_with' construct for the function declaration");
herdOp->setAttr("link_with", fnDecl->getAttr("link_with"));
return WalkResult::interrupt();
});
return success();
}

class ScfParToHerdConversion : public OpRewritePattern<scf::ParallelOp> {
public:
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
Expand Down Expand Up @@ -1948,25 +1968,9 @@ class ScfParToHerdConversion : public OpRewritePattern<scf::ParallelOp> {
auto herdOp = rewriter.create<air::HerdOp>(op.getLoc(), dims, args);
auto moduleOp = SymbolTable::getNearestSymbolTable(op);
auto &body = op.getBody()->getOperations();
// func.call itself has a `link_with` which we can absorb into air.herd.
// This means that the onus of setting the path to microkernel is on IREE.
//
// NOTE: Microkernel being used is actually residing within MLIR-AIE.
//
// Walk through all the func.call operations (immediate/nested children)
// within scf.parallel. Currently we only assume and enforce that we relay
// `link_with` information from just one func.call op.
op->walk([&](func::CallOp callOp) {
// Fetch name.
StringRef fnName = callOp.getCallee();
auto fnDecl = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, fnName));
assert(fnDecl && "expected function declaration");
assert(fnDecl->hasAttr("link_with") &&
"expected 'link_with' construct for the function declaration");
herdOp->setAttr("link_with", fnDecl->getAttr("link_with"));
return WalkResult::interrupt();
});

propagateLinkWith(op, herdOp);

auto &bb = herdOp.getBody().front();
auto ivs = op.getInductionVars();

Expand Down Expand Up @@ -2085,6 +2089,8 @@ class ScfForallToHerdConversion : public OpRewritePattern<scf::ForallOp> {
auto &bb = herdOp.getBody().front();
auto ivs = op.getInductionVars();

propagateLinkWith(op, herdOp);

ivs[0].replaceAllUsesWith(herdOp.getIds()[idx0]);
if (op.getRank() == 2)
ivs[1].replaceAllUsesWith(herdOp.getIds()[idx1]);
Expand Down
115 changes: 0 additions & 115 deletions mlir/test/Conversion/ConvertToAIR/affine_par_to_herd_launch.mlir

This file was deleted.

45 changes: 45 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/affine_parallel_to_herd.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- affine_par_to_herd_launch.mlir --------------------------*- MLIR -*-===//
//
// Copyright (C) 2021-2022, Xilinx Inc. All rights reserved.
// Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd -cse %s | FileCheck %s

// CHECK-LABEL: func.func @par0
// CHECK: %[[C0:.*]] = arith.constant 1 : index
// CHECK: air.herd @herd_0 tile ({{.*}}, {{.*}}) in ({{.*}}=%[[C0]], {{.*}}=%[[C0]])
func.func @par0() {
affine.parallel (%x,%y) = (0,0) to (1,1) {
%2 = arith.addi %x, %y : index
affine.yield
}
return
}

// -----

func.func @par1() {
// expected-error@+1 {{'affine.parallel' op failed conversion to 'air.herd': only 2d loops are supported}}
affine.parallel (%x,%y,%z) = (0,0,0) to (1,2,3) {
%2 = arith.addi %x, %y : index
affine.yield
}
return
}

// -----

// CHECK-LABEL: func.func @par2
func.func @par2() {
// CHECK: %[[C0:.*]] = arith.constant 4 : index
// CHECK: %[[C1:.*]] = arith.constant 5 : index
// CHECK: air.herd @herd_0 tile ({{.*}}, {{.*}}) in ({{.*}}=%[[C0]], {{.*}}=%[[C1]])
affine.parallel (%x,%y) = (0,2) to (4,12) step (1,2) {
%2 = arith.addi %x, %y : index
affine.yield
}
return
}
31 changes: 31 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,34 @@ func.func @scf2() {
}
return
}

// -----

// This test demonstrates that while forming air.herd we look through func.call ops, fetch
// the corresponding function declaration's 'link_with' attribute and attach it to the newly
// formed air.herd op.

// CHECK-LABEL: module {
// CHECK: func.func private @matmul_i32_i32
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
// CHECK: func.func @matmul_small_dispatch_0_matmul_8x32x16_i32(
// CHECK: air.herd @herd_0
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o"} {
// CHECK: func.call @matmul_i32_i32
// CHECK: air.herd_terminator
// CHECK: }
// CHECK: return
// CHECK: }
// CHECK: }
module {
func.func private @matmul_i32_i32(memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
func.func @matmul_small_dispatch_0_matmul_8x32x16_i32(%base_buffer: memref<i32, 2 : i32>, %base_buffer_14: memref<i32, 2 : i32>, %base_buffer_18: memref<i32, 2 : i32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.forall (%x,%y) in (2, 2) {
%2 = arith.addi %x, %y : index
func.call @matmul_i32_i32(%base_buffer, %c0, %base_buffer_14, %c0, %base_buffer_18, %c0) : (memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) -> ()
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,73 @@ func.func @scf2() {
}
return
}

// -----

// This test demonstrates that while forming air.herd we look through func.call ops, fetch
// the corresponding function declaration's 'link_with' attribute and attach it to the newly
// formed air.herd op.

// CHECK-LABEL: module {
// CHECK: func.func private @matmul_i32_i32
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
// CHECK: func.func @matmul_small_dispatch_0_matmul_8x32x16_i32(
// CHECK: air.herd @herd_0
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o"} {
// CHECK: func.call @matmul_i32_i32
// CHECK: air.herd_terminator
// CHECK: }
// CHECK: return
// CHECK: }
// CHECK: }
module {
func.func private @matmul_i32_i32(memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
func.func @matmul_small_dispatch_0_matmul_8x32x16_i32(%base_buffer: memref<i32, 2 : i32>, %base_buffer_14: memref<i32, 2 : i32>, %base_buffer_18: memref<i32, 2 : i32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%x,%y) = (%c0,%c0) to (%c1,%c1) step (%c1, %c1) {
%2 = arith.addi %x, %y : index
func.call @matmul_i32_i32(%base_buffer, %c0, %base_buffer_14, %c0, %base_buffer_18, %c0) : (memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) -> ()
scf.reduce
}
return
}
}

// -----

// This test demonstrates the relaying of `link_with` construct to air.herd op even if the
// func.call op is not an immediate child of scf.parallel.

// CHECK-LABEL: module {
// CHECK: func.func private @matmul_scalar_i32_i32
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
// CHECK: func.func @matmul_small_nested_scf_dispatch_0_matmul_8x32x16_i32(
// CHECK: air.herd @herd_0
// CHECK-SAME: attributes {link_with = "/path/to/mm_microkernel.o"} {
// CHECK: scf.for
// CHECK-SAME: {
// CHECK: func.call @matmul_scalar_i32_i32
// CHECK: }
// CHECK: air.herd_terminator
// CHECK: }
// CHECK: return
// CHECK: }
// CHECK: }
module {
func.func private @matmul_scalar_i32_i32(memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) attributes {link_with = "/path/to/mm_microkernel.o", llvm.bareptr = true}
func.func @matmul_small_nested_scf_dispatch_0_matmul_8x32x16_i32(%base_buffer: memref<i32, 2 : i32>, %base_buffer_14: memref<i32, 2 : i32>, %base_buffer_18: memref<i32, 2 : i32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
scf.parallel (%x,%y) = (%c0,%c0) to (%c1,%c1) step (%c1, %c1) {
%2 = arith.addi %x, %y : index
scf.for %arg0 = %c0 to %c32 step %c4 {
func.call @matmul_scalar_i32_i32(%base_buffer, %c0, %base_buffer_14, %c0, %base_buffer_18, %c0) : (memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index, memref<i32, 2 : i32>, index) -> ()
}
scf.reduce
}
return
}
}

0 comments on commit d98f0a3

Please sign in to comment.