@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11051105 // / A set of operations that were modified by the current pattern.
11061106 SetVector<Operation *> patternModifiedOps;
11071107
1108- // / A set of blocks that were inserted (newly-created blocks or moved blocks)
1109- // / by the current pattern.
1110- SetVector<Block *> patternInsertedBlocks;
1111-
11121108 // / A list of unresolved materializations that were created by the current
11131109 // / pattern.
11141110 DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
20462042 if (!config.allowPatternRollback && config.listener )
20472043 config.listener ->notifyBlockInserted (block, previous, previousIt);
20482044
2049- patternInsertedBlocks.insert (block);
2050-
20512045 if (wasDetached) {
20522046 // If the block was detached, it is most likely a newly created block.
20532047 if (config.allowPatternRollback ) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
23992393 bool canApplyPattern (Operation *op, const Pattern &pattern);
24002394
24012395 // / Legalize the resultant IR after successfully applying the given pattern.
2402- LogicalResult legalizePatternResult (Operation *op, const Pattern &pattern,
2403- const RewriterState &curState,
2404- const SetVector<Operation *> &newOps,
2405- const SetVector<Operation *> &modifiedOps,
2406- const SetVector<Block *> &insertedBlocks);
2407-
2408- // / Legalizes the actions registered during the execution of a pattern.
24092396 LogicalResult
2410- legalizePatternBlockRewrites (Operation *op,
2411- const SetVector<Block *> &insertedBlocks,
2412- const SetVector<Operation *> &newOps);
2397+ legalizePatternResult (Operation *op, const Pattern &pattern,
2398+ const RewriterState &curState,
2399+ const SetVector<Operation *> &newOps,
2400+ const SetVector<Operation *> &modifiedOps);
2401+
24132402 LogicalResult
24142403 legalizePatternCreatedOperations (const SetVector<Operation *> &newOps);
24152404 LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26082597 auto cleanup = llvm::make_scope_exit ([&]() {
26092598 rewriterImpl.patternNewOps .clear ();
26102599 rewriterImpl.patternModifiedOps .clear ();
2611- rewriterImpl.patternInsertedBlocks .clear ();
26122600 });
26132601
26142602 // Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26622650static void
26632651reportNewIrLegalizationFatalError (const Pattern &pattern,
26642652 const SetVector<Operation *> &newOps,
2665- const SetVector<Operation *> &modifiedOps,
2666- const SetVector<Block *> &insertedBlocks) {
2653+ const SetVector<Operation *> &modifiedOps) {
26672654 auto newOpNames = llvm::map_range (
26682655 newOps, [](Operation *op) { return op->getName ().getStringRef (); });
26692656 auto modifiedOpNames = llvm::map_range (
26702657 modifiedOps, [](Operation *op) { return op->getName ().getStringRef (); });
2671- StringRef detachedBlockStr = " (detached block)" ;
2672- auto insertedBlockNames = llvm::map_range (insertedBlocks, [&](Block *block) {
2673- if (block->getParentOp ())
2674- return block->getParentOp ()->getName ().getStringRef ();
2675- return detachedBlockStr;
2676- });
2677- llvm::report_fatal_error (
2678- " pattern '" + pattern.getDebugName () +
2679- " ' produced IR that could not be legalized. " + " new ops: {" +
2680- llvm::join (newOpNames, " , " ) + " }, " + " modified ops: {" +
2681- llvm::join (modifiedOpNames, " , " ) + " }, " + " inserted block into ops: {" +
2682- llvm::join (insertedBlockNames, " , " ) + " }" );
2658+ llvm::report_fatal_error (" pattern '" + pattern.getDebugName () +
2659+ " ' produced IR that could not be legalized. " +
2660+ " new ops: {" + llvm::join (newOpNames, " , " ) + " }, " +
2661+ " modified ops: {" +
2662+ llvm::join (modifiedOpNames, " , " ) + " }" );
26832663}
26842664
26852665LogicalResult OperationLegalizer::legalizeWithPattern (Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27432723 }
27442724 rewriterImpl.patternNewOps .clear ();
27452725 rewriterImpl.patternModifiedOps .clear ();
2746- rewriterImpl.patternInsertedBlocks .clear ();
27472726 LLVM_DEBUG ({
27482727 logFailure (rewriterImpl.logger , " pattern failed to match" );
27492728 if (rewriterImpl.config .notifyCallback ) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27772756 SetVector<Operation *> newOps = moveAndReset (rewriterImpl.patternNewOps );
27782757 SetVector<Operation *> modifiedOps =
27792758 moveAndReset (rewriterImpl.patternModifiedOps );
2780- SetVector<Block *> insertedBlocks =
2781- moveAndReset (rewriterImpl.patternInsertedBlocks );
2782- auto result = legalizePatternResult (op, pattern, curState, newOps,
2783- modifiedOps, insertedBlocks);
2759+ auto result =
2760+ legalizePatternResult (op, pattern, curState, newOps, modifiedOps);
27842761 appliedPatterns.erase (&pattern);
27852762 if (failed (result)) {
27862763 if (!rewriterImpl.config .allowPatternRollback )
2787- reportNewIrLegalizationFatalError (pattern, newOps, modifiedOps,
2788- insertedBlocks);
2764+ reportNewIrLegalizationFatalError (pattern, newOps, modifiedOps);
27892765 rewriterImpl.resetState (curState, pattern.getDebugName ());
27902766 }
27912767 if (config.listener )
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
28232799LogicalResult OperationLegalizer::legalizePatternResult (
28242800 Operation *op, const Pattern &pattern, const RewriterState &curState,
28252801 const SetVector<Operation *> &newOps,
2826- const SetVector<Operation *> &modifiedOps,
2827- const SetVector<Block *> &insertedBlocks) {
2802+ const SetVector<Operation *> &modifiedOps) {
28282803 [[maybe_unused]] auto &impl = rewriter.getImpl ();
28292804 assert (impl.pendingRootUpdates .empty () && " dangling root updates" );
28302805
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28432818#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
28442819
28452820 // Legalize each of the actions registered during application.
2846- if (failed (legalizePatternBlockRewrites (op, insertedBlocks, newOps)) ||
2847- failed (legalizePatternRootUpdates (modifiedOps)) ||
2821+ if (failed (legalizePatternRootUpdates (modifiedOps)) ||
28482822 failed (legalizePatternCreatedOperations (newOps))) {
28492823 return failure ();
28502824 }
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28532827 return success ();
28542828}
28552829
2856- LogicalResult OperationLegalizer::legalizePatternBlockRewrites (
2857- Operation *op, const SetVector<Block *> &insertedBlocks,
2858- const SetVector<Operation *> &newOps) {
2859- ConversionPatternRewriterImpl &impl = rewriter.getImpl ();
2860- SmallPtrSet<Operation *, 16 > alreadyLegalized;
2861-
2862- // If the pattern moved or created any blocks, make sure the types of block
2863- // arguments get legalized.
2864- for (Block *block : insertedBlocks) {
2865- if (impl.erasedBlocks .contains (block))
2866- continue ;
2867-
2868- // Only check blocks outside of the current operation.
2869- Operation *parentOp = block->getParentOp ();
2870- if (!parentOp || parentOp == op || block->getNumArguments () == 0 )
2871- continue ;
2872-
2873- // If the region of the block has a type converter, try to convert the block
2874- // directly.
2875- if (auto *converter = impl.regionToConverter .lookup (block->getParent ())) {
2876- std::optional<TypeConverter::SignatureConversion> conversion =
2877- converter->convertBlockSignature (block);
2878- if (!conversion) {
2879- LLVM_DEBUG (logFailure (impl.logger , " failed to convert types of moved "
2880- " block" ));
2881- return failure ();
2882- }
2883- impl.applySignatureConversion (block, converter, *conversion);
2884- continue ;
2885- }
2886-
2887- // Otherwise, try to legalize the parent operation if it was not generated
2888- // by this pattern. This is because we will attempt to legalize the parent
2889- // operation, and blocks in regions created by this pattern will already be
2890- // legalized later on.
2891- if (!newOps.count (parentOp) && alreadyLegalized.insert (parentOp).second ) {
2892- if (failed (legalize (parentOp))) {
2893- LLVM_DEBUG (logFailure (
2894- impl.logger , " operation '{0}'({1}) became illegal after rewrite" ,
2895- parentOp->getName (), parentOp));
2896- return failure ();
2897- }
2898- }
2899- }
2900- return success ();
2901- }
2902-
29032830LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
29042831 const SetVector<Operation *> &newOps) {
29052832 for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
38003727 TypeConverter::SignatureConversion result (type.getNumInputs ());
38013728 SmallVector<Type, 1 > newResults;
38023729 if (failed (typeConverter.convertSignatureArgs (type.getInputs (), result)) ||
3803- failed (typeConverter.convertTypes (type.getResults (), newResults)) ||
3804- failed (rewriter.convertRegionTypes (&funcOp.getFunctionBody (),
3805- typeConverter, &result)))
3730+ failed (typeConverter.convertTypes (type.getResults (), newResults)))
38063731 return failure ();
3732+ if (!funcOp.getFunctionBody ().empty ())
3733+ rewriter.applySignatureConversion (&funcOp.getFunctionBody ().front (), result,
3734+ &typeConverter);
38073735
38083736 // Update the function signature in-place.
38093737 auto newType = FunctionType::get (rewriter.getContext (),
0 commit comments