Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass to fold DMA waits #962

Merged
merged 11 commits into from
Dec 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ LogicalResult convertOp(AMDAIE::NpuDmaWaitOp op, TransactionBuilder &builder) {
LogicalResult convertOp(AMDAIE::NpuPushToQueueOp op,
TransactionBuilder &builder) {
uint32_t repeatCount = op.getRepeatCount() - 1;
if (failed(builder.appendPushToQueueOp(op.getCol(), op.getRow(),
op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, true))) {
if (failed(builder.appendPushToQueueOp(
op.getCol(), op.getRow(), op.getDirection(), op.getChannel(),
op.getBdId(), repeatCount, static_cast<bool>(op.getAsyncToken())))) {
return failure();
}
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/Iterators.h"
#define DEBUG_TYPE "iree-amdaie-simplify-dma-waits"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Traverses the control code in reverse, ensuring that for each connection,
/// only one DMA wait op is retained for every maximum queue size.
LogicalResult simplifyDmaWaits(AMDAIE::AMDAIEDeviceModel deviceModel,
AMDAIE::WorkgroupOp workgroupOp) {
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
IRRewriter rewriter(workgroupOp->getContext());
std::vector<AMDAIE::NpuDmaWaitOp> waitOpsToErase;
DenseMap<AMDAIE::ConnectionOp, SmallVector<uint32_t>> connectionToBdIdQueues;
AMDAIE::ControlCodeOp controlCodeOp = workgroupOp.getControlCode();
WalkResult res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toErase = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
// Retrieve the connection op.
std::optional<AMDAIE::ConnectionOp> maybeConnectionOp =
npuHalfDmaCpyNdOp.getConnectionOp();
if (!maybeConnectionOp) {
npuHalfDmaCpyNdOp.emitOpError()
<< "expected to operate on an `amdaie.connection`";
return WalkResult::interrupt();
}
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp =
maybeConnectionOp->getFlowOp();
if (!maybeFlowOp) {
maybeConnectionOp->emitOpError()
<< "expected to operate on an `amdaie.flow`";
return WalkResult::interrupt();
}
if (maybeFlowOp->getIsPacketFlow()) return WalkResult::advance();
// Retrieve the BD ID op.
std::optional<AMDAIE::BdIdOp> maybeBdIdOp =
npuHalfDmaCpyNdOp.getBdIdOp();
if (!maybeBdIdOp) {
npuHalfDmaCpyNdOp.emitOpError()
<< "must have a BD ID op to lower to "
"`amdaie.npu.write_bd`";
return WalkResult::interrupt();
}
AMDAIE::BdIdOp bdIdOp = maybeBdIdOp.value();
// Retrieve the tile op.
AMDAIE::TileOp tileOp = dyn_cast_if_present<AMDAIE::TileOp>(
bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
bdIdOp.emitOpError() << "must operate on an `amdaie.tile`";
return WalkResult::interrupt();
}
// Get the maximum queue size.
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);
// Keep wait op if reaches the maximum queue size or there is a
// duplicate BD ID.
uint32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
auto &bdIdQueue = connectionToBdIdQueues[connectionOp];
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
if (bdIdQueue.size() >= maxQueueSize) bdIdQueue.clear();
if (bdIdQueue.empty() || llvm::is_contained(bdIdQueue, bdId)) {
toErase = false;
bdIdQueue = {bdId};
} else {
bdIdQueue.push_back(bdId);
}
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
}
}
// Erase later to avoid invalidating the iterator.
if (toErase) waitOpsToErase.push_back(waitOp);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

for (AMDAIE::NpuDmaWaitOp waitOp : waitOpsToErase) {
SmallVector<Value> asyncTokens(waitOp.getAsyncTokens());
// Erase the wait op.
rewriter.eraseOp(waitOp);
for (Value token : asyncTokens) {
if (auto op = dyn_cast_if_present<AMDAIE::NpuHalfDmaCpyNdOp>(
token.getDefiningOp())) {
if (op.use_empty()) {
rewriter.setInsertionPoint(op);
TypeRange resultTypeRange = TypeRange{};
// Nullify the result to avoid issuing a token.
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel());
rewriter.eraseOp(op);
}
}
}
}

return success();
}

class AMDAIESimplifyDmaWaitsPass
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
: public impl::AMDAIESimplifyDmaWaitsBase<AMDAIESimplifyDmaWaitsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

AMDAIESimplifyDmaWaitsPass() = default;
AMDAIESimplifyDmaWaitsPass(const AMDAIESimplifyDmaWaitsPass &pass){};
void runOnOperation() override;
};

void AMDAIESimplifyDmaWaitsPass::runOnOperation() {
Operation *parentOp = getOperation();

auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
parentOp->emitOpError()
<< "has no AMDAIEDevice in the target attribute configuration. This "
"device-specific information is required to simplify DMA wait "
"ops.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(simplifyDmaWaits(deviceModel, workgroupOp))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIESimplifyDmaWaitsPass() {
return std::make_unique<AMDAIESimplifyDmaWaitsPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ iree_cc_library(
"AMDAIEPeelForLoop.cpp"
"AMDAIEPropagateDataLayout.cpp"
"AMDAIERemoveMemorySpace.cpp"
"AMDAIESimplifyDmaWaits.cpp"
"AMDAIESinkIntoCore.cpp"
"AMDAIESplitLogicalObjFifos.cpp"
"AMDAIESplitLogicalObjFifosForConnectionReuse.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEPEELFORLOOP
#define GEN_PASS_DEF_AMDAIEPROPAGATEDATALAYOUT
#define GEN_PASS_DEF_AMDAIEREMOVEMEMORYSPACE
#define GEN_PASS_DEF_AMDAIESIMPLIFYDMAWAITS
#define GEN_PASS_DEF_AMDAIESINKINTOCORE
#define GEN_PASS_DEF_AMDAIESPLITLOGICALOBJFIFOS
#define GEN_PASS_DEF_AMDAIESPLITLOGICALOBJFIFOSFORCONNECTIONREUSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
passManager.addPass(createAMDAIESimplifyDmaWaitsPass());
passManager.addPass(createAMDAIEControlCodeLoweringPass());
passManager.addPass(createAMDAIEControlCodeToTransactionPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ std::unique_ptr<Pass> createAMDAIEPeelForLoopPass(
/// Create a pass to remove memory space annotation from all types.
std::unique_ptr<Pass> createAMDAIERemoveMemorySpacePass();

/// Create a pass to remove redundant DMA wait operations.
std::unique_ptr<Pass> createAMDAIESimplifyDmaWaitsPass();

/// Create a pass to sink all dependencies into `amdaie.core` operations.
std::unique_ptr<Pass> createAMDAIESinkIntoCorePass();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,12 @@ def AMDAIERemoveMemorySpace : Pass<"iree-amdaie-remove-memoryspace"> {
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIERemoveMemorySpacePass()";
}

def AMDAIESimplifyDmaWaits :
Pass<"iree-amdaie-simplify-dma-waits", ""> {
let summary = "Remove redundant dma wait operations in controlcode.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESimplifyDmaWaitsPass()";
}

def AMDAIESinkIntoCore :
Pass<"iree-amdaie-sink-into-core", "ModuleOp"> {
let summary = "Clone constants and other ops into amdaie.cores";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ iree_lit_test_suite(
"peel_for_loop.mlir"
"propagate_data_layout.mlir"
"remove_memory_space.mlir"
"simplify_dma_waits.mlir"
"sink_into_core.mlir"
"split_logicalobjfifos.mlir"
"split_logicalobjfifos_for_connection_reuse.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0001D214
// CHECK: 0x00000000
// CHECK: 0x80000000
// CHECK: 0x00000000
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue_default_values
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand All @@ -102,7 +102,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: 0x00000000
// CHECK: 0x0601D21C
// CHECK: 0x00000000
// CHECK: 0x803F0002
// CHECK: 0x003F0002
// CHECK: 0x00000018
// CHECK-LABEL: @push_to_queue
// CHECK: npu_instructions = dense_resource<npu_instructions> : tensor<10xui32>
Expand Down
Loading
Loading