diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp index 63baa0ab2..a5f314004 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp @@ -19,40 +19,30 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -FailureOr getParentBeforeLoop(BlockArgument blkArg) { - Operation *parent = blkArg.getOwner()->getParentOp(); - LoopLikeOpInterface loop = dyn_cast(parent); - if (!loop) return failure(); - Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner(); - return operandParent; -} - /// A utility function specific to this pass which, given a value `operand`, /// traverses the def-chain till it finds a tensor.extract_slice. The 2 cases /// where it successfully finds and returns an extract_slice (SLICE) are: /// /// Case 1) -/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand -/// ^^^^^^^^^^^^^^^^^^^^ -/// any number (>= 0) of trailing packs -/// -/// Case 2) /// pack -> SLICE -> pack -> pack -> pack -> operand /// ^^^^^^^^^^^^^^^^^^^^ /// any number (>= 0) of trailing packs /// +/// Case 2) +/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand +/// ^^^^^^^^^^^^^^^^^^^^ +/// any number (>= 0) of trailing packs +/// +/// Case 2 only matches where `block arg` is for a loop operation. static FailureOr getTensorExtractSliceDefiningOp( Value operand) { // roll back through all the packs immediately preceding `operand`. - while (operand.getDefiningOp() && - isa(operand.getDefiningOp())) { + while (isa_and_present(operand.getDefiningOp())) { operand = operand.getDefiningOp()->getOperand(0); } - // If the parent of `operand` is not an extract_slice, return failure. - Operation *defOp = operand.getDefiningOp(); - if (!defOp) return failure(); - tensor::ExtractSliceOp sliceOp = dyn_cast(defOp); + tensor::ExtractSliceOp sliceOp = + dyn_cast_if_present(operand.getDefiningOp()); if (!sliceOp) return failure(); // Case 1 outlined above. @@ -62,9 +52,11 @@ static FailureOr getTensorExtractSliceDefiningOp( // Case 2 outlined above. else if (auto blkArg = dyn_cast(sliceOp.getSource())) { - FailureOr operandParent = getParentBeforeLoop(blkArg); - if (failed(operandParent)) return failure(); - if (isa_and_present(operandParent.value())) return sliceOp; + Operation *parent = blkArg.getOwner()->getParentOp(); + LoopLikeOpInterface loop = dyn_cast(parent); + if (!loop) return failure(); + Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner(); + if (isa_and_present(operandParent)) return sliceOp; } return failure(); @@ -139,6 +131,7 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() { for (Value operand : genericOp.getOperands()) { FailureOr maybeSliceOp = getTensorExtractSliceDefiningOp(operand); + if (succeeded(maybeSliceOp)) { tensor::ExtractSliceOp sliceOp = maybeSliceOp.value(); std::optional fusedProducer = @@ -152,14 +145,12 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() { // Case where operand of generic is a pack op which is in a different // block than the generic's block. - else { - if (auto parent = - dyn_cast_if_present(operand.getDefiningOp())) { - Block *genericBlock = genericOp->getBlock(); - if (parent->getBlock() != genericBlock && parent->hasOneUse()) { - Operation *firstOpInBlock = &genericBlock->front(); - rewriter.moveOpBefore(parent, firstOpInBlock); - } + else if (auto parent = dyn_cast_if_present( + operand.getDefiningOp())) { + Block *genericBlock = genericOp->getBlock(); + if (parent->getBlock() != genericBlock && parent->hasOneUse()) { + Operation *firstOpInBlock = &genericBlock->front(); + rewriter.moveOpBefore(parent, firstOpInBlock); } } }