Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass to fold DMA waits #962

Merged
merged 11 commits into from
Dec 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
LogicalResult convertOp(AMDAIE::NpuPushToQueueOp op,
TransactionBuilder &builder) {
uint32_t repeatCount = op.getRepeatCount() - 1;
if (failed(builder.appendPushToQueueOp(op.getCol(), op.getRow(),
op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, true))) {
if (failed(builder.appendPushToQueueOp(
op.getCol(), op.getRow(), op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, static_cast<bool>(op.getAsyncToken())))) {
return failure();
}
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/Iterators.h"
#define DEBUG_TYPE "iree-amdaie-fold-dma-waits"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Traverses the control code in reverse, ensuring that for each connection,
/// only one DMA wait op is retained for every maximum queue size.
LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>,
SmallVector<uint32_t>>
tileConnectionToBdIdQueueMap;
// Traverse the control code in reverse.
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toErase = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still kinda think a function at the level

FailureOr canFoldBasedOnNpuHalfDmaCpyNdOp(...)

would make for slightly easier to read (less indented) code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, made it a function now

dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
return WalkResult::interrupt();
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp =
maybeConnectionOp->getFlowOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
maybeConnectionOp->getFlowOp();
connectionOp.getFlowOp();

if (!maybeFlowOp) {
maybeConnectionOp->emitOpError()
<< "expected to operate on an `amdaie.flow`";
return WalkResult::interrupt();
}
bool isPacketFlow = maybeFlowOp->getIsPacketFlow();
// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp =
npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
return WalkResult::interrupt();
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
// Retrieve the tile op.
AMDAIE::TileOp tileOp = dyn_cast_if_present<AMDAIE::TileOp>(
bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
return WalkResult::interrupt();
}
// Get the maximum queue size.
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);
// Keep wait op if, either reaches the maximum queue size, or there
// is a duplicate BD ID in the same tile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// is a duplicate BD ID in the same tile.
// is a duplicate BD ID in the same tile, or packet flow, or the queue is empty

?

uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
bool isDuplicateBdId = llvm::any_of(
tileConnectionToBdIdQueueMap, [&](const auto &entry) {
return entry.first.first == tileOp &&
llvm::is_contained(entry.second, bdId);
});
SmallVector<uint32_t> &bdIdQueue =
tileConnectionToBdIdQueueMap[std::make_pair(tileOp,
connectionOp)];
if (isDuplicateBdId || isPacketFlow ||
bdIdQueue.size() >= maxQueueSize) {
bdIdQueue.clear();
}
if (bdIdQueue.empty()) toErase = false;
bdIdQueue.push_back(bdId);
}
}
// Erase later to avoid invalidating the iterator.
if (toErase) waitOpsToErase.push_back(waitOp);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) {
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
// Erase the wait op.
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
if (auto op = dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
if (op.use_empty()) {
rewriter.setInsertionPoint(op);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel());
rewriter.eraseOp(op);
}
}
}
}

return success();
}

class AMDAIEFoldDmaWaitsPass
: public impl::AMDAIEFoldDmaWaitsBase<AMDAIEFoldDmaWaitsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

AMDAIEFoldDmaWaitsPass() = default;
AMDAIEFoldDmaWaitsPass(const AMDAIEFoldDmaWaitsPass &pass){};
void runOnOperation() override;
};

void AMDAIEFoldDmaWaitsPass::runOnOperation() {
Operation *parentOp = getOperation();

auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
parentOp->emitOpError()
<< "has no AMDAIEDevice in the target attribute configuration. This "
"device-specific information is required to fold DMA wait "
"ops.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
if (failed(foldDmaWaits(deviceModel, controlCodeOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEFoldDmaWaitsPass() {
return std::make_unique<AMDAIEFoldDmaWaitsPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ iree_cc_library(
"AMDAIEDmaToCircularDma.cpp"
"AMDAIEFlattenLogicalObjectFifo.cpp"
"AMDAIELinalgFunctionOutlining.cpp"
"AMDAIEFoldDmaWaits.cpp"
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO
#define GEN_PASS_DEF_AMDAIELINALGFUNCTIONOUTLINING
#define GEN_PASS_DEF_AMDAIEFOLDDMAWAITS
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
passManager.addPass(createAMDAIEFoldDmaWaitsPass());
passManager.addPass(createAMDAIEControlCodeLoweringPass());
passManager.addPass(createAMDAIEControlCodeToTransactionPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ std::unique_ptr<Pass> createAMDAIEHoistLogicalObjFifoPass();
std::unique_ptr<Pass> createAMDAIEInsertLoopsForVectorizationPass(
AMDAIEInsertLoopsForVectorizationOptions options = {});

/// Create a pass to remove redundant DMA wait operations.
std::unique_ptr<Pass> createAMDAIEFoldDmaWaitsPass();

/// Create a pass to fuse the pack operations into the for loops.
std::unique_ptr<Pass> createAMDAIEFusePackIntoLoopPass(
AMDAIEFusePackIntoLoopOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def AMDAIELinalgFunctionOutlining :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIELinalgFunctionOutliningPass()";
}

def AMDAIEFoldDmaWaits :
Pass<"iree-amdaie-fold-dma-waits", ""> {
let summary = "Remove redundant dma wait operations in controlcode.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFoldDmaWaitsPass()";
}

def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ iree_lit_test_suite(
"dma_loop_subsumption_circular.mlir"
"dma_loop_subsumption.mlir"
"dma_to_circular_dma.mlir"
"fold_dma_waits.mlir"
"flatten_logical_objectfifo.mlir"
"linalg_function_outlining.mlir"
"fuse_consumer_into_loop.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0001D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000000
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue_default_values
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand All @@ -102,7 +102,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0601D21C
// CHECK: 0x00000000
// CHECK: 0x803F0002
// CHECK: 0x003F0002
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand Down
Loading
Loading