Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 12, 2024
1 parent ee61bdb commit 022bdf8
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,34 @@ LogicalResult convertOp(AMDAIE::NpuAddressPatchOp op,
}

LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
for (Value token : op.getAsyncTokens()) {
auto pushToQueueOp =
dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(token.getDefiningOp());
// Batch DMA operations with the same row, channel, and direction into a
// single TCT sync operation, as long as they have consecutive columns.
SmallVector<std::pair<AMDAIE::NpuPushToQueueOp, uint32_t>> columnBatches;
for (Value asyncToken : op.getAsyncTokens()) {
auto pushToQueueOp = dyn_cast_if_present<AMDAIE::NpuPushToQueueOp>(
asyncToken.getDefiningOp());
if (!pushToQueueOp) {
return op.emitOpError()
<< "should operate on an `amdaie.push_to_queue` op";
<< "should operate on an `amdaie.push_to_queue` op async token";
}
if (!columnBatches.empty()) {
auto &[lastPushOp, lastColNum] = columnBatches.back();
if (lastPushOp.getRow() == pushToQueueOp.getRow() &&
lastPushOp.getCol() + lastColNum == pushToQueueOp.getCol() &&
lastPushOp.getDirection() == pushToQueueOp.getDirection() &&
lastPushOp.getChannel() == pushToQueueOp.getChannel()) {
++lastColNum;
continue;
}
}
columnBatches.push_back({pushToQueueOp, 1});
}

// Convert to TCT sync ops.
for (auto &[pushToQueueOp, colNum] : columnBatches) {
if (failed(builder.appendTCTSync(
pushToQueueOp.getCol(), pushToQueueOp.getRow(),
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, 1,
static_cast<uint32_t>(pushToQueueOp.getDirection()), 1, colNum,
pushToQueueOp.getChannel()))) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace {

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldBasedOnHalfDmaCpy(
FailureOr<bool> canFoldByConnection(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
Expand Down Expand Up @@ -101,8 +101,9 @@ FailureOr<bool> canFoldBasedOnHalfDmaCpy(
/// Reverse traversal simplifies handling duplicate BD IDs, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
LogicalResult foldDmaWaitsByConnection(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
Expand All @@ -116,7 +117,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldBasedOnHalfDmaCpy(
FailureOr<bool> result = canFoldByConnection(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
Expand Down Expand Up @@ -152,6 +153,147 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
return success();
}

struct DmaColumnBatch {
uint32_t row;
uint32_t channel;
AMDAIE::DMAChannelDir direction;

// Sorted by column.
std::map<uint32_t, AMDAIE::NpuDmaWaitOp> colWaitOpMap;
};

/// Updates a batch of asynchronous DMA wait operations by combining their
/// async tokens into a single NpuDmaWaitOp.
void updateColumnBatchTokens(
IRRewriter &rewriter,
std::map<uint32_t, AMDAIE::NpuDmaWaitOp> &colWaitOpMap) {
if (colWaitOpMap.size() < 2) return;

// Check if there is any discontinuity in the columns, and if so, split into
// separate batches.
SmallVector<SmallVector<AMDAIE::NpuDmaWaitOp>> waitOpsList;
uint32_t prevCol = 0;
for (auto &entry : colWaitOpMap) {
uint32_t col = entry.first;
AMDAIE::NpuDmaWaitOp waitOp = entry.second;
if (waitOpsList.empty() || col != prevCol + 1) {
waitOpsList.push_back({});
}
waitOpsList.back().push_back(waitOp);
prevCol = col;
}

for (SmallVector<AMDAIE::NpuDmaWaitOp> &waitOps : waitOpsList) {
// For each batch, combine the async tokens into a single NpuDmaWaitOp.
SmallVector<Value> asyncTokens;
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
asyncTokens.append(waitOp.getAsyncTokens().begin(),
waitOp.getAsyncTokens().end());
}
rewriter.setInsertionPointAfter(waitOps.back());
rewriter.create<AMDAIE::NpuDmaWaitOp>(waitOps.back().getLoc(), asyncTokens);
for (AMDAIE::NpuDmaWaitOp waitOp : waitOps) {
rewriter.eraseOp(waitOp);
}
}
}

/// Utility function to determine if a DMA wait operation can be folded.
/// This is achieved by verifying whether it shares the same row, channel,
/// and direction with preceding wait operations.
LogicalResult foldByColumn(IRRewriter &rewriter, DmaColumnBatch &dmaBatch,
AMDAIE::NpuHalfDmaCpyNdOp dmaOp,
AMDAIE::NpuDmaWaitOp waitOp) {
// Get the row and column.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = dmaOp.getBdIdOp();
if (!maybeBdIdOp) return dmaOp.emitOpError() << "must have a BD ID op";
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp)
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());

// Get the channel.
std::optional<AMDAIE::ChannelOp> maybeChannelOp = dmaOp.getChannelOp();
if (!maybeChannelOp)
return dmaOp.emitOpError() << "found non-`amdaie.channel` channel";
AMDAIE::ChannelOp channelOp = maybeChannelOp.value();
std::optional<AMDAIE::DMAChannelDir> maybeDirection =
channelOp.getDirection();
std::optional<uint32_t> maybeChannel = channelOp.getValue();
if (!maybeDirection || !maybeChannel)
return channelOp.emitOpError() << "direction and channel needed";
AMDAIE::DMAChannelDir direction = maybeDirection.value();
uint32_t channel = maybeChannel.value();

if (dmaBatch.colWaitOpMap.empty() || row != dmaBatch.row ||
channel != dmaBatch.channel || direction != dmaBatch.direction) {
updateColumnBatchTokens(rewriter, dmaBatch.colWaitOpMap);
dmaBatch = {row, channel, direction, {}};
}
dmaBatch.colWaitOpMap[col] = waitOp;
return success();
}

/// Traverses the control code forward, ensuring that only one DMA wait op is
/// retained for all the columns.
///
/// Example Input:
/// %0 = dma_cpy_nd(col=0)
/// %1 = dma_cpy_nd(col=1)
/// %2 = dma_cpy_nd(col=2)
/// %3 = dma_cpy_nd(col=3)
/// dma_wait(%0)
/// dma_wait(%1)
/// dma_wait(%2)
/// dma_wait(%3)
/// Example Output:
/// %0 = dma_cpy_nd(col=0)
/// %1 = dma_cpy_nd(col=1)
/// %2 = dma_cpy_nd(col=2)
/// %3 = dma_cpy_nd(col=3)
/// dma_wait(%0, %1, %2, %3)
LogicalResult foldDmaWaitsByColumn(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
DmaColumnBatch dmaBatch = {};

WalkResult res = controlCodeOp->walk([&](Operation *op) {
auto waitOp = dyn_cast<AMDAIE::NpuDmaWaitOp>(op);
// Skip if not a DMA wait op or if it already has multiple async tokens.
if (!waitOp || waitOp.getAsyncTokens().size() != 1) {
updateColumnBatchTokens(rewriter, dmaBatch.colWaitOpMap);
dmaBatch.colWaitOpMap.clear();
return WalkResult::advance();
}

// Get the half DMA copy operation.
Value token = waitOp.getAsyncTokens().front();
auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(token.getDefiningOp());
if (!npuHalfDmaCpyNdOp) {
waitOp.emitOpError() << "expected to operate on an "
"`amdaie.npu.half_dma_cpy_nd`";
return WalkResult::interrupt();
}

// Check if the DMA wait op can be folded into the column batch.
if (succeeded(
foldByColumn(rewriter, dmaBatch, npuHalfDmaCpyNdOp, waitOp))) {
return WalkResult::advance();
} else {
return WalkResult::interrupt();
}
});

// Process the remaining wait ops.
updateColumnBatchTokens(rewriter, dmaBatch.colWaitOpMap);
if (res.wasInterrupted()) return failure();
return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
Expand Down Expand Up @@ -181,7 +323,10 @@ void AMDAIEFoldDmaWaitsPass::runOnOperation() {

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
if (failed(foldDmaWaitsByConnection(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
if (failed(foldDmaWaitsByColumn(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,59 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}

// -----

// CHECK: 0x06030100
// CHECK: 0x00000105
// CHECK: 0x00000005
// CHECK: 0x00000080
// CHECK: 0x00000000
// CHECK: 0x00000000
// CHECK: 0x0001D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000018
// CHECK: 0x00000000
// CHECK: 0x00000000
// CHECK: 0x0201D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000018
// CHECK: 0x00000000
// CHECK: 0x00000000
// CHECK: 0x0401D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000018
// CHECK: 0x00000000
// CHECK: 0x00000000
// CHECK: 0x0601D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000018
// CHECK: 0x00000080
// CHECK: 0x00000010
// CHECK: 0x00000001
// CHECK: 0x00040100
// CHECK-LABEL: @async_push_to_queue_and_wait_col_num
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<32xui32>
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @async_push_to_queue_and_wait_col_num() {
amdaie.workgroup {
amdaie.controlcode {
%0 = amdaie.npu.push_to_queue async {bd_id = 0 : ui32, channel = 0 : ui32, col = 0 : ui32, direction = 1 : i32, repeat_count = 1 : ui32, row = 0 : ui32}
%1 = amdaie.npu.push_to_queue async {bd_id = 0 : ui32, channel = 0 : ui32, col = 1 : ui32, direction = 1 : i32, repeat_count = 1 : ui32, row = 0 : ui32}
%2 = amdaie.npu.push_to_queue async {bd_id = 0 : ui32, channel = 0 : ui32, col = 2 : ui32, direction = 1 : i32, repeat_count = 1 : ui32, row = 0 : ui32}
%3 = amdaie.npu.push_to_queue async {bd_id = 0 : ui32, channel = 0 : ui32, col = 3 : ui32, direction = 1 : i32, repeat_count = 1 : ui32, row = 0 : ui32}
amdaie.npu.dma_wait(%0, %1, %2, %3 : !amdaie.async_token, !amdaie.async_token, !amdaie.async_token, !amdaie.async_token)
amdaie.end
}
}
return
}
}

// -----

// CHECK: 0x06030100
// CHECK: 0x00000105
// CHECK: 0x00000001
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,93 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}

// -----

// The first two DMA operations are expected to be batched into a single DMA wait, as they share the same row,
// channel, and direction, with consecutive columns (0 and 1). The third DMA operation is not batched because
// its column (3) is not consecutive with the previous operations.
// CHECK-LABEL: @fold_dma_waits_column_batch
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[TILE_1_0:.+]] = amdaie.tile(%[[C1]], %[[C0]])
// CHECK: %[[TILE_3_0:.+]] = amdaie.tile(%[[C3]], %[[C0]])
// CHECK: %[[BD_ID_0:.+]] = amdaie.bd_id(%[[TILE_0_0]], %[[C0]])
// CHECK: %[[TOKEN_0:.+]] = amdaie.npu.half_dma_cpy_nd async %{{.+}}(%{{.+}} [] [] [] bd_id = %[[BD_ID_0]]
// CHECK: %[[BD_ID_1:.+]] = amdaie.bd_id(%[[TILE_1_0]], %[[C0]])
// CHECK: %[[TOKEN_1:.+]] = amdaie.npu.half_dma_cpy_nd async %{{.+}}(%{{.+}} [] [] [] bd_id = %[[BD_ID_1]]
// CHECK: %[[BD_ID_2:.+]] = amdaie.bd_id(%[[TILE_3_0]], %[[C0]])
// CHECK: %[[TOKEN_2:.+]] = amdaie.npu.half_dma_cpy_nd async %{{.+}}(%{{.+}} [] [] [] bd_id = %[[BD_ID_2]]
// CHECK: amdaie.npu.dma_wait(%[[TOKEN_0]], %[[TOKEN_1]] : !amdaie.async_token, !amdaie.async_token)
// CHECK: amdaie.npu.dma_wait(%[[TOKEN_2]] : !amdaie.async_token)
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @fold_dma_waits_column_batch() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
amdaie.workgroup {
%tile_0_1 = amdaie.tile(%c0, %c1)
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_1_1 = amdaie.tile(%c1, %c1)
%tile_1_0 = amdaie.tile(%c1, %c0)
%tile_3_1 = amdaie.tile(%c3, %c1)
%tile_3_0 = amdaie.tile(%c3, %c0)
%buffer = amdaie.buffer(%tile_0_1) : memref<2048xi32, 1 : i32>
%buffer_0 = amdaie.buffer(%tile_0_1) : memref<2048xi32, 1 : i32>
%buffer_1 = amdaie.buffer(%tile_1_1) : memref<2048xi32, 1 : i32>
%buffer_2 = amdaie.buffer(%tile_1_1) : memref<2048xi32, 1 : i32>
%buffer_3 = amdaie.buffer(%tile_3_1) : memref<2048xi32, 1 : i32>
%buffer_4 = amdaie.buffer(%tile_3_1) : memref<2048xi32, 1 : i32>
%lock = amdaie.lock(%tile_0_1(4), 4)
%lock_5 = amdaie.lock(%tile_0_1(5), 0)
%lock_6 = amdaie.lock(%tile_1_1(4), 4)
%lock_7 = amdaie.lock(%tile_1_1(5), 0)
%lock_8 = amdaie.lock(%tile_3_1(4), 4)
%lock_9 = amdaie.lock(%tile_3_1(5), 0)
%0 = amdaie.logicalobjectfifo.from_buffers({%buffer, %buffer_0}, {%lock}, {%lock_5}) : memref<2048xi32, 1 : i32>, memref<2048xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<64x32xi32>
%2 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<64x32xi32>>
%3 = amdaie.logicalobjectfifo.from_buffers({%buffer_1, %buffer_2}, {%lock_6}, {%lock_7}) : memref<2048xi32, 1 : i32>, memref<2048xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<64x32xi32>
%5 = amdaie.logicalobjectfifo.placeholder{%tile_1_0} : !amdaie.logicalobjectfifo<memref<64x32xi32>>
%6 = amdaie.logicalobjectfifo.from_buffers({%buffer_3, %buffer_4}, {%lock_8}, {%lock_9}) : memref<2048xi32, 1 : i32>, memref<2048xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>
%7 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<64x32xi32>
%8 = amdaie.logicalobjectfifo.placeholder{%tile_3_0} : !amdaie.logicalobjectfifo<memref<64x32xi32>>
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_10 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%channel_11 = amdaie.channel(%tile_1_0, 0, port_type = DMA, direction = MM2S)
%channel_12 = amdaie.channel(%tile_1_1, 0, port_type = DMA, direction = S2MM)
%channel_13 = amdaie.channel(%tile_3_0, 0, port_type = DMA, direction = MM2S)
%channel_14 = amdaie.channel(%tile_3_1, 0, port_type = DMA, direction = S2MM)
%9 = amdaie.flow({%channel} -> {%channel_10}) {is_packet_flow = false}
%10 = amdaie.connection(%0 {%channel_10}, %2 {%channel}, flow = %9) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>, !amdaie.logicalobjectfifo<memref<64x32xi32>>)
%11 = amdaie.flow({%channel_11} -> {%channel_12}) {is_packet_flow = false}
%12 = amdaie.connection(%3 {%channel_12}, %5 {%channel_11}, flow = %11) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>, !amdaie.logicalobjectfifo<memref<64x32xi32>>)
%13 = amdaie.flow({%channel_13} -> {%channel_14}) {is_packet_flow = false}
%14 = amdaie.connection(%6 {%channel_14}, %8 {%channel_13}, flow = %13) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>, 2>, !amdaie.logicalobjectfifo<memref<64x32xi32>>)
amdaie.controlcode {
%15 = amdaie.logicalobjectfifo.from_memref %1, {%tile_0_0} : memref<64x32xi32> -> !amdaie.logicalobjectfifo<memref<2048xi32>>
memref.assume_alignment %1, 64 : memref<64x32xi32>
%16 = amdaie.logicalobjectfifo.from_memref %4, {%tile_1_0} : memref<64x32xi32> -> !amdaie.logicalobjectfifo<memref<2048xi32>>
memref.assume_alignment %4, 64 : memref<64x32xi32>
%17 = amdaie.logicalobjectfifo.from_memref %7, {%tile_3_0} : memref<64x32xi32> -> !amdaie.logicalobjectfifo<memref<2048xi32>>
memref.assume_alignment %7, 64 : memref<64x32xi32>
%bd_id = amdaie.bd_id(%tile_0_0, %c0)
%18 = amdaie.npu.half_dma_cpy_nd async %10(%15 [] [] [] bd_id = %bd_id channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
%bd_id_15 = amdaie.bd_id(%tile_1_0, %c0)
%19 = amdaie.npu.half_dma_cpy_nd async %12(%16 [] [] [] bd_id = %bd_id_15 channel = %channel_11) : !amdaie.logicalobjectfifo<memref<2048xi32>>
%bd_id_16 = amdaie.bd_id(%tile_3_0, %c0)
%20 = amdaie.npu.half_dma_cpy_nd async %14(%17 [] [] [] bd_id = %bd_id_16 channel = %channel_13) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.dma_wait(%18 : !amdaie.async_token)
amdaie.npu.dma_wait(%19 : !amdaie.async_token)
amdaie.npu.dma_wait(%20 : !amdaie.async_token)
amdaie.end
}
}
return
}
}

0 comments on commit 022bdf8

Please sign in to comment.