Skip to content

Commit

Permalink
Add a pass to fold DMA waits (#962)
Browse files Browse the repository at this point in the history
Each DMA channel has a task queue with the depth of 4. DMA wait is only
required for every 4 pushes, reducing unnecessary synchronization.

Example:
https://gist.github.com/Yu-Zhewen/5f569b56c7b1f1a8715a7c4c3bf9e609

Results compared to 7c4b985:
| Test (MxKxN) | Instruction Size Before (Words) | Instruction Size
After (Words) |

|---------------|---------------------------------|--------------------------------|
| 512x4096x512 | 1228 | 1132 |
| 512x512x4096 | 820 | 772 |
| 4096x512x512 | 4628 | 4244 |

This optimization is orthogonal to DMA chaining #931.

---------

Co-authored-by: James Newling <[email protected]>
  • Loading branch information
Yu-Zhewen and newling authored Dec 9, 2024
1 parent f5ab91e commit 2243dd8
Show file tree
Hide file tree
Showing 12 changed files with 448 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
LogicalResult convertOp(AMDAIE::NpuPushToQueueOp op,
TransactionBuilder &builder) {
uint32_t repeatCount = op.getRepeatCount() - 1;
if (failed(builder.appendPushToQueueOp(op.getCol(), op.getRow(),
op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, true))) {
if (failed(builder.appendPushToQueueOp(
op.getCol(), op.getRow(), op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, static_cast<bool>(op.getAsyncToken())))) {
return failure();
}
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/Iterators.h"
#define DEBUG_TYPE "iree-amdaie-fold-dma-waits"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility function to determine whether a DMA wait op can be folded based on
/// its half DMA copy operation.
FailureOr<bool> canFoldBasedOnHalfDmaCpy(
const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::NpuHalfDmaCpyNdOp &npuHalfDmaCpyNdOp,
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>> &tileConnectToBdIdQueue) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();

// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp = connectionOp.getFlowOp();
if (!maybeFlowOp) {
return connectionOp->emitOpError()
<< "expected to operate on an `amdaie.flow`";
}
AMDAIE::FlowOp flowOp = maybeFlowOp.value();
bool isPacketFlow = flowOp.getIsPacketFlow();

// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp = npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
return npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();

// Retrieve the tile op.
AMDAIE::TileOp tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
return bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
}

// Get the maximum queue size.
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);

// Keep wait op if, either reaches the maximum queue size, or a
// duplicate BD ID in the same tile, or packet flow, or the queue is
// empty
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId =
llvm::any_of(tileConnectToBdIdQueue, [&](const auto &entry) {
return entry.first.first == tileOp &&
llvm::is_contained(entry.second, bdId);
});
SmallVector<uint32_t> &bdIdQueue =
tileConnectToBdIdQueue[{tileOp, connectionOp}];
bool canFold = true;
if (isDuplicateBdId || isPacketFlow || bdIdQueue.size() >= maxQueueSize ||
bdIdQueue.empty()) {
bdIdQueue.clear();
canFold = false;
}
bdIdQueue.push_back(bdId);
return canFold;
}

/// Traverses the control code in reverse, ensuring that for each connection,
/// only one DMA wait op is retained for every maximum queue size.
///
/// Example Output: assuming a maximum queue size of 4.
/// dma_cpy_nd
/// %0 = dma_cpy_nd
/// dma_wait(%0)
/// dma_cpy_nd
/// dma_cpy_nd
/// dma_cpy_nd
/// %1 = dma_cpy_nd
/// dma_wait(%1)
/// From the bottom up, for every four DMA copy operations, only one DMA wait
/// operation is retained.
///
/// Reverse traversal simplifies handling duplicate BD IDs, preventing
/// the need to revisit and modify earlier operations after processing later
/// ones.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>>
tileConnectToBdIdQueue;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toErase = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
FailureOr<bool> result = canFoldBasedOnHalfDmaCpy(
deviceModel, npuHalfDmaCpyNdOp, tileConnectToBdIdQueue);
if (failed(result)) return WalkResult::interrupt();
toErase &= *result;
}
}
// Erase later to avoid invalidating the iterator.
if (toErase) waitOpsToErase.push_back(waitOp);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) {
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
// Erase the wait op.
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
if (auto op = dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
if (op.use_empty()) {
rewriter.setInsertionPoint(op);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel());
rewriter.eraseOp(op);
}
}
}
}

return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

AMDAIEFoldDmaWaitsPass() = default;
AMDAIEFoldDmaWaitsPass(const AMDAIEFoldDmaWaitsPass &pass){};
void runOnOperation() override;
};

void AMDAIEFoldDmaWaitsPass::runOnOperation() {
Operation *parentOp = getOperation();

auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
parentOp->emitOpError()
<< "has no AMDAIEDevice in the target attribute configuration. This "
"device-specific information is required to fold DMA wait "
"ops.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEFoldDmaWaitsPass() {
return std::make_unique<AMDAIEFoldDmaWaitsPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ iree_cc_library(
"AMDAIEDmaToCircularDma.cpp"
"AMDAIEFlattenLogicalObjectFifo.cpp"
"AMDAIELinalgFunctionOutlining.cpp"
"AMDAIEFoldDmaWaits.cpp"
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO
#define GEN_PASS_DEF_AMDAIELINALGFUNCTIONOUTLINING
#define GEN_PASS_DEF_AMDAIEFOLDDMAWAITS
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
passManager.addPass(createAMDAIEFoldDmaWaitsPass());
passManager.addPass(createAMDAIEControlCodeLoweringPass());
passManager.addPass(createAMDAIEControlCodeToTransactionPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ std::unique_ptr<Pass> createAMDAIEHoistLogicalObjFifoPass();
std::unique_ptr<Pass> createAMDAIEInsertLoopsForVectorizationPass(
AMDAIEInsertLoopsForVectorizationOptions options = {});

/// Create a pass to remove redundant DMA wait operations.
std::unique_ptr<Pass> createAMDAIEFoldDmaWaitsPass();

/// Create a pass to fuse the pack operations into the for loops.
std::unique_ptr<Pass> createAMDAIEFusePackIntoLoopPass(
AMDAIEFusePackIntoLoopOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ def AMDAIELinalgFunctionOutlining :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIELinalgFunctionOutliningPass()";
}

def AMDAIEFoldDmaWaits :
Pass<"iree-amdaie-fold-dma-waits", ""> {
let summary = "Remove redundant dma wait operations in controlcode.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFoldDmaWaitsPass()";
}

def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_lit_test_suite(
"dma_loop_subsumption_circular.mlir"
"dma_loop_subsumption.mlir"
"dma_to_circular_dma.mlir"
"fold_dma_waits.mlir"
"flatten_logical_objectfifo.mlir"
"linalg_function_outlining.mlir"
"fuse_consumer_into_loop.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0001D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000000
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue_default_values
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand All @@ -102,7 +102,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0601D21C
// CHECK: 0x00000000
// CHECK: 0x803F0002
// CHECK: 0x003F0002
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand Down
Loading

0 comments on commit 2243dd8

Please sign in to comment.