Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRRtToNpuPass SHIM DMA BD optimization #550

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 119 additions & 72 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,15 +578,15 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto const_stride = *getConstantIntValue(strides[i]);
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
// Found dimension with illegal wrap. Tiling.
int inner_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int new_wrap = mlir::ceilDiv(const_wrap, inner_wrap);
int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap);
wraps[i] = builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), inner_wrap));
wraps.insert(wraps.begin() + i,
builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), new_wrap)));
IntegerAttr::get(builder.getI64Type(), outer_wrap)));
auto new_const_stride =
(const_stride * inner_wrap) %
air::getTensorVolume(
Expand Down Expand Up @@ -1130,56 +1130,71 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
}

std::optional<AIE::ShimDMAAllocationOp>
getAllocOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) {
auto sym = dev.lookupSymbol(sym_name);
if (!sym)
return std::nullopt;

auto uses = SymbolTable::getSymbolUses(sym, dev);
for (auto use : *uses)
if (auto infoOp = dyn_cast<AIE::ShimDMAAllocationOp>(use.getUser()))
return infoOp;

getAllocOpForSymbol(SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps,
StringRef sym_name) {
for (auto shimDmaAllocOp : shimDmaAllocOps)
if (shimDmaAllocOp.getSymName() == sym_name)
return shimDmaAllocOp;
return std::nullopt;
}

std::optional<AIE::ObjectFifoCreateOp>
getObjectFifoCreateOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) {
auto sym = dev.lookupSymbol(sym_name);
if (!sym)
return std::nullopt;

for (auto objFifoCreateOp : dev.getOps<AIE::ObjectFifoCreateOp>()) {
if (objFifoCreateOp.getSymName().str() == sym_name.str())
return objFifoCreateOp;
}

std::optional<AIE::ObjectFifoCreateOp> getObjectFifoCreateOpForSymbol(
SmallVector<AIE::ObjectFifoCreateOp> objectFifoCreateOps,
StringRef sym_name) {
for (auto objectFifoCreateOp : objectFifoCreateOps)
if (objectFifoCreateOp.getSymName().str() == sym_name.str())
return objectFifoCreateOp;
return std::nullopt;
}

void insertNpuSyncOpForResults(ModuleOp module) {
module.walk([&](mlir::func::FuncOp f) {
SmallVector<mlir::func::FuncOp> funcOps;
module.walk([&](mlir::func::FuncOp f) { funcOps.push_back(f); });
for (auto f : funcOps) {
SmallVector<AIEX::NpuDmaMemcpyNdOp> dmas;
f.walk([&](AIEX::NpuDmaMemcpyNdOp dma) { dmas.push_back(dma); });
auto d = f->getParentOfType<AIE::DeviceOp>();

SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps;
if (d)
d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) {
shimDmaAllocOps.push_back(shimDmaAllocOp);
});
// Performance optimization: instead of repeating calls to
// getAllocOpForSymbol with the same symbol name, cache the result of the
// first call and use the cache for subsequent calls. This dramatically
// improves compile time for some designs.
llvm::DenseMap<StringRef, std::optional<AIE::ShimDMAAllocationOp>>
allocationCache;
auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) {
auto iter = allocationCache.find(sym_name);
if (iter != allocationCache.end()) {
return iter->second;
}
auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name);
allocationCache.insert({sym_name, infaOp});
return infaOp;
};

if (!d)
return;
continue;
OpBuilder builder(f);
for (auto dma : dmas) {
if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) {
if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) {
// Found dma op copying results to host
OpBuilder builder(dma);
auto col = builder.getI32IntegerAttr(infoOp->getCol());
auto row = builder.getI32IntegerAttr(0);
auto dir = builder.getI32IntegerAttr(0);
auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex());
auto col_num = builder.getI32IntegerAttr(1);
auto row_num = builder.getI32IntegerAttr(1);
builder.setInsertionPointAfter(dma);
builder.create<AIEX::NpuSyncOp>(dma->getLoc(), col, row, dir, chan,
col_num, row_num);
}
}
auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata());
if (!infoOp)
continue;
if (infoOp->getChannelDir() != AIE::DMAChannelDir::S2MM)
continue;
// Found dma op copying results to host
auto col = builder.getI32IntegerAttr(infoOp->getCol());
auto row = builder.getI32IntegerAttr(0);
auto dir = builder.getI32IntegerAttr(0);
auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex());
auto col_num = builder.getI32IntegerAttr(1);
auto row_num = builder.getI32IntegerAttr(1);
builder.setInsertionPointAfter(dma);
builder.create<AIEX::NpuSyncOp>(dma->getLoc(), col, row, dir, chan,
col_num, row_num);
}

// Attempt to make npu.sync ops contiguous if they are not operating on
Expand All @@ -1189,54 +1204,86 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
if (auto sync = dyn_cast<AIEX::NpuSyncOp>(op))
previsouSyncs.push_back(sync);
else if (auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op)) {
auto infoOp = getAllocOpForSymbol(d, dma.getMetadata());
if (infoOp && infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM &&
!previsouSyncs.empty()) {
auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata());
if (!infoOp)
return;
if (previsouSyncs.empty())
return;
if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) {
for (auto prevSync : previsouSyncs)
prevSync->moveAfter(op);
} else if (infoOp &&
infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S &&
!previsouSyncs.empty()) {
} else if (infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S) {
previsouSyncs.clear();
}
}
});
});
}
}

// Renumber aiex.npu.dma_memcpy_nd ops per column of AIEs.
void renumberNpuDmaOps(Block *blk) {
std::map<int, int> chanToIdMap;
AIE::DeviceOp d = nullptr;
blk->walk([&](AIE::DeviceOp op) { d = op; });
SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps;
if (d)
d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) {
shimDmaAllocOps.push_back(shimDmaAllocOp);
});
// Performance optimization: instead of repeating calls to
// getAllocOpForSymbol with the same symbol name, cache the result of the
// first call and use the cache for subsequent calls. This dramatically
// improves compile time for some designs.
llvm::DenseMap<StringRef, std::optional<AIE::ShimDMAAllocationOp>>
allocationCache;
auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) {
auto iter = allocationCache.find(sym_name);
if (iter != allocationCache.end()) {
return iter->second;
}
auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name);
allocationCache.insert({sym_name, infaOp});
return infaOp;
};
SmallVector<AIE::ObjectFifoCreateOp> objectFifoCreateOps;
if (d)
d.walk([&](AIE::ObjectFifoCreateOp objectFifoCreateOp) {
objectFifoCreateOps.push_back(objectFifoCreateOp);
});
OpBuilder builder(blk->getParentOp());
blk->walk([&](Operation *op) {
if (auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op)) {
OpBuilder builder(dma);
int col = -1;
if (d) {
if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) {
col = infoOp->getCol();
} else if (auto objFifoCreateOp =
getObjectFifoCreateOpForSymbol(d, dma.getMetadata())) {
auto prodTileOp =
objFifoCreateOp->getProducerTile().getDefiningOp<AIE::TileOp>();
if (prodTileOp.isShimTile())
col = prodTileOp.colIndex();
for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) {
auto consTileOp = consumerTileOp.getDefiningOp<AIE::TileOp>();
if (consTileOp.isShimTile()) {
col = consTileOp.colIndex();
}
auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op);
auto sync = dyn_cast<AIEX::NpuSyncOp>(op);
if (sync) {
chanToIdMap.clear();
return;
}
if (!dma)
return;
builder.setInsertionPoint(dma);
int col = -1;
if (d) {
if (auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata())) {
col = infoOp->getCol();
} else if (auto objFifoCreateOp = getObjectFifoCreateOpForSymbol(
objectFifoCreateOps, dma.getMetadata())) {
auto prodTileOp =
objFifoCreateOp->getProducerTile().getDefiningOp<AIE::TileOp>();
if (prodTileOp.isShimTile())
col = prodTileOp.colIndex();
for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) {
auto consTileOp = consumerTileOp.getDefiningOp<AIE::TileOp>();
if (consTileOp.isShimTile()) {
col = consTileOp.colIndex();
}
}
}
if (!chanToIdMap.count(col))
chanToIdMap[col] = 0;
dma->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(dma->getContext(), 64),
chanToIdMap[col]++));
} else if (isa<AIEX::NpuSyncOp>(op))
chanToIdMap.clear();
}
if (!chanToIdMap.count(col))
chanToIdMap[col] = 0;
dma->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(dma->getContext(), 64),
chanToIdMap[col]++));
});
}

Expand Down
18 changes: 9 additions & 9 deletions mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
Expand Down Expand Up @@ -521,9 +521,9 @@ module {

// CHECK-LABEL: aie.device(npu)
// CHECK: func.func @func10(%[[ARG0:.*]]: memref<2654208xi32>)
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
module {
Expand Down Expand Up @@ -701,8 +701,8 @@ module {
// CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32>

module {
Expand Down
Loading