Skip to content

Commit

Permalink
separate canFold decisions with update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 17, 2024
1 parent b13c577 commit a10aedd
Showing 1 changed file with 81 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,20 @@ LogicalResult eraseQueueOperations(IRRewriter &rewriter,
/// Utility function to determine whether a DMA wait op can be folded into a
/// queue based on its half DMA copy operation.
FailureOr<bool> canFoldByQueue(
const AMDAIE::AMDAIEDeviceModel &deviceModel, Operation *queueParentOp,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
const AMDAIE::AMDAIEDeviceModel &deviceModel,
const Operation *queueParentOp,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal,
AMDAIE::NpuHalfDmaCpyNdOp &currHalfDmaCpyNdOp) {
// Check if the current operation is in the same scope as the rest of the
// queue.
bool isSameScope = npuHalfDmaCpyNdOp->getParentOp() == queueParentOp;
bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == queueParentOp;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
currHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
return currHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
Expand All @@ -87,45 +89,41 @@ FailureOr<bool> canFoldByQueue(
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
return currHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}
currBdIdKey = {tileOp, connectionOp};

// Get the maximum queue size.
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);

uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(bdId);
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});
DenseSet<uint32_t> &bdIds = dmaBdIdsMap[{tileOp, connectionOp}];
bool canFold = true;
const DenseSet<uint32_t> &bdIds = dmaBdIdsMap.lookup(currBdIdKey);

// Can't fold wait op if:
// (1) the current BD ID on the same tile already occurs in the queue, or
// (2) the current operation is a packet flow, or
// (3) reaches the maximum queue size, or
// (4) the queue is empty, or
// (5) the current operation is not in the same scope as the queue.
if (isDuplicateBdId || isPacketFlow || bdIds.size() >= maxQueueSize ||
bdIds.empty() || !isSameScope) {
bdIds.clear();
canFold = false;
}
bdIds.insert(bdId);
return canFold;
return !(isDuplicateBdId || isPacketFlow || bdIds.size() >= maxQueueSize ||
bdIds.empty() || !isSameScope);
}

/// Traverses the control code in reverse, ensuring that for each connection,
Expand All @@ -151,6 +149,16 @@ LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
IRRewriter rewriter(controlCodeOp->getContext());
SmallVector<SmallVector<AMDAIE::NpuDmaWaitOp>> waitOpQueues;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) {
assert(currBdIdKey.first && "TileOp must not be null");
assert(currBdIdKey.second && "ConnectionOp must not be null");
if (!canFold) dmaBdIdsMap[currBdIdKey].clear();
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
Expand All @@ -162,10 +170,14 @@ LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldByQueue(
deviceModel, queueParentOp, npuHalfDmaCpyNdOp, dmaBdIdsMap);
DmaBdIdKey currBdIdKey = {nullptr, nullptr};
uint32_t currBdIdVal = 0;
FailureOr<bool> result =
canFoldByQueue(deviceModel, queueParentOp, dmaBdIdsMap,
currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp);
if (failed(result)) return WalkResult::interrupt();
toFold &= *result;
updateWithCurrBdId(*result, dmaBdIdsMap, currBdIdKey, currBdIdVal);
}
}
// Store all the queues, and modify later to avoid invalidating the
Expand All @@ -190,8 +202,8 @@ LogicalResult foldDmaWaitsByQueue(const AMDAIE::AMDAIEDeviceModel &deviceModel,
}

/// For each batch, combine the async tokens into a single NpuDmaWaitOp.
LogicalResult updateBatchTokens(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
LogicalResult eraseBatchOperations(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps) {
// Skip if there are less than two DMA wait operations.
if (waitOps.size() < 2) return success();

Expand All @@ -215,21 +227,24 @@ LogicalResult updateBatchTokens(IRRewriter &rewriter,
/// Utility function to determine if a DMA wait operation can be folded into a
/// a batch based on its half DMA copy operation.
FailureOr<bool> canFoldByBatch(
Operation *batchParentOp, AMDAIE::NpuHalfDmaCpyNdOp npuHalfDmaCpyNdOp,
DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap) {
const Operation *batchParentOp,
const DenseSet<AMDAIE::ConnectionOp> &connectionOps,
const DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t &currBdIdVal,
AMDAIE::NpuHalfDmaCpyNdOp currHalfDmaCpyNdOp) {
// Check if the current operation is in the same scope as the rest of the
// batch.
bool isSameScope = npuHalfDmaCpyNdOp->getParentOp() == batchParentOp;
bool isSameScope = currHalfDmaCpyNdOp->getParentOp() == batchParentOp;

// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
currHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
return currHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
bool isDuplicateConnection = connectionOps.contains(connectionOp);

// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
Expand All @@ -241,48 +256,35 @@ FailureOr<bool> canFoldByBatch(
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = currHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
return currHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
currBdIdVal = getConstantIndexOrAssert(bdIdOp.getValue());

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}
currBdIdKey = {tileOp, connectionOp};

bool isDuplicateConnection = connectionOps.contains(connectionOp);
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId = llvm::any_of(dmaBdIdsMap, [&](const auto &entry) {
return entry.first.first == tileOp && entry.second.contains(bdId);
return entry.first.first == tileOp && entry.second.contains(currBdIdVal);
});

bool canFold = true;
// Can't fold wait op if:
// (1) the current connection op already occurs in the batch, or
// (2) the current BD ID on the same tile already occurs in the batch, or
// (3) the current operation is a packet flow, or
// (4) the batch is empty, or
// (5) the current operation is not in the same scope as the batch.
if (isDuplicateConnection || isDuplicateBdId || isPacketFlow ||
connectionOps.empty() || !isSameScope) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
canFold = false;
}
connectionOps.insert(connectionOp);
dmaBdIdsMap[{tileOp, connectionOp}].insert(bdId);
return canFold;
return !(isDuplicateConnection || isDuplicateBdId || isPacketFlow ||
connectionOps.empty() || !isSameScope);
}

/// Traverses the control code in reverse, ensuring that only one DMA wait op is
Expand Down Expand Up @@ -311,6 +313,27 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
SmallVector<AMDAIE::NpuDmaWaitOp> waitOps;
DenseSet<AMDAIE::ConnectionOp> connectionOps;
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> dmaBdIdsMap;

auto updateWithCurrBdId =
[&](bool canFold, DenseSet<AMDAIE::ConnectionOp> &connectionOps,
DenseMap<DmaBdIdKey, DenseSet<uint32_t>> &dmaBdIdsMap,
DmaBdIdKey &currBdIdKey, uint32_t currBdIdVal) {
assert(currBdIdKey.first && "TileOp must not be null");
assert(currBdIdKey.second && "ConnectionOp must not be null");
if (!canFold) {
// Clear the BD IDs for all the connections in the batch.
for (auto &entry : dmaBdIdsMap) {
ConnectionOp connectionOp = entry.first.second;
DenseSet<uint32_t> &bdIds = entry.second;
if (connectionOps.contains(connectionOp)) bdIds.clear();
}
connectionOps.clear();
}
connectionOps.insert(currBdIdKey.second);
dmaBdIdsMap[currBdIdKey].insert(currBdIdVal);
};

// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toBatch = true;
Expand All @@ -320,18 +343,23 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldByBatch(
batchParentOp, npuHalfDmaCpyNdOp, connectionOps, dmaBdIdsMap);
DmaBdIdKey currBdIdKey = {nullptr, nullptr};
uint32_t currBdIdVal = 0;
FailureOr<bool> result =
canFoldByBatch(batchParentOp, connectionOps, dmaBdIdsMap,
currBdIdKey, currBdIdVal, npuHalfDmaCpyNdOp);
if (failed(result)) return WalkResult::interrupt();
toBatch &= *result;
updateWithCurrBdId(*result, connectionOps, dmaBdIdsMap, currBdIdKey,
currBdIdVal);
}
}
// Process the previous batch of wait ops, and start a new batch.
if (!toBatch) {
// Since the controlcode is traversed in reverse order, we need to
// restore the original order of the DMA operations.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(updateBatchTokens(rewriter, waitOps)))
if (failed(eraseBatchOperations(rewriter, waitOps)))
return WalkResult::interrupt();
waitOps.clear();
}
Expand All @@ -342,7 +370,7 @@ LogicalResult foldDmaWaitsByBatch(AMDAIE::ControlCodeOp controlCodeOp) {
if (res.wasInterrupted()) return failure();
// Process the remaining wait ops.
std::reverse(waitOps.begin(), waitOps.end());
if (failed(updateBatchTokens(rewriter, waitOps))) return failure();
if (failed(eraseBatchOperations(rewriter, waitOps))) return failure();
return success();
}

Expand Down

0 comments on commit a10aedd

Please sign in to comment.