Skip to content

Commit

Permalink
separate refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 17, 2024
1 parent a10aedd commit cc1ae9e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,125 +18,83 @@ namespace {

using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to erase the DMA wait operations in the queue, except for
/// the last one.
LogicalResult eraseQueueOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations in the queue.
if (waitOps.size() < 2) return success();

Operation *parentOp = waitOps.back()->getParentOp();
// Do not modify the last wait op, it will be kept.
waitOps.pop_back();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
if (waitOp->getParentOp() != parentOp) {
return waitOp.emitError(
"DMA operations to be queued must belong to the same scope");
}
// Erase the wait op.
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
auto dmaOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(token.getDefiningOp());
if (!dmaOp)
waitOp.emitError("expected to operate on an `amdaie.half_dma_cpy_nd`");
if (dmaOp.use_empty()) {
rewriter.setInsertionPoint(dmaOp);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
dmaOp.getLoc(), resultTypeRange, dmaOp.getConnection(),
dmaOp.getInput(), dmaOp.getMixedOffsets(), dmaOp.getMixedSizes(),
dmaOp.getMixedStrides(), dmaOp.getBdId(), dmaOp.getChannel(),
dmaOp.getNextBd(), dmaOp.getStartBd());
rewriter.eraseOp(dmaOp);
}
}
}
return success();
}

/// Utility function to determine whether a DMA wait op can be folded into a
/// queue based on its half DMA copy operation.
/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
const Operation *queueParentOp,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal,
AMDAIE::NpuHalfDmaCpyNdOp &currHalfDmaCpyNdOp) {
// Check if the current operation is in the same scope as the rest of the
// queue.
bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == queueParentOp;

AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
currHalfDmaCpyNdOp.getConnectionOp();
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return currHalfDmaCpyNdOp.emitOpError()
return npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();

// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp.emitOpError()
return connectionOp->emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp();
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return currHalfDmaCpyNdOp.emitOpError()
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}
currBdIdKey = {tileOp, connectionOp};

// Get the maximum queue size.
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);

bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});
const DenseSet<uint32_t> &bdIds = dmaBdIdsMap.lookup(currBdIdKey);

// Can't fold wait op if:
// (1) the current BD ID on the same tile already occurs in the queue, or
// (2) the current operation is a packet flow, or
// (3) reaches the maximum queue size, or
// (4) the queue is empty, or
// (5) the current operation is not in the same scope as the queue.
return !(isDuplicateBdId || isPacketFlow || bdIds.size() >= maxQueueSize ||
bdIds.empty() || !isSameScope);
// Keep wait op if, either reaches the maximum queue size, or a
// duplicate BD ID in the same tile, or packet flow, or the queue is
// empty
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId =
llvm::any_of(tileConnectToBdIdQueue, [&](const auto &entry) {
return entry.first.first == tileOp &&
llvm::is_contained(entry.second, bdId);
});
SmallVector<uint32_t> &bdIdQueue =
tileConnectToBdIdQueue[{tileOp, connectionOp}];
bool canFold = true;
if (isDuplicateBdId || isPacketFlow || bdIdQueue.size() >= maxQueueSize ||
bdIdQueue.empty()) {
bdIdQueue.clear();
canFold = false;
}
bdIdQueue.push_back(bdId);
return canFold;
}

/// Traverses the control code in reverse, ensuring that for each connection,
/// only one DMA wait op is retained for every maximum queue size.
///
/// Example Output: assuming a maximum queue size of 4.
/// dma_cpy_nd(connection=0, bd_id=0)
/// %0 = dma_cpy_nd(connection=0, bd_id=1)
/// dma_cpy_nd
/// %0 = dma_cpy_nd
/// dma_wait(%0)
/// dma_cpy_nd(connection=0, bd_id=2)
/// dma_cpy_nd(connection=0, bd_id=3)
/// dma_cpy_nd(connection=0, bd_id=4)
/// %1 = dma_cpy_nd(connection=0, bd_id=5)
/// dma_cpy_nd
/// dma_cpy_nd
/// dma_cpy_nd
/// %1 = dma_cpy_nd
/// dma_wait(%1)
/// From the bottom up, for every four DMA copy operations, only one DMA wait
/// operation is retained.
Expand All @@ -147,57 +105,49 @@ FailureOr<bool> canFoldByQueue(
LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<SmallVector<AMDAIE::NpuDmaWaitOp>> waitOpQueues;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) {
assert(currBdIdKey.first && "TileOp must not be null");
assert(currBdIdKey.second && "ConnectionOp must not be null");
if (!canFold) dmaBdIdsMap[currBdIdKey].clear();
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<DmaBdIdKey, SmallVector<uint32_t>> tileConnectToBdIdQueue;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toFold = true;
Operation *queueParentOp =
waitOpQueues.empty() ? waitOp->getParentOp()
: waitOpQueues.back().front()->getParentOp();
bool toErase = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
DmaBdIdKey currBdIdKey = {nullptr, nullptr};
uint32_t currBdIdVal = 0;
FailureOr<bool> result =
canFoldByQueue(deviceModel, queueParentOp, dmaBdIdsMap,
currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp);
FailureOr<bool> result = canFoldByQueue(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toFold &= *result;
updateWithCurrBdId(*result, dmaBdIdsMap, currBdIdKey, currBdIdVal);
toErase &= *result;
}
}
// Store all the queues, and modify later to avoid invalidating the
// iterator.
if (toFold) {
// Append the wait op to the last queue if it can be folded.
waitOpQueues.back().push_back(waitOp);
} else {
// Create a new queue if the wait op cannot be folded.
waitOpQueues.push_back(SmallVector<AMDAIE::NpuDmaWaitOp>{waitOp});
}
// Erase later to avoid invalidating the iterator.
if (toErase) waitOpsToErase.push_back(waitOp);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
for (SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps : waitOpQueues) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseQueueOperations(rewriter, waitOps))) return failure();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) {
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
// Erase the wait op.
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
if (auto op = dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
if (op.use_empty()) {
rewriter.setInsertionPoint(op);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel(), op.getNextBd(), op.getStartBd());
rewriter.eraseOp(op);
}
}
}
}

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaChainKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaChain = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to update `next_bd` and `start_bd` operands.
LogicalResult updateChainOperands(
Expand Down Expand Up @@ -83,9 +83,9 @@ LogicalResult updateChainOperands(
/// - Chain X: [0] (the newly added BD ID).
/// - Chain Y: [] (emptied after breaking).
void checkForChainsToBeBroken(
uint32_t currBdId, const DmaChainKey &currDmaChain,
const DenseMap<DmaChainKey, DenseSet<uint32_t>> &dmaChainToBdIds,
SmallVector<DmaChainKey> &chainsToBreak) {
uint32_t currBdId, const DmaChain &currDmaChain,
const DenseMap<DmaChain, DenseSet<uint32_t>> &dmaChainToBdIds,
SmallVector<DmaChain> &chainsToBreak) {
for (auto &[entry, bdIds] : dmaChainToBdIds) {
if (entry.first == currDmaChain.first && bdIds.contains(currBdId)) {
// Break the chain that contains the duplicate BD ID.
Expand Down Expand Up @@ -120,10 +120,9 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
}

// BD IDs that have been assigned in each tile.
DenseMap<DmaChainKey, DenseSet<uint32_t>> dmaChainToBdIds;
DenseMap<DmaChain, DenseSet<uint32_t>> dmaChainToBdIds;
// Buffers the DMA ops that will be chained.
DenseMap<DmaChainKey, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp>>
dmaChainToDmaOps;
DenseMap<DmaChain, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp>> dmaChainToDmaOps;

res = controlCodeOp->walk<WalkOrder::PostOrder,
ReverseIterator>([&](Operation *op) {
Expand Down Expand Up @@ -186,8 +185,8 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
// Any duplicate BD ID from the same tile indicates that the chain
// cannot grow further and requires breaking to release the
// conflicting BD ID.
SmallVector<DmaChainKey> chainsToBreak;
DmaChainKey currDmaChain = {tileOp, connectionOp};
SmallVector<DmaChain> chainsToBreak;
DmaChain currDmaChain = {tileOp, connectionOp};
checkForChainsToBeBroken(bdId, currDmaChain, dmaChainToBdIds,
chainsToBreak);

Expand Down

0 comments on commit cc1ae9e

Please sign in to comment.