diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp index d42e7f935..2796c5706 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertDmaBdChain.cpp @@ -87,7 +87,7 @@ void checkForChainsToBeBroken( const DenseMap> &dmaChainToBdIds, SmallVector &chainsToBreak) { for (auto &[entry, bdIds] : dmaChainToBdIds) { - if (entry.first == currDmaChain.first && bdIds.contains(currbdId)) { + if (entry.first == currDmaChain.first && bdIds.contains(currBdId)) { // Break the chain that contains the duplicate BD ID. chainsToBreak.push_back(entry); if (entry != currDmaChain) { @@ -155,12 +155,9 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel, // Repeat count > 1, do not chain BDs. int32_t repeatCount = 1; - uint8_t numIntraAddrDim = deviceModel.getDmaProp( - AMDAIE::AMDAIETileType::SHIMNOC, AMDAIE::AMDAIEDmaProp::NumAddrDim); - uint8_t numAddrDim = - numIntraAddrDim + deviceModel.deviceConfig.dmaNbInterDims; - auto sizes = npuHalfDmaCpyNdOp.getMixedSizes(); - auto strides = npuHalfDmaCpyNdOp.getMixedStrides(); + uint8_t numAddrDim = DmaDimConfig(deviceModel, 0).maxNbDims; + SmallVector sizes = npuHalfDmaCpyNdOp.getMixedSizes(); + SmallVector strides = npuHalfDmaCpyNdOp.getMixedStrides(); if (!sizes.empty() && !strides.empty()) { int64_t size = getConstantIndexOrAssert(sizes[0]); int64_t stride = getConstantIndexOrAssert(strides[0]); @@ -197,23 +194,28 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel, // the `updateChainOperands` function. if (!chainsToBreak.empty()) { for (auto &entry : chainsToBreak) { + // Since the controlcode is traversed in reverse order, we need to + // restore the original order of the DMA operations. + std::reverse(dmaChainToDmaOps[entry].begin(), + dmaChainToDmaOps[entry].end()); if (failed(updateChainOperands(rewriter, dmaChainToDmaOps[entry]))) WalkResult::interrupt(); dmaChainToBdIds[entry].clear(); dmaChainToDmaOps[entry].clear(); } } - dmaChainToBdIds[currDmaChain].insert(bdId); - // Insert at the front, as we are walking in reverse order. - dmaChainToDmaOps[currDmaChain].insert( - dmaChainToDmaOps[currDmaChain].begin(), npuHalfDmaCpyNdOp); + dmaChainToDmaOps[currDmaChain].push_back(npuHalfDmaCpyNdOp); } return WalkResult::advance(); }); // Build the remaining chains. for (auto &[entry, _] : dmaChainToBdIds) { + // Since the controlcode is traversed in reverse order, we need to + // restore the original order of the DMA operations. + std::reverse(dmaChainToDmaOps[entry].begin(), + dmaChainToDmaOps[entry].end()); if (failed(updateChainOperands(rewriter, dmaChainToDmaOps[entry]))) return failure(); }