Skip to content

Commit

Permalink
fix test and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 17, 2024
1 parent fb6d4d2 commit b13c577
Showing 1 changed file with 131 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,58 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

using DmaQueueKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;
using DmaBdIdKey = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to erase the DMA wait operations in the queue, except for
/// the last one.
LogicalResult eraseQueueOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations in the queue.
if (waitOps.size() < 2) return success();

Operation *parentOp = waitOps.back()->getParentOp();
// Do not modify the last wait op, it will be kept.
waitOps.pop_back();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
if (waitOp->getParentOp() != parentOp) {
return waitOp.emitError(
"DMA operations to be queued must belong to the same scope");
}
// Erase the wait op.
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
auto dmaOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(token.getDefiningOp());
if (!dmaOp)
waitOp.emitError("expected to operate on an `amdaie.half_dma_cpy_nd`");
if (dmaOp.use_empty()) {
rewriter.setInsertionPoint(dmaOp);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
dmaOp.getLoc(), resultTypeRange, dmaOp.getConnection(),
dmaOp.getInput(), dmaOp.getMixedOffsets(), dmaOp.getMixedSizes(),
dmaOp.getMixedStrides(), dmaOp.getBdId(), dmaOp.getChannel(),
dmaOp.getNextBd(), dmaOp.getStartBd());
rewriter.eraseOp(dmaOp);
}
}
}
return success();
}

/// Utility function to determine whether a DMA wait op can be folded into a
/// queue based on its half DMA copy operation.
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
const AMDAIE::AMDAIEDeviceModel &deviceModel, Operation *queueParentOp,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<DmaQueueKey, SmallVector<uint32_t>> &dmaQueueToBdIds) {
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
// Check if the current operation is in the same scope as the rest of the
// queue.
bool isSameScope = npuHalfDmaCpyNdOp->getParentOp() == queueParentOp;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand Down Expand Up @@ -63,22 +107,24 @@ FailureOr<bool> canFoldByQueue(
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);

// Keep wait op if, either reaches the maximum queue size, or a
// duplicate BD ID in the same tile, or packet flow, or the queue is
// empty
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId = llvm::any_of(dmaQueueToBdIds, [&](const auto &entry) {
return entry.first.first == tileOp &&
llvm::is_contained(entry.second, bdId);
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(bdId);
});
SmallVector<uint32_t> &bdIds = dmaQueueToBdIds[{tileOp, connectionOp}];
DenseSet<uint32_t> &bdIds = dmaBdIdsMap[{tileOp, connectionOp}];
bool canFold = true;
// Can't fold wait op if:
// (1) the current BD ID on the same tile already occurs in the queue, or
// (2) the current operation is a packet flow, or
// (3) reaches the maximum queue size, or
// (4) the queue is empty, or
// (5) the current operation is not in the same scope as the queue.
if (isDuplicateBdId || isPacketFlow || bdIds.size() >= maxQueueSize ||
bdIds.empty()) {
bdIds.empty() || !isSameScope) {
bdIds.clear();
canFold = false;
}
bdIds.push_back(bdId);
bdIds.insert(bdId);
return canFold;
}

Expand All @@ -103,49 +149,43 @@ FailureOr<bool> canFoldByQueue(
LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<DmaQueueKey, SmallVector<uint32_t>> dmaQueueToBdIds;
SmallVector<SmallVector<AMDAIE::NpuDmaWaitOp>> waitOpQueues;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toErase = true;
bool toFold = true;
Operation *queueParentOp =
waitOpQueues.empty() ? waitOp->getParentOp()
: waitOpQueues.back().front()->getParentOp();
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result =
canFoldByQueue(deviceModel, npuHalfDmaCpyNdOp, dmaQueueToBdIds);
FailureOr<bool> result = canFoldByQueue(
deviceModel, queueParentOp, npuHalfDmaCpyNdOp, dmaBdIdsMap);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
toFold &= *result;
}
}
// Erase later to avoid invalidating the iterator.
if (toErase) waitOpsToErase.push_back(waitOp);
// Store all the queues, and modify later to avoid invalidating the
// iterator.
if (toFold) {
// Append the wait op to the last queue if it can be folded.
waitOpQueues.back().push_back(waitOp);
} else {
// Create a new queue if the wait op cannot be folded.
waitOpQueues.push_back(SmallVector<AMDAIE::NpuDmaWaitOp>{waitOp});
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) {
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
// Erase the wait op.
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
if (auto op = dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
if (op.use_empty()) {
rewriter.setInsertionPoint(op);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel(), op.getNextBd(), op.getStartBd());
rewriter.eraseOp(op);
}
}
}
for (SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps : waitOpQueues) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(eraseQueueOperations(rewriter, waitOps))) return failure();
}

return success();
}

Expand Down Expand Up @@ -174,9 +214,14 @@ LogicalResult updateBatchTokens(IRRewriter &rewriter,

/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
FailureOr<bool> canFoldByBatch(Operation *batchParentOp,
AMDAIE::NpuHalfDmaCpyNdOp npuHalfDmaCpyNdOp,
DenseSet<AMDAIE::ConnectionOp> &connectionOps) {
FailureOr<bool> canFoldByBatch(
Operation *batchParentOp, AMDAIE::NpuHalfDmaCpyNdOp npuHalfDmaCpyNdOp,
DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
// Check if the current operation is in the same scope as the rest of the
// batch.
bool isSameScope = npuHalfDmaCpyNdOp->getParentOp() == batchParentOp;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
Expand All @@ -195,17 +240,48 @@ FailureOr<bool> canFoldByBatch(Operation *batchParentOp,
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}

bool isDuplicateConnection = connectionOps.contains(connectionOp);
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(bdId);
});

bool canFold = true;
// Can't fold if the current connection op already occurs in the batch, or
// if the current operation is a packet flow, or if the batch is empty, or
// if the current operation is not in the same scope as the batch.
if (connectionOps.contains(connectionOp) || isPacketFlow ||
connectionOps.empty() ||
(batchParentOp != npuHalfDmaCpyNdOp->getParentOp())) {
// Can't fold wait op if:
// (1) the current connection op already occurs in the batch, or
// (2) the current BD ID on the same tile already occurs in the batch, or
// (3) the current operation is a packet flow, or
// (4) the batch is empty, or
// (5) the current operation is not in the same scope as the batch.
if (isDuplicateConnection || isDuplicateBdId || isPacketFlow ||
connectionOps.empty() || !isSameScope) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
canFold = false;
}
connectionOps.insert(connectionOp);
dmaBdIdsMap[{tileOp, connectionOp}].insert(bdId);
return canFold;
}

Expand Down Expand Up @@ -234,6 +310,7 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
DenseSet<AMDAIE::ConnectionOp> connectionOps;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Expand All @@ -243,14 +320,16 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result =
canFoldByBatch(batchParentOp, npuHalfDmaCpyNdOp, connectionOps);
FailureOr<bool> result = canFoldByBatch(
batchParentOp, npuHalfDmaCpyNdOp, connectionOps, dmaBdIdsMap);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(updateBatchTokens(rewriter, waitOps)))
return WalkResult::interrupt();
Expand Down

0 comments on commit b13c577

Please sign in to comment.