From 475e57d3b646d241ace3e7341845fcb9b541763b Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Fri, 24 Jan 2025 20:16:38 -0600 Subject: [PATCH 1/2] [Dispatch] Don't fuse bit truncate -> extend ops Signed-off-by: Ian Wood --- .../DispatchCreation/FuseMultiUseElementwiseProducer.cpp | 4 ++-- compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index fc18e3a2f37b..07814294da91 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -198,10 +198,10 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, // 7. Skip dequantization-like `producer` ops as we would rather fuse // by cloning the producer instead of multi-use fusion. - if (IREE::LinalgExt::isBitExtendOp(producer)) { + if (IREE::LinalgExt::isBitTruncateOp(producer) || + IREE::LinalgExt::isBitExtendOp(producer)) { return; } - // 8. All uses from `producer` -> `consumer` need to be fusable. // Without this the `producer` is still live, and there is no // advantage to do the fusion. diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index c428091f6cf8..b5097942af1e 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -69,6 +69,10 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return false; } + if (IREE::LinalgExt::isBitTruncateOp(producerOp)) { + return false; + } + auto linalgConsumerOp = dyn_cast(consumerOp); if (!linalgConsumerOp) { return false; From 12952e21d55794f0e2e0c1fc3049e007dae120c0 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 27 Jan 2025 12:21:53 -0600 Subject: [PATCH 2/2] Fix codegen failure Signed-off-by: Ian Wood --- .../DispatchCreation/ElementwiseOpFusion.cpp | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index 6779d29ec1c6..5f890827b666 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -63,31 +63,43 @@ struct GatherFusionPattern final : public OpRewritePattern { } // Check if the producerOp is fusible - if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 || - !isElementwise(producerOp) || + if (!isElementwise(producerOp) || !IREE::LinalgExt::isBitExtendOp(producerOp)) { return rewriter.notifyMatchFailure(producerOp, "producer op is not fusible"); } + auto result = cast(extractOp.getTensor()); + auto resultMap = producerOp.getIndexingMapMatchingResult(result); + const SmallVector resultPerm = llvm::map_to_vector<4>( + resultMap.getResults(), [](AffineExpr expr) -> int64_t { + return cast(expr).getPosition(); + }); + OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(extractOp); - // Create a new extract op that extracts from the original tensor - // (after the original extract). Clone the producerOp's body into the - // consumerOp, inline the cloned block (erases the block) after the new - // extract, and clean up. - auto newExtractOp = rewriter.create( - extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(), - extractOp.getIndices()); + SmallVector extractOps; + for (OpOperand &operand : producerOp->getOpOperands()) { + auto map = producerOp.getMatchingIndexingMap(&operand); + auto perm = llvm::map_to_vector<4>( + map.getResults(), [](AffineExpr expr) -> int64_t { + return cast(expr).getPosition(); + }); + SmallVector indices = extractOp.getIndices(); + perm = applyPermutation(perm, resultPerm); + indices = applyPermutation(indices, perm); + auto newExtract = rewriter.create( + extractOp.getLoc(), operand.get(), indices); + extractOps.push_back(newExtract); + } rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), consumerOp.getRegion().begin()); Block &clonedBlock = consumerOp.getRegion().front(); auto producerTermOp = clonedBlock.getTerminator(); - rewriter.inlineBlockBefore( - &clonedBlock, extractOp->getNextNode(), - {newExtractOp.getResult(), newExtractOp.getResult()}); + rewriter.inlineBlockBefore(&clonedBlock, extractOp->getNextNode(), + extractOps); // Replace the the all references to the original extract result with the // result from the inlined producerOp.