Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 17, 2024
1 parent ac2a7ec commit fb6d4d2
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaQueue = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaQueueKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to determine whether a DMA wait op can be folded into a
/// queue based on its half DMA copy operation.
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<DmaQueue, SmallVector<uint32_t>> &dmaQueueToBdIds) {
DenseMap<DmaQueueKey, SmallVector<uint32_t>> &dmaQueueToBdIds) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand Down Expand Up @@ -104,7 +104,7 @@ LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<DmaQueue, SmallVector<uint32_t>> dmaQueueToBdIds;
DenseMap<DmaQueueKey, SmallVector<uint32_t>> dmaQueueToBdIds;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
Expand Down Expand Up @@ -168,17 +168,15 @@ LogicalResult updateBatchTokens(IRRewriter &rewriter,

rewriter.setInsertionPointAfter(waitOps.back());
rewriter.create<AMDAIE::NpuDmaWaitOp>(waitOps.back().getLoc(), asyncTokens);
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
rewriter.eraseOp(waitOp);
}
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) rewriter.eraseOp(waitOp);
return success();
}

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
FailureOr<bool> canFoldByBatch(
AMDAIE::NpuHalfDmaCpyNdOp npuHalfDmaCpyNdOp,
SmallVector<AMDAIE::ConnectionOp> &connectionOps) {
FailureOr<bool> canFoldByBatch(Operation *batchParentOp,
AMDAIE::NpuHalfDmaCpyNdOp npuHalfDmaCpyNdOp,
DenseSet<AMDAIE::ConnectionOp> &connectionOps) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand All @@ -199,17 +197,19 @@ FailureOr<bool> canFoldByBatch(

bool canFold = true;
// Can't fold if the current connection op already occurs in the batch, or
// if the current operation is a packet flow, or if the batch is empty.
if (llvm::is_contained(connectionOps, connectionOp) || isPacketFlow ||
connectionOps.empty()) {
// if the current operation is a packet flow, or if the batch is empty, or
// if the current operation is not in the same scope as the batch.
if (connectionOps.contains(connectionOp) || isPacketFlow ||
connectionOps.empty() ||
(batchParentOp != npuHalfDmaCpyNdOp->getParentOp())) {
connectionOps.clear();
canFold = false;
}
connectionOps.push_back(connectionOp);
connectionOps.insert(connectionOp);
return canFold;
}

/// Traverses the control code forward, ensuring that only one DMA wait op is
/// Traverses the control code in reverse, ensuring that only one DMA wait op is
/// retained for every batch of DMA copy operations.
///
/// Example Input:
Expand All @@ -227,34 +227,42 @@ FailureOr<bool> canFoldByBatch(
/// %2 = dma_cpy_nd(connection2)
/// %3 = dma_cpy_nd(connection3)
/// dma_wait(%0, %1, %2, %3)
/// Reverse traversal simplifies handling duplicate connections, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
SmallVector<AMDAIE::ConnectionOp> connectionOps;
WalkResult res = controlCodeOp->walk([&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result =
canFoldByBatch(npuHalfDmaCpyNdOp, connectionOps);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
if (failed(updateBatchTokens(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
waitOps.push_back(waitOp);
return WalkResult::advance();
});
DenseSet<AMDAIE::ConnectionOp> connectionOps;
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Operation *batchParentOp =
waitOps.empty() ? waitOp->getParentOp() : waitOps[0]->getParentOp();
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result =
canFoldByBatch(batchParentOp, npuHalfDmaCpyNdOp, connectionOps);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
std::reverse(waitOps.begin(), waitOps.end());
if (failed(updateBatchTokens(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
waitOps.push_back(waitOp);
return WalkResult::advance();
});

if (res.wasInterrupted()) return failure();
// Process the remaining wait ops.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(updateBatchTokens(rewriter, waitOps))) return failure();
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaChain = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaChainKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to update `next_bd` and `start_bd` operands.
LogicalResult updateChainOperands(
Expand Down Expand Up @@ -83,9 +83,9 @@ LogicalResult updateChainOperands(
/// - Chain X: [0] (the newly added BD ID).
/// - Chain Y: [] (emptied after breaking).
void checkForChainsToBeBroken(
uint32_t currBdId, const DmaChain &currDmaChain,
const DenseMap<DmaChain, DenseSet<uint32_t>> &dmaChainToBdIds,
SmallVector<DmaChain> &chainsToBreak) {
uint32_t currBdId, const DmaChainKey &currDmaChain,
const DenseMap<DmaChainKey, DenseSet<uint32_t>> &dmaChainToBdIds,
SmallVector<DmaChainKey> &chainsToBreak) {
for (auto &[entry, bdIds] : dmaChainToBdIds) {
if (entry.first == currDmaChain.first && bdIds.contains(currBdId)) {
// Break the chain that contains the duplicate BD ID.
Expand Down Expand Up @@ -120,9 +120,10 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
}

// BD IDs that have been assigned in each tile.
DenseMap<DmaChain, DenseSet<uint32_t>> dmaChainToBdIds;
DenseMap<DmaChainKey, DenseSet<uint32_t>> dmaChainToBdIds;
// Buffers the DMA ops that will be chained.
DenseMap<DmaChain, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp>> dmaChainToDmaOps;
DenseMap<DmaChainKey, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp>>
dmaChainToDmaOps;

res = controlCodeOp->walk<WalkOrder::PostOrder,
ReverseIterator>([&](Operation *op) {
Expand Down Expand Up @@ -185,8 +186,8 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
// Any duplicate BD ID from the same tile indicates that the chain
// cannot grow further and requires breaking to release the
// conflicting BD ID.
SmallVector<DmaChain> chainsToBreak;
DmaChain currDmaChain = {tileOp, connectionOp};
SmallVector<DmaChainKey> chainsToBreak;
DmaChainKey currDmaChain = {tileOp, connectionOp};
checkForChainsToBeBroken(bdId, currDmaChain, dmaChainToBdIds,
chainsToBreak);

Expand Down
Loading

0 comments on commit fb6d4d2

Please sign in to comment.