Skip to content

Commit

Permalink
Add lowering of air.herd_load to npu.rtp_write
Browse files Browse the repository at this point in the history
  • Loading branch information
fifield committed May 21, 2024
1 parent 41f0db0 commit 00e807e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 54 deletions.
23 changes: 17 additions & 6 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,24 @@ class AIRHerdConversion : public ConversionPattern {
return failure();
}

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(op->getBlock());
rewriter.create<airrt::HerdLoadOp>(op->getLoc(), rewriter.getI64Type(),
herd_name_attr.getValue().str(),
/* operands */ SmallVector<Value>());
// The first two herd operands are the herd size. Of the rest, integer
// operands are passed as arguments (runtime parameters) to the herd load
// op.
SmallVector<Value> args;
if (herd.getNumKernelOperands() + 2 != operands.size()) {
assert(0 && "error lowering air.herd: unexpected number of operands");
return failure();
}
for (int i = 2, e = operands.size(); i < e; i++) {
Value o = operands[i];
auto iTy = dyn_cast<IntegerType>(o.getType());
if (!iTy)
continue;
args.push_back(o);
}

rewriter.create<airrt::HerdLoadOp>(op->getLoc(), rewriter.getI64Type(),
herd_name_attr.getValue().str(), args);

SmallVector<Value, 4> deps;
for (auto &o : operands)
Expand Down
42 changes: 41 additions & 1 deletion mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,45 @@ struct HerdLoadToNpuPattern : public OpConversionPattern<HerdLoadOp> {
LogicalResult
matchAndRewrite(HerdLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto module = op->getParentOfType<ModuleOp>();

// get the size metadata associated with this herd load
int64_t size_x = -1;
int64_t size_y = -1;
module.walk([&](HerdMetadataOp metadata) {
// return the first match by name
if (metadata.getSymName() == op.getSymName()) {
size_x = metadata->getAttrOfType<IntegerAttr>("size_x").getInt();
size_y = metadata->getAttrOfType<IntegerAttr>("size_y").getInt();
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (size_x == -1 || size_y == -1)
return failure();

for (int i = 0, e = op.getNumOperands(); i < e; i++) {
Value oper = adaptor.getOperands()[i];

// only support 32-bit integers
auto iTy = dyn_cast<IntegerType>(oper.getType());
if (!iTy || iTy.getWidth() != 32)
continue;

for (int x = 0; x < size_x; x++) {
for (int y = 0; y < size_y; y++) {
std::string name =
"__air_herd_rtp_" + std::to_string(x) + "_" + std::to_string(y);
auto constOp =
dyn_cast_if_present<arith::ConstantOp>(oper.getDefiningOp());
if (!constOp)
continue;
uint32_t v = cast<IntegerAttr>(constOp.getValue()).getInt();
rewriter.create<AIEX::NpuWriteRTPOp>(op.getLoc(), name, x, y, i, v);
}
}
}
rewriter.eraseOp(op);
return success();
}
Expand Down Expand Up @@ -452,7 +491,8 @@ void hoistTargetOpsToNewAffineFor(OpBuilder builder, affine::AffineForOp for_op,
}
}

template <typename T> void push_back_if_unique(SmallVector<T> &vec, T entry) {
template <typename T>
void push_back_if_unique(SmallVector<T> &vec, T entry) {
if (std::find(vec.begin(), vec.end(), entry) == vec.end()) {
vec.push_back(entry);
}
Expand Down
85 changes: 54 additions & 31 deletions mlir/lib/Conversion/AIRToAIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,27 @@ void outlineAIECores(OpBuilder &builder, AIE::DeviceOp aie_device,
remap.map(h.getSize()[1],
core_builder.create<arith::ConstantIndexOp>(hloc, herd_size_y));

for (unsigned i = 0; i < h.getNumKernelOperands(); i++) {
auto a = h.getKernelArgument(i);
int64_t rtp_buffer_size = 0; // size in i32s
for (unsigned i = 0, e = h.getNumKernelOperands(); i < e; i++) {
BlockArgument a = h.getKernelArgument(i);
if (auto intTy = llvm::dyn_cast<IntegerType>(a.getType()))
if (intTy.getWidth() == 32)
rtp_buffer_size++;
}
uint64_t buffer_id = 0;
AIE::BufferOp rtp_buffer = allocateBufferOp(
buffer_id,
MemRefType::get({rtp_buffer_size}, core_builder.getI32Type()), tile,
core_builder.getStringAttr("__air_herd_rtp"), x, y);

for (unsigned i = 0, e = h.getNumKernelOperands(); i < e; i++) {
BlockArgument a = h.getKernelArgument(i);

if (auto intTy = llvm::dyn_cast<IntegerType>(a.getType())) {
uint64_t buffer_id = i;
AIE::BufferOp b = allocateBufferOp(
buffer_id, MemRefType::get({}, intTy), tile,
core_builder.getStringAttr("__air_herd_rtp_"), x, y);
auto c = core_builder.create<memref::LoadOp>(hloc, intTy, b,
SmallVector<Value>{});
auto const_i = core_builder.create<arith::ConstantIndexOp>(hloc, i);
SmallVector<Value> offsets{const_i};
auto c = core_builder.create<memref::LoadOp>(hloc, intTy, rtp_buffer,
offsets);
remap.map(a, c);
continue;
}
Expand Down Expand Up @@ -513,7 +524,8 @@ void createAIEModulesAndOutlineCores(
OpBuilder builder(aie_dev);
outlineAIECores(builder, aie_dev, h, tileToHerdMap, options);
}
// Outline any L1 memref allocs used by herds but located outside of any herd
// Outline any L1 memref allocs used by herds but located outside of any
// herd
std::vector<Value> sharedL1Memrefs;
for (auto &p : aie_modules) {
// for (auto h : herds) {
Expand Down Expand Up @@ -793,7 +805,8 @@ void lowerScfAirTokens(AIE::DeviceOp m) {
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
}

// struct LowerPipeGetPutPattern : public OpRewritePattern<air::PipelinePutOp> {
// struct LowerPipeGetPutPattern : public OpRewritePattern<air::PipelinePutOp>
// {
// using OpRewritePattern<air::PipelinePutOp>::OpRewritePattern;

// LowerPipeGetPutPattern(MLIRContext *ctx,
Expand All @@ -813,9 +826,10 @@ void lowerScfAirTokens(AIE::DeviceOp m) {
// auto row_offset = r ? *r : 0;

// auto other_x =
// cast<arith::ConstantIndexOp>(put.getDst0().getDefiningOp()); auto other_y
// = cast<arith::ConstantIndexOp>(put.getDst1().getDefiningOp()); auto
// other_core = getPhysTileOp(aie_device, other_x.value() + col_offset,
// cast<arith::ConstantIndexOp>(put.getDst0().getDefiningOp()); auto
// other_y = cast<arith::ConstantIndexOp>(put.getDst1().getDefiningOp());
// auto other_core = getPhysTileOp(aie_device, other_x.value() +
// col_offset,
// other_y.value() + row_offset)
// .getCoreOp();
// assert(other_core);
Expand Down Expand Up @@ -876,7 +890,8 @@ void lowerScfAirTokens(AIE::DeviceOp m) {
// shared aie.buffer + aie.lock. This is a single-buffered implementation
// with exclusive access to the buffer controlled by the lock. i.e. FIXME.
// void lowerPipelineGetPut(AIE::DeviceOp &m,
// std::map<AIE::TileOp, air::HerdOp> tileToHerdMap) {
// std::map<AIE::TileOp, air::HerdOp> tileToHerdMap)
// {
// auto ctx = m->getContext();
// RewritePatternSet patterns(ctx);
// patterns.insert<LowerPipeGetPutPattern>(ctx, tileToHerdMap);
Expand Down Expand Up @@ -924,7 +939,8 @@ void lowerScfAirTokens(AIE::DeviceOp m) {
// tile.getCol() - col_offset, tile.getRow() - row_offset);

// rewriter.setInsertionPoint(cast);
// rewriter.create<memref::TensorStoreOp>(cast.getLoc(), cast.getOperand(),
// rewriter.create<memref::TensorStoreOp>(cast.getLoc(),
// cast.getOperand(),
// buffer);
// rewriter.replaceOp(cast, buffer->getResults());
// return success();
Expand Down Expand Up @@ -999,7 +1015,8 @@ struct AllocL2BuffersPattern : public OpRewritePattern<memref::AllocOp> {
LogicalResult matchAndRewrite(memref::AllocOp alloc,
PatternRewriter &rewriter) const override {

// L2 memref allocs should exist inside of device op but outside of core op
// L2 memref allocs should exist inside of device op but outside of core
// op
AIE::DeviceOp device = alloc->getParentOfType<AIE::DeviceOp>();
if (!device)
return failure();
Expand Down Expand Up @@ -1463,10 +1480,10 @@ struct LowerAIRChannelsPattern : public OpRewritePattern<air::ChannelOp> {
std::map<Operation *, AIE::ObjectFifoCreateOp> &linksToComplete;
};

// This function replaces ChannelPutOp/ChannelGetOp with AIE_CreateObjectFifoOps
// and with ObjectFifoAcquireOp<Producer/Consumer>. It also erases memref allocs
// as the objFifo lowering allocates its own memory. It replaces the associated
// memref deallocs with ObjectFifoReleaseOps.
// This function replaces ChannelPutOp/ChannelGetOp with
// AIE_CreateObjectFifoOps and with ObjectFifoAcquireOp<Producer/Consumer>. It
// also erases memref allocs as the objFifo lowering allocates its own memory.
// It replaces the associated memref deallocs with ObjectFifoReleaseOps.
void lowerAIRChannels(
AIE::DeviceOp &d, ShimTileAllocator &s,
std::map<AIE::BufferOp, AIE::TileOp> &bufferToMemtileMap) {
Expand Down Expand Up @@ -1835,8 +1852,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
isa<arith::ConstantIndexOp>(operand.getDefiningOp())) {
operandOps.push_back(operand.getDefiningOp());
}
// Substituting index operands, such as strides and offsets, to constant
// zero for convenience. TODO: generalize this
// Substituting index operands, such as strides and offsets, to
// constant zero for convenience. TODO: generalize this
else if (llvm::isa<IndexType>(operand.getType())) {
remap.map(operand, builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 0));
Expand Down Expand Up @@ -2137,8 +2154,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
// verifyMemcpyOps(dma_memcpy_ops,
// aie_device.getTargetModel().getTargetArch());

// Step 2: Pair up memcpy ops into flow ops. Each entry in memcpy_flows is a
// bundle of memcpy ops which share the same aie.flow.
// Step 2: Pair up memcpy ops into flow ops. Each entry in memcpy_flows is
// a bundle of memcpy ops which share the same aie.flow.
std::vector<MemcpyBundleAsFlow> memcpy_flows;
for (auto o : dma_memcpy_ops) {
if (auto dma = dyn_cast<air::DmaMemcpyNdOp>(o)) {
Expand Down Expand Up @@ -2173,9 +2190,9 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {

// Step 3: Allocate tile DMA channels, shim DMA channels and shim tiles
// AIR channel to AIE flow mapping strategy: allocate L1 DMAs first,
// followed by L2 and then L3, where outer memory hierarchies reuse existing
// AIE flows as possible.
// if (groupingMemcpysByLoop(memcpy_flows))
// followed by L2 and then L3, where outer memory hierarchies reuse
// existing AIE flows as possible. if
// (groupingMemcpysByLoop(memcpy_flows))
// groupedByLoopDMAChannelAllocation(memcpy_flows, shim_dma_alloc,
// memtile_dma_alloc, tile_dma_alloc);
// else
Expand Down Expand Up @@ -2542,6 +2559,12 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
name = attr.getValue().str();

auto herd_meta = builder.create<airrt::HerdMetadataOp>(loc, name);
herd_meta->setAttr("size_x", builder.getI64IntegerAttr(herd.getNumCols()));
herd_meta->setAttr("size_y", builder.getI64IntegerAttr(herd.getNumRows()));
if (auto co = herd.getColOffset())
herd_meta->setAttr("loc_x", builder.getI64IntegerAttr(*co));
if (auto ro = herd.getRowOffset())
herd_meta->setAttr("loc_y", builder.getI64IntegerAttr(*ro));
return herd_meta;
}

Expand Down Expand Up @@ -2712,8 +2735,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
? ndcpy.getDstStrides()
: ndcpy.getSrcStrides();

// Skip over repeat pattern at highest dimension; repeat pattern handled at
// AIE::DMAStartOp.
// Skip over repeat pattern at highest dimension; repeat pattern handled
// at AIE::DMAStartOp.
if (!strides.empty() && !sizes.empty() && !offsets.empty())
if (auto const_highest_stride = getConstantIntValue(strides[0]))
if (*const_highest_stride == 0) {
Expand Down Expand Up @@ -2777,8 +2800,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
const auto &target_model = device.getTargetModel();
OpBuilder builder(device);

// Unlike shimDmaAlloc, tileDmaAlloc is local to device because it does not
// need to export to airrt.metadata
// Unlike shimDmaAlloc, tileDmaAlloc is local to device because it does
// not need to export to airrt.metadata
TileDMAAllocator tileDmaAlloc(device);
MemTileDMAAllocator memTileDmaAlloc(device);

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ std::stringstream air::generateBufferNameInStringStream(StringRef prefix,
std::stringstream ss;
if (attr) {
if (x >= 0 && y >= 0)
ss << attr.getValue().str() << BufferId++ << "_" << x << "_" << y;
ss << attr.getValue().str() << "_" << x << "_" << y;
else
ss << attr.getValue().str() << BufferId++;
} else {
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Conversion/AIRLowering/air_herd_rtp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- air_herd_rtp.mlir ---------------------------------------*- MLIR -*-===//
//
// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

// RUN: air-opt %s -air-to-std | FileCheck %s
// CHECK-LABEL: func.func @herd1
// CHECK: %{{.*}} = airrt.herd_load "herd" (%{{.*}}, %{{.*}}) : (i32, i32) -> i64
func.func @herd1(%arg0: i32, %arg1: i32, %arg2: i32) {
%cst1 = arith.constant 1 : index
%cst2 = arith.constant 2 : index
air.herd @herd tile(%tx, %ty) in (%size_x = %cst1, %size_y = %cst2) args(%a = %arg0, %b = %arg1, %c = %arg2) : i32, i32, i32 {
%src0 = memref.alloc() : memref<1xi32, 2>
%src1 = memref.alloc() : memref<1xi32, 2>
%zero = arith.constant 0 : index
%0 = memref.load %src0[%zero] : memref<1xi32, 2>
%1 = memref.load %src1[%zero] : memref<1xi32, 2>
%2 = arith.addi %0, %a : i32
%3 = arith.addi %1, %b : i32
%4 = arith.addi %2, %3 : i32
%5 = arith.addi %4, %c : i32
%dst0 = memref.alloc() : memref<1xi32, 2>
memref.store %5, %dst0[%zero] : memref<1xi32, 2>
air.herd_terminator
}
return
}

24 changes: 24 additions & 0 deletions mlir/test/Conversion/AIRRtToNpu/herd_load_to_npu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===- herd_load_to_npu.mlir -----------------------------------*- MLIR -*-===//
//
// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

// RUN: air-opt -airrt-to-npu %s | FileCheck %s
// CHECK: aiex.npu.rtp_write(0, 0, 0, 11) {buffer_sym_name = "__air_herd_rtp_0_0"}
// CHECK: aiex.npu.rtp_write(0, 1, 0, 11) {buffer_sym_name = "__air_herd_rtp_0_1"}
// CHECK: aiex.npu.rtp_write(0, 0, 1, 22) {buffer_sym_name = "__air_herd_rtp_0_0"}
// CHECK: aiex.npu.rtp_write(0, 1, 1, 22) {buffer_sym_name = "__air_herd_rtp_0_1"}
airrt.module_metadata{
airrt.segment_metadata attributes {sym_name = "segment"} {
airrt.herd_metadata {size_x = 1 : i64, size_y = 2 : i64, sym_name = "herd"}
}
}
func.func @func1(%arg0: i32, %arg1: i32) {
%c1_i32 = arith.constant 11 : i32
%c2_i32 = arith.constant 22 : i32
%h = airrt.herd_load "herd" (%c1_i32, %c2_i32) : (i32, i32) -> i64
%c1 = arith.constant 1 : index
return
}
37 changes: 22 additions & 15 deletions mlir/test/Conversion/AIRToAIE/air_herd_to_aie_rtp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,39 @@

// RUN: air-opt %s -air-to-aie | FileCheck %s

func.func @herd1(%arg0: i32, %arg1: i32) {
%cst1 = arith.constant 1 : index
func.func @herd1(%arg0: i32, %arg1: i32, %arg2: i32) {
// CHECK-LABEL: aie.device
// CHECK: %[[VAR1:.*]] = aie.tile(1, 1)
// CHECK: %[[BUF1:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}} : memref<1xi32, 2>
// CHECK: %[[BUF2:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}} : memref<1xi32, 2>
// CHECK: %[[BUF3:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}} : memref<1xi32, 2>
// CHECK: %[[RTP1:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}} : memref<i32>
// CHECK: %[[RTP2:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}} : memref<i32>
// CHECK: %[[VAR2:.*]] = aie.core(%[[VAR1]]) {
air.herd tile(%tx, %ty) in (%size_x = %cst1, %size_y = %cst1) args(%a = %arg0, %b = %arg1) : i32, i32{
// CHECK: load %[[RTP2]][] : memref<i32>
// CHECK: load %[[RTP1]][] : memref<i32>
// CHECK: %[[VAR2:.*]] = aie.tile(1, 2)

// CHECK: %[[RTP2:.*]] = aie.buffer(%[[VAR2]]) {{{.*}}sym_name = "__air_herd_rtp_0_1"{{.*}}} : memref<3xi32>
// CHECK: aie.core(%[[VAR2]])
// CHECK: load %[[RTP2]][%c0] : memref<3xi32>
// CHECK: load %[[RTP2]][%c1] : memref<3xi32>
// CHECK: load %[[RTP2]][%c2] : memref<3xi32>

// CHECK: %[[RTP1:.*]] = aie.buffer(%[[VAR1]]) {{{.*}}sym_name = "__air_herd_rtp_0_0"{{.*}}} : memref<3xi32>
// CHECK: aie.core(%[[VAR1]])
// CHECK: load %[[RTP1]][%c0] : memref<3xi32>
// CHECK: load %[[RTP1]][%c1] : memref<3xi32>
// CHECK: load %[[RTP1]][%c2] : memref<3xi32>
%cst1 = arith.constant 1 : index
%cst2 = arith.constant 2 : index
%cst12 = arith.constant 12 : i32
%cst23 = arith.constant 23 : i32
%cst34 = arith.constant 34 : i32
air.herd @herd tile(%tx, %ty) in (%size_x = %cst1, %size_y = %cst2) args(%a = %cst12, %b = %cst23, %c = %cst34) : i32, i32, i32 {
%src0 = memref.alloc() : memref<1xi32, 2>
%src1 = memref.alloc() : memref<1xi32, 2>
%zero = arith.constant 0 : index
// CHECK: load %[[BUF1]]
%0 = memref.load %src0[%zero] : memref<1xi32, 2>
// CHECK: load %[[BUF2]]
%1 = memref.load %src1[%zero] : memref<1xi32, 2>
%2 = arith.addi %0, %a : i32
%3 = arith.addi %1, %b : i32
%4 = arith.addi %2, %3 : i32
%5 = arith.addi %4, %c : i32
%dst0 = memref.alloc() : memref<1xi32, 2>
// CHECK: memref.store {{.*}}, %[[BUF3]]
memref.store %4, %dst0[%zero] : memref<1xi32, 2>
memref.store %5, %dst0[%zero] : memref<1xi32, 2>
air.herd_terminator
}
return
Expand Down

0 comments on commit 00e807e

Please sign in to comment.