diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitLogicalObjFifos.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitLogicalObjFifos.cpp index 403a0ed79..958cf4e9f 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitLogicalObjFifos.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitLogicalObjFifos.cpp @@ -6,7 +6,6 @@ #include "iree-amd-aie/IR/AMDAIEOps.h" #include "iree-amd-aie/Transforms/Passes.h" #include "iree-amd-aie/Transforms/Utils/AMDAIELogicalObjFifoSplittingUtils.h" -#include "mlir/IR/Iterators.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-amdaie-split-logical-objectfifos" @@ -28,10 +27,11 @@ using DmaObjFifoPairT = /// each DMA and objectFifo pair. /// /// Each pair is handled in the following way: +/// /// First, compute the objectFifo splitting dimension as the last non-unit shape -/// dimension. Afterwards, depending on which logical objectFifo is being -/// split on, find the outermost dimension in either the source or -/// target access pattern that has: +/// dimension, less than 2. Afterwards, depending on which logical objectFifo is +/// being split on, find the outermost dimension in either the source or target +/// access pattern that has: /// - stride == sizeAfterSplit /// - size != 1 /// This is the splitting dimension to be used on the respective side of the DMA @@ -59,9 +59,9 @@ LogicalResult collectSplittingDims( auto iter = std::find_if(memrefShape.begin(), memrefShape.end(), [](int64_t size) { return size > 1; }); size_t objFifoSplitDim = std::distance(memrefShape.begin(), iter); - // If all dimensions are unit (1), no splitting can be done, so continue to - // the next pair. - if (objFifoSplitDim >= memrefShape.size()) continue; + // Only consider splitting on dimensions 0 and 1. + if (objFifoSplitDim >= 2) continue; + int64_t sizeAfterSplit = std::accumulate(memrefShape.begin() + objFifoSplitDim + 1, memrefShape.end(), 1, std::multiplies<>());