Skip to content

Commit

Permalink
Merge branch 'main' into refactor-bufferize-operand
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls authored Dec 11, 2024
2 parents a1fd889 + 71e17ed commit a07c87c
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 165 deletions.
22 changes: 22 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,28 @@ bool NpuDmaCpyNdOp::hasDmaWaitOpUser() {
[](auto userOp) { return isa<NpuDmaWaitOp>(userOp); });
}

FailureOr<AMDAIE::ChannelOp> NpuDmaCpyNdOp::getSourceChannelOp() {
AMDAIE::ConnectionOp connectionOp = getConnectionOp();
if (!connectionOp)
return emitOpError() << "should operate on an `amdaie.connection` op";
if (connectionOp.getSourceChannels().size() != 1)
return emitOpError() << "expected a single source channel";
auto sourceChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getSourceChannels()[0].getDefiningOp());
return sourceChannelOp;
}

FailureOr<AMDAIE::ChannelOp> NpuDmaCpyNdOp::getTargetChannelOp() {
AMDAIE::ConnectionOp connectionOp = getConnectionOp();
if (!connectionOp)
return emitOpError() << "should operate on an `amdaie.connection` op";
if (connectionOp.getTargetChannels().size() != 1)
return emitOpError() << "expected a single target channel";
auto targetChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getTargetChannels()[0].getDefiningOp());
return targetChannelOp;
}

namespace {
struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd", [
if (!bdIdValue) return nullptr;
return dyn_cast_if_present<BdIdOp>(bdIdValue.getDefiningOp());
}

FailureOr<AMDAIE::ChannelOp> getSourceChannelOp();

FailureOr<AMDAIE::ChannelOp> getTargetChannelOp();

// A utility to create a new doubly strided operation from this one with a
// new set of source and target offsets, sizes and strides.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,27 @@ template <CopyOpOperateOn OperateOn>
FailureOr<AMDAIE::BdIdOp> getBdIdOp(
IRRewriter &rewriter, AMDAIE::NpuDmaCpyNdOp &npuDmaOp,
DenseMap<Value, ChannelBdIdGenerator> &shimTileToGeneratorMap,
DenseMap<AMDAIE::BdIdOp, SmallVector<uint32_t>> &bdIdOpToBdIdsMap,
uint32_t channel) {
FailureOr<AMDAIE::TileOp> tileOp =
DenseMap<AMDAIE::BdIdOp, SmallVector<uint32_t>> &bdIdOpToBdIdsMap) {
// Get the TileOp.
FailureOr<AMDAIE::TileOp> maybeTileOp =
getGeneratorTileOp<OperateOn>(npuDmaOp, shimTileToGeneratorMap);
if (failed(tileOp)) return failure();
if (failed(maybeTileOp)) return failure();
AMDAIE::TileOp tileOp = maybeTileOp.value();

// Get the channel.
FailureOr<AMDAIE::ChannelOp> maybeChannelOp;
if constexpr (OperateOn == CopyOpOperateOn::Source) {
maybeChannelOp = npuDmaOp.getSourceChannelOp();
} else if constexpr (OperateOn == CopyOpOperateOn::Target) {
maybeChannelOp = npuDmaOp.getTargetChannelOp();
} else {
return npuDmaOp.emitOpError()
<< "Function can only operate on Source or Target";
}
if (failed(maybeChannelOp)) return failure();
uint32_t channel = maybeChannelOp.value().getValue();

ChannelBdIdGenerator &generator = shimTileToGeneratorMap[tileOp->getResult()];
ChannelBdIdGenerator &generator = shimTileToGeneratorMap[tileOp.getResult()];
rewriter.setInsertionPoint(npuDmaOp);
if (scf::ForOp loop = npuDmaOp->getParentOfType<scf::ForOp>();
loop && getNumberIterations(loop)) {
Expand All @@ -165,7 +179,7 @@ FailureOr<AMDAIE::BdIdOp> getBdIdOp(

// Get the number of BD IDs will be assigned to current DMA op.
uint32_t numRequired = 0;
getNumRequiredBdIds(loop, npuDmaOp, *tileOp, shimTileToGeneratorMap,
getNumRequiredBdIds(loop, npuDmaOp, tileOp, shimTileToGeneratorMap,
numRequired);
uint32_t numAvailable = generator.getNumAvailableBdIds(channel);
uint32_t size = std::max(numAvailable / numRequired, 1u);
Expand Down Expand Up @@ -193,7 +207,7 @@ FailureOr<AMDAIE::BdIdOp> getBdIdOp(
iv,
});
AMDAIE::BdIdOp bdIdOp = rewriter.create<AMDAIE::BdIdOp>(
rewriter.getUnknownLoc(), *tileOp, affineApply.getResult());
rewriter.getUnknownLoc(), tileOp, affineApply.getResult());
bdIdOpToBdIdsMap[bdIdOp] = bdIds;
return bdIdOp;
}
Expand All @@ -206,7 +220,7 @@ FailureOr<AMDAIE::BdIdOp> getBdIdOp(
auto constant = rewriter.create<arith::ConstantOp>(
rewriter.getUnknownLoc(), rewriter.getIndexAttr(bdId.value()));
AMDAIE::BdIdOp bdIdOp = rewriter.create<AMDAIE::BdIdOp>(
rewriter.getUnknownLoc(), *tileOp, constant.getResult());
rewriter.getUnknownLoc(), tileOp, constant.getResult());
return bdIdOp;
};

Expand Down Expand Up @@ -266,13 +280,6 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
}
});

// TODO(jornt): Temporarily use channel 0 for all DMAs. This should
// return correct results for Shim channels, however, for generality
// towards other DMAs and future hardware generations, channel
// assignment should happen before BD assignemnt. This requires more
// refactoring.
const uint32_t channel = 0;

DenseMap<AMDAIE::BdIdOp, SmallVector<uint32_t>> bdIdOpToBdIdsMap;
// Walk `amdaie.npu_dma_cpy_nd` and `amdaie.dma_wait` operations and assign
// and release BD IDs when encountering the respective operations using the
Expand All @@ -282,8 +289,7 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op)) {
if (npuDmaOp.getSource()) {
FailureOr<AMDAIE::BdIdOp> bdIdOp = getBdIdOp<CopyOpOperateOn::Source>(
rewriter, npuDmaOp, shimTileToGeneratorMap, bdIdOpToBdIdsMap,
channel);
rewriter, npuDmaOp, shimTileToGeneratorMap, bdIdOpToBdIdsMap);
if (failed(bdIdOp)) return WalkResult::interrupt();
rewriter.setInsertionPoint(npuDmaOp);
npuDmaOp = rewriter.replaceOpWithNewOp<AMDAIE::NpuDmaCpyNdOp>(
Expand All @@ -296,8 +302,7 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
}
if (npuDmaOp.getTarget()) {
FailureOr<AMDAIE::BdIdOp> bdIdOp = getBdIdOp<CopyOpOperateOn::Target>(
rewriter, npuDmaOp, shimTileToGeneratorMap, bdIdOpToBdIdsMap,
channel);
rewriter, npuDmaOp, shimTileToGeneratorMap, bdIdOpToBdIdsMap);
if (failed(bdIdOp)) return WalkResult::interrupt();
rewriter.setInsertionPoint(npuDmaOp);
(void)rewriter.replaceOpWithNewOp<AMDAIE::NpuDmaCpyNdOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIEDmaCSEPass());

passManager.addPass(createAMDAIEAssignChannelsPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEAssignNpuDmaBdIdsPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
Expand All @@ -650,10 +654,6 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEConvertCoreForallToForPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEAssignChannelsPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());
passManager.addPass(createAMDAIEConnectionToFlowPass());
Expand Down
Loading

0 comments on commit a07c87c

Please sign in to comment.