Skip to content

Commit 37bb7f1

Browse files
authored
Add DMA composition pass (#729)
Add a pass that iteratively calls the `DmaLoopSubsumption` and `CombineStridedOps` pattern rewriters as both can enable new composition opportunities for each other.
1 parent cc68d51 commit 37bb7f1

13 files changed

+311
-49
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECombineStridedOps.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
1616
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
1717
#include "iree-amd-aie/Transforms/Passes.h"
18+
#include "iree-amd-aie/Transforms/Transforms.h"
1819
#include "llvm/ADT/STLExtras.h"
1920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021

@@ -159,7 +160,7 @@ void AMDAIECombineStridedOpsPass::runOnOperation() {
159160
Operation *parentOp = getOperation();
160161
MLIRContext *context = &getContext();
161162
RewritePatternSet patterns(context);
162-
patterns.insert<CombineStridedOps>(context);
163+
populateStridedOpCombinationPattern(patterns);
163164
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) {
164165
parentOp->emitOpError("failed to combine strided operations");
165166
return signalPassFailure();
@@ -168,6 +169,10 @@ void AMDAIECombineStridedOpsPass::runOnOperation() {
168169

169170
} // namespace
170171

172+
void populateStridedOpCombinationPattern(RewritePatternSet &patterns) {
173+
patterns.insert<CombineStridedOps>(patterns.getContext());
174+
}
175+
171176
std::unique_ptr<Pass> createAMDAIECombineStridedOpsPass() {
172177
return std::make_unique<AMDAIECombineStridedOpsPass>();
173178
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file composes more complex strided DMA ops by iteratively:
10+
// 1. Combining ops in the same block.
11+
// 2. Subsuming loop iterations into the strided access pattern.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
16+
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
17+
#include "iree-amd-aie/Transforms/Passes.h"
18+
#include "iree-amd-aie/Transforms/Transforms.h"
19+
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
#define DEBUG_TYPE "iree-amdaie-dma-composition"
23+
24+
namespace mlir::iree_compiler::AMDAIE {
25+
26+
namespace {
27+
28+
class AMDAIEDmaCompositionPass
29+
: public impl::AMDAIEDmaCompositionBase<AMDAIEDmaCompositionPass> {
30+
public:
31+
AMDAIEDmaCompositionPass() = default;
32+
AMDAIEDmaCompositionPass(const AMDAIEDmaCompositionPass &pass){};
33+
AMDAIEDmaCompositionPass(const AMDAIEDmaCompositionOptions &options)
34+
: AMDAIEDmaCompositionBase(options) {}
35+
void runOnOperation() override;
36+
};
37+
38+
void AMDAIEDmaCompositionPass::runOnOperation() {
39+
Operation *parentOp = getOperation();
40+
MLIRContext *context = &getContext();
41+
RewritePatternSet patterns(context);
42+
{
43+
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
44+
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
45+
if (!maybeDevice) {
46+
parentOp->emitOpError()
47+
<< "has no AMDAIEDevice in the target attribute configuration. This "
48+
"device-specific information is required to determine when loops "
49+
"can be subsumed into DMA operations, and must be attached to a "
50+
"containing ModuleOp.";
51+
return signalPassFailure();
52+
}
53+
AMDAIE::AMDAIEDeviceModel deviceModel =
54+
AMDAIE::getDeviceModel(maybeDevice.value());
55+
populateDmaLoopSubsumptionPattern(patterns, std::move(deviceModel),
56+
onlyZeroStrideOnOuterDim);
57+
}
58+
populateStridedOpCombinationPattern(patterns);
59+
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) {
60+
parentOp->emitOpError("failed to compose strided operations");
61+
return signalPassFailure();
62+
}
63+
64+
IRRewriter rewriter(parentOp->getContext());
65+
if (failed(moveNpuDmaSyncUsersAfterAncestorInSameBlock(rewriter, parentOp))) {
66+
parentOp->emitOpError() << "failed to move DMA users to correct scope "
67+
"after strided op composition";
68+
return signalPassFailure();
69+
}
70+
}
71+
72+
} // namespace
73+
74+
std::unique_ptr<Pass> createAMDAIEDmaCompositionPass(
75+
AMDAIEDmaCompositionOptions options) {
76+
return std::make_unique<AMDAIEDmaCompositionPass>(options);
77+
}
78+
79+
} // namespace mlir::iree_compiler::AMDAIE

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaLoopSubsumption.cpp

+14-42
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
2626
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
2727
#include "iree-amd-aie/Transforms/Passes.h"
28+
#include "iree-amd-aie/Transforms/Transforms.h"
2829
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
2930
#include "mlir/Dialect/Affine/IR/AffineOps.h"
3031
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -49,15 +50,6 @@ int64_t calculateNbIterations(int64_t lowerBound, int64_t upperBound,
4950

5051
namespace {
5152

52-
/// Return an ancestor of 'op' in 'block', or nullptr if no such ancestor.
53-
Operation *getAncestorInBlock(Operation *op, Block *block) {
54-
if (!op || !block) return nullptr;
55-
auto parent = op;
56-
while (parent && (parent->getBlock() != block))
57-
parent = parent->getParentOp();
58-
return parent;
59-
}
60-
6153
/// Utility affine expression visitor to retrieve the scale and optional bias
6254
/// from the expression.
6355
struct RetrieveScaleAndBias
@@ -112,31 +104,6 @@ struct RetrieveScaleAndBias
112104
}
113105
};
114106

115-
/// Utility to clean up the DMA users after loop subsumption + hoisting. This
116-
/// will hoist `amdaie.npu.dma_cpy_nd`'s users like `npu.dma_wait` as well.
117-
LogicalResult moveUsersToHoistedDMAScope(Operation *parentOp) {
118-
IRRewriter rewriter(parentOp->getContext());
119-
// Move `amdaie.npu.dma_wait` operation after the parent op in the same block
120-
// as the input `amdaie.npu.dma_cpy_nd` operation. This parent op will
121-
// typically be a loop out of which the DMA operation has been hoisted. Moving
122-
// the wait operation after this loop is important to avoid a deadlock with
123-
// whatever operations are still remaining inside the loop's scope.
124-
WalkResult res = parentOp->walk([&](AMDAIE::NpuDmaWaitOp npuDmaWaitOp) {
125-
Operation *dmaOp = npuDmaWaitOp.getDma().getDefiningOp();
126-
Operation *ancestorInSameBlock =
127-
getAncestorInBlock(npuDmaWaitOp, dmaOp->getBlock());
128-
if (!ancestorInSameBlock) {
129-
npuDmaWaitOp->emitOpError(
130-
"doesn't have an ancestor in the same scope as the source DMA op");
131-
return WalkResult::interrupt();
132-
}
133-
rewriter.moveOpAfter(npuDmaWaitOp, ancestorInSameBlock);
134-
return WalkResult::advance();
135-
});
136-
if (res.wasInterrupted()) return failure();
137-
return success();
138-
}
139-
140107
struct SubsumeLoopIntoDMA
141108
: public OpInterfaceRewritePattern<AMDAIE::DoublyStridedOpInterface> {
142109
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
@@ -594,7 +561,7 @@ class AMDAIEDmaLoopSubsumptionPass
594561
}
595562

596563
AMDAIEDmaLoopSubsumptionPass() = default;
597-
AMDAIEDmaLoopSubsumptionPass(const AMDAIEDmaLoopSubsumptionPass &pass) {};
564+
AMDAIEDmaLoopSubsumptionPass(const AMDAIEDmaLoopSubsumptionPass &pass){};
598565
AMDAIEDmaLoopSubsumptionPass(const AMDAIEDmaLoopSubsumptionOptions &options)
599566
: AMDAIEDmaLoopSubsumptionBase(options) {}
600567
void runOnOperation() override;
@@ -605,7 +572,6 @@ void AMDAIEDmaLoopSubsumptionPass::runOnOperation() {
605572
MLIRContext *context = &getContext();
606573

607574
RewritePatternSet patterns(context);
608-
609575
{
610576
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
611577
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
@@ -619,19 +585,17 @@ void AMDAIEDmaLoopSubsumptionPass::runOnOperation() {
619585
}
620586
AMDAIE::AMDAIEDeviceModel deviceModel =
621587
AMDAIE::getDeviceModel(maybeDevice.value());
622-
623-
SubsumeLoopIntoDMA pattern(context, std::move(deviceModel),
624-
onlyZeroStrideOnOuterDim);
625-
626-
patterns.insert<SubsumeLoopIntoDMA>(std::move(pattern));
588+
populateDmaLoopSubsumptionPattern(patterns, std::move(deviceModel),
589+
onlyZeroStrideOnOuterDim);
627590
}
628591

629592
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) {
630593
parentOp->emitOpError("failed to subsume some loops into DMA operations");
631594
return signalPassFailure();
632595
}
633596

634-
if (failed(moveUsersToHoistedDMAScope(parentOp))) {
597+
IRRewriter rewriter(parentOp->getContext());
598+
if (failed(moveNpuDmaSyncUsersAfterAncestorInSameBlock(rewriter, parentOp))) {
635599
parentOp->emitOpError(
636600
"failed to move DMA users to correct scope after loop subsumption");
637601
return signalPassFailure();
@@ -640,6 +604,14 @@ void AMDAIEDmaLoopSubsumptionPass::runOnOperation() {
640604

641605
} // namespace
642606

607+
void populateDmaLoopSubsumptionPattern(RewritePatternSet &patterns,
608+
AMDAIE::AMDAIEDeviceModel &&deviceModel,
609+
bool onlyZeroStrideOnOuterDim) {
610+
SubsumeLoopIntoDMA pattern(patterns.getContext(), std::move(deviceModel),
611+
onlyZeroStrideOnOuterDim);
612+
patterns.insert<SubsumeLoopIntoDMA>(std::move(pattern));
613+
}
614+
643615
std::unique_ptr<Pass> createAMDAIEDmaLoopSubsumptionPass(
644616
AMDAIEDmaLoopSubsumptionOptions options) {
645617
return std::make_unique<AMDAIEDmaLoopSubsumptionPass>(options);

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212

1313
namespace mlir::iree_compiler::AMDAIE {
1414

15+
/// Return an ancestor of 'op' in 'block', or nullptr if no such ancestor.
16+
Operation *getAncestorInBlock(Operation *op, Block *block) {
17+
if (!op || !block) return nullptr;
18+
auto parent = op;
19+
while (parent && (parent->getBlock() != block))
20+
parent = parent->getParentOp();
21+
return parent;
22+
}
23+
1524
/// Utility to retrieve a constant index from an OpFoldResult.
1625
int64_t getConstantIndexOrAssert(OpFoldResult dim) {
1726
std::optional<int64_t> size = getConstantIntValue(dim);
@@ -317,4 +326,22 @@ LogicalResult foldUnitDims(const SmallVector<OpFoldResult> &offsets,
317326
return success(foldableUnitDimsFound);
318327
}
319328

329+
LogicalResult moveNpuDmaSyncUsersAfterAncestorInSameBlock(
330+
RewriterBase &rewriter, Operation *parentOp) {
331+
WalkResult res = parentOp->walk([&](AMDAIE::NpuDmaWaitOp npuDmaWaitOp) {
332+
Operation *dmaOp = npuDmaWaitOp.getDma().getDefiningOp();
333+
Operation *ancestorInSameBlock =
334+
getAncestorInBlock(npuDmaWaitOp, dmaOp->getBlock());
335+
if (!ancestorInSameBlock) {
336+
npuDmaWaitOp->emitOpError(
337+
"doesn't have an ancestor in the same scope as the source DMA op");
338+
return WalkResult::interrupt();
339+
}
340+
rewriter.moveOpAfter(npuDmaWaitOp, ancestorInSameBlock);
341+
return WalkResult::advance();
342+
});
343+
if (res.wasInterrupted()) return failure();
344+
return success();
345+
}
346+
320347
} // namespace mlir::iree_compiler::AMDAIE

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.h

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "iree-amd-aie/IR/AMDAIEAttrs.h"
1111
#include "iree-amd-aie/IR/AMDAIEDmaOpInterface.h"
12+
#include "iree-amd-aie/IR/AMDAIEOps.h"
1213
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
1314
#include "llvm/ADT/SmallVector.h"
1415
#include "mlir/IR/MLIRContext.h"
@@ -301,6 +302,15 @@ struct DmaDimConfig {
301302
}
302303
};
303304

305+
/// Utility to move the synchronization users (`amdaie.npu.dma_wait`) directly
306+
/// after its ancestor in the same block as the DMA operation it's synchronizing
307+
/// on. This utility can be used for cleanup after DMA transformations to avoid
308+
/// deadlocks and/or ensure SSA dominance. The idea is to ensure correct
309+
/// synchronization by not influencing whatever is happening in between the
310+
/// async DMA operation and its synchronization op.
311+
LogicalResult moveNpuDmaSyncUsersAfterAncestorInSameBlock(
312+
RewriterBase &rewriter, Operation *parentOp);
313+
304314
} // namespace mlir::iree_compiler::AMDAIE
305315

306316
#endif

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ iree_cc_library(
6060
"AMDAIECreateLogicalObjectFifoLink.cpp"
6161
"AMDAIECreateReferenceToAllocation.cpp"
6262
"AMDAIEDistributeCoresAndObjectFifos.cpp"
63+
"AMDAIEDmaComposition.cpp"
6364
"AMDAIEDmaLoopSubsumption.cpp"
6465
"AMDAIEDmaToCircularDma.cpp"
6566
"AMDAIEDmaUtils.cpp"

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace mlir::iree_compiler::AMDAIE {
3939
#define GEN_PASS_DEF_AMDAIECREATEREFERENCETOALLOCATION
4040
#define GEN_PASS_DEF_AMDAIEDECOMPOSELINALGEXTPACKUNPACKTOAIR
4141
#define GEN_PASS_DEF_AMDAIEDISTRIBUTECORESANDOBJECTFIFOS
42+
#define GEN_PASS_DEF_AMDAIEDMACOMPOSITION
4243
#define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION
4344
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
4445
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -608,11 +608,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
608608
passManager.addPass(createCSEPass());
609609
passManager.addPass(createCanonicalizerPass());
610610

611-
passManager.addPass(createAMDAIEDmaLoopSubsumptionPass());
612-
passManager.addPass(createCSEPass());
613-
passManager.addPass(createCanonicalizerPass());
614-
615-
passManager.addPass(createAMDAIECombineStridedOpsPass());
611+
passManager.addPass(createAMDAIEDmaCompositionPass());
616612
passManager.addPass(createCSEPass());
617613
passManager.addPass(createCanonicalizerPass());
618614

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h

+6
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ std::unique_ptr<Pass> createAMDAIEDecomposeLinalgExtPackUnPackToAIRPass();
122122
/// operations and distribute the logical objectFifos.
123123
std::unique_ptr<Pass> createAMDAIEDistributeCoresAndObjectFifosPass();
124124

125+
/// Create a pass to compose more complex DMA operations, e.g. by combining DMA
126+
/// operations and/or subsuming loop iterations into the strided access
127+
/// patterns.
128+
std::unique_ptr<Pass> createAMDAIEDmaCompositionPass(
129+
AMDAIEDmaCompositionOptions options = {});
130+
125131
/// Create a pass to subsume loop iterations into DMA operations' access
126132
/// patterns.
127133
std::unique_ptr<Pass> createAMDAIEDmaLoopSubsumptionPass(

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td

+11-1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def AMDAIEDistributeCoresAndObjectFifos :
173173
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEDistributeCoresAndObjectFifosPass()";
174174
}
175175

176+
def AMDAIEDmaComposition :
177+
Pass<"iree-amdaie-dma-composition"> {
178+
let summary = "Compose DMA operations by DMA combination and loop subsumption.";
179+
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEDmaCompositionPass()";
180+
let options = [
181+
Option<"onlyZeroStrideOnOuterDim", "only-zero-stride-on-outer-dim", "bool", /*default=*/"true",
182+
"Whether a stride of zero indicating a repeat is only supported on the "
183+
"outer dimension. This is the case of AIE2(+).">
184+
];
185+
}
186+
176187
def AMDAIEDmaLoopSubsumption :
177188
Pass<"iree-amdaie-dma-loop-subsumption"> {
178189
let summary = "Subsume loop iterations into DMA operations' access patterns.";
@@ -459,7 +470,6 @@ def AMDAIESinkIntoCore :
459470
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESinkIntoCorePass()";
460471
}
461472

462-
463473
def AMDAIETile :
464474
InterfacePass<"iree-amdaie-tile", "mlir::FunctionOpInterface"> {
465475
let summary = "Pass to tile TilingInterface operations.";

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Transforms.h

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define IREE_AMD_AIE_TRANSFORMS_AMDAIETRANSFORMS_H_
99

1010
#include "iree-amd-aie/IR/AMDAIEOps.h"
11+
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
1112
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1213
#include "mlir/Dialect/SCF/IR/SCF.h"
1314

@@ -38,6 +39,14 @@ LogicalResult normalizeLoopBounds(RewriterBase &rewriter, scf::ForOp forOp);
3839
LogicalResult normalizeLoopBounds(RewriterBase &rewriter,
3940
scf::ForallOp forallOp);
4041

42+
/// Populate patterns that subsume loops iterations into DMA access patterns.
43+
void populateDmaLoopSubsumptionPattern(RewritePatternSet &patterns,
44+
AMDAIE::AMDAIEDeviceModel &&deviceModel,
45+
bool onlyZeroStrideOnOuterDim);
46+
47+
/// Populate patterns that combine strided ops in the same block.
48+
void populateStridedOpCombinationPattern(RewritePatternSet &patterns);
49+
4150
} // namespace mlir::iree_compiler::AMDAIE
4251

4352
#endif

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ iree_lit_test_suite(
2727
"create_reference_to_allocation.mlir"
2828
"disable_vectorization.mlir"
2929
"distribute_cores_and_objectfifos.mlir"
30+
"dma_composition.mlir"
3031
"dma_loop_subsumption.mlir"
3132
"dma_to_circular_dma.mlir"
3233
"flatten_logical_objectfifo.mlir"

0 commit comments

Comments
 (0)