Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 26, 2025

When converting a function, convert only the entry block signature. The remaining block signatures should be converted by the respective branching ops. The FuncToLLVM / ControlFlowToLLVM patterns already use that design.

struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {

  LogicalResult
  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Convert successor block.
    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
    FailureOr<Block *> convertedBlock =
        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
                          TypeRange(ValueRange(flattenedAdaptor)));
    // ...
  }
};

This is consistent with the fact that operations from unreachable blocks are not put on the initial worklist.

With this change, parent ops are no longer recursively legalized when inserting a block, simplifying the conversion driver a bit.

Note for LLVM integration: If you are seeing failures, make sure to:

  • Drop converter.isLegal(&op.getBody()) when checking the legality of a function op. Only the entry block signature / function type should be taken into account.
  • If you need to convert all reachable blocks and are using cf branching ops, add populateCFStructuralTypeConversionsAndLegality.
  • If you need to convert all reachable blocks and are using custom branching ops, implement and populate custom structural type conversion patterns, similar to populateCFStructuralTypeConversionsAndLegality.

matthias-springer added a commit that referenced this pull request Oct 30, 2025
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of #165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
@matthias-springer matthias-springer force-pushed the users/matthias-springer/func_convert_entry_block branch from b91825b to a9aa0a2 Compare October 30, 2025 15:14
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of llvm#165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
@matthias-springer matthias-springer force-pushed the users/matthias-springer/func_convert_entry_block branch from a9aa0a2 to e805f6f Compare November 2, 2025 04:56
@matthias-springer matthias-springer marked this pull request as ready for review November 2, 2025 05:03
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

When converting a function, convert only the entry block signature. The remaining block signatures should be converted by the respective branching ops. The FuncToLLVM / ControlFlowToLLVM patterns already use that design.

struct BranchOpLowering : public ConvertOpToLLVMPattern&lt;cf::BranchOp&gt; {

  LogicalResult
  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &amp;rewriter) const override {
    // Convert successor block.
    SmallVector&lt;Value&gt; flattenedAdaptor = flattenValues(adaptor.getOperands());
    FailureOr&lt;Block *&gt; convertedBlock =
        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
                          TypeRange(ValueRange(flattenedAdaptor)));
    // ...
  }
};

This is consistent with the fact that operations from unreachable blocks are not put on the initial worklist.

With this change, parent ops are no longer recursively legalized when inserting a block, simplifying the conversion driver a bit.

Note for LLVM integration: If you are seeing failures, make sure to:

  • Drop converter.isLegal(&amp;op.getBody()) when checking the legality of a function op. Only the entry block signature / function type should be taken into account.
  • If you need to convert all reachable blocks and are using cf branching ops, add populateCFStructuralTypeConversionsAndLegality.
  • If you need to convert all reachable blocks and are using custom branching ops, implement and populate custom structural type conversion patterns, similar to populateCFStructuralTypeConversionsAndLegality.

Full diff: https://github.com/llvm/llvm-project/pull/165180.diff

3 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-92)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (-30)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-4)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..2fe06970eb568 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// A set of operations that were modified by the current pattern.
   SetVector<Operation *> patternModifiedOps;
 
-  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
-  /// by the current pattern.
-  SetVector<Block *> patternInsertedBlocks;
-
   /// A list of unresolved materializations that were created by the current
   /// pattern.
   DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
   if (!config.allowPatternRollback && config.listener)
     config.listener->notifyBlockInserted(block, previous, previousIt);
 
-  patternInsertedBlocks.insert(block);
-
   if (wasDetached) {
     // If the block was detached, it is most likely a newly created block.
     if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
   bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
-  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      const RewriterState &curState,
-                                      const SetVector<Operation *> &newOps,
-                                      const SetVector<Operation *> &modifiedOps,
-                                      const SetVector<Block *> &insertedBlocks);
-
-  /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
-  legalizePatternBlockRewrites(Operation *op,
-                               const SetVector<Block *> &insertedBlocks,
-                               const SetVector<Operation *> &newOps);
+  legalizePatternResult(Operation *op, const Pattern &pattern,
+                        const RewriterState &curState,
+                        const SetVector<Operation *> &newOps,
+                        const SetVector<Operation *> &modifiedOps);
+
   LogicalResult
   legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto cleanup = llvm::make_scope_exit([&]() {
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
   });
 
   // Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
 static void
 reportNewIrLegalizationFatalError(const Pattern &pattern,
                                   const SetVector<Operation *> &newOps,
-                                  const SetVector<Operation *> &modifiedOps,
-                                  const SetVector<Block *> &insertedBlocks) {
+                                  const SetVector<Operation *> &modifiedOps) {
   auto newOpNames = llvm::map_range(
       newOps, [](Operation *op) { return op->getName().getStringRef(); });
   auto modifiedOpNames = llvm::map_range(
       modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
-  StringRef detachedBlockStr = "(detached block)";
-  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
-    if (block->getParentOp())
-      return block->getParentOp()->getName().getStringRef();
-    return detachedBlockStr;
-  });
-  llvm::report_fatal_error(
-      "pattern '" + pattern.getDebugName() +
-      "' produced IR that could not be legalized. " + "new ops: {" +
-      llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
-      llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
-      llvm::join(insertedBlockNames, ", ") + "}");
+  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+                           "' produced IR that could not be legalized. " +
+                           "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+                           "modified ops: {" +
+                           llvm::join(modifiedOpNames, ", ") + "}");
 }
 
 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     }
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
     SetVector<Operation *> modifiedOps =
         moveAndReset(rewriterImpl.patternModifiedOps);
-    SetVector<Block *> insertedBlocks =
-        moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, curState, newOps,
-                                        modifiedOps, insertedBlocks);
+    auto result =
+        legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
-        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
-                                          insertedBlocks);
+        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
       rewriterImpl.resetState(curState, pattern.getDebugName());
     }
     if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
 LogicalResult OperationLegalizer::legalizePatternResult(
     Operation *op, const Pattern &pattern, const RewriterState &curState,
     const SetVector<Operation *> &newOps,
-    const SetVector<Operation *> &modifiedOps,
-    const SetVector<Block *> &insertedBlocks) {
+    const SetVector<Operation *> &modifiedOps) {
   [[maybe_unused]] auto &impl = rewriter.getImpl();
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
-      failed(legalizePatternRootUpdates(modifiedOps)) ||
+  if (failed(legalizePatternRootUpdates(modifiedOps)) ||
       failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   return success();
 }
 
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, const SetVector<Block *> &insertedBlocks,
-    const SetVector<Operation *> &newOps) {
-  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
-  SmallPtrSet<Operation *, 16> alreadyLegalized;
-
-  // If the pattern moved or created any blocks, make sure the types of block
-  // arguments get legalized.
-  for (Block *block : insertedBlocks) {
-    if (impl.erasedBlocks.contains(block))
-      continue;
-
-    // Only check blocks outside of the current operation.
-    Operation *parentOp = block->getParentOp();
-    if (!parentOp || parentOp == op || block->getNumArguments() == 0)
-      continue;
-
-    // If the region of the block has a type converter, try to convert the block
-    // directly.
-    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      std::optional<TypeConverter::SignatureConversion> conversion =
-          converter->convertBlockSignature(block);
-      if (!conversion) {
-        LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
-                                           "block"));
-        return failure();
-      }
-      impl.applySignatureConversion(block, converter, *conversion);
-      continue;
-    }
-
-    // Otherwise, try to legalize the parent operation if it was not generated
-    // by this pattern. This is because we will attempt to legalize the parent
-    // operation, and blocks in regions created by this pattern will already be
-    // legalized later on.
-    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
-      if (failed(legalize(parentOp))) {
-        LLVM_DEBUG(logFailure(
-            impl.logger, "operation '{0}'({1}) became illegal after rewrite",
-            parentOp->getName(), parentOp));
-        return failure();
-      }
-    }
-  }
-  return success();
-}
-
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     const SetVector<Operation *> &newOps) {
   for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
   TypeConverter::SignatureConversion result(type.getNumInputs());
   SmallVector<Type, 1> newResults;
   if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
-      failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
-      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
-                                         typeConverter, &result)))
+      failed(typeConverter.convertTypes(type.getResults(), newResults)))
     return failure();
+  if (!funcOp.getFunctionBody().empty())
+    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+                                      &typeConverter);
 
   // Update the function signature in-place.
   auto newType = FunctionType::get(rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..ba1f962fdb68b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
 
 // -----
 
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) : () -> ()
-  // expected-remark@+1 {{op 'func.return' is not legalizable}}
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) {legalizer.should_clone} : () -> ()
-  // expected-remark@+1 {{op 'func.return' is not legalizable}}
-  return
-}
-
 // CHECK-LABEL: func @remap_drop_region
 func.func @remap_drop_region() {
   // CHECK-NEXT: return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index fd2b943ff1296..12edecc113495 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1553,8 +1553,7 @@ struct TestLegalizePatternDriver
                            [](Type type) { return type.isF32(); });
     });
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
-      return converter.isSignatureLegal(op.getFunctionType()) &&
-             converter.isLegal(&op.getBody());
+      return converter.isSignatureLegal(op.getFunctionType());
     });
     target.addDynamicallyLegalOp<func::CallOp>(
         [&](func::CallOp op) { return converter.isLegal(op); });
@@ -2156,8 +2155,7 @@ struct TestTypeConversionDriver
               recursiveType.getName() == "outer_converted_type");
     });
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
-      return converter.isSignatureLegal(op.getFunctionType()) &&
-             converter.isLegal(&op.getBody());
+      return converter.isSignatureLegal(op.getFunctionType());
     });
     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
       // Allow casts from F64 to F32.

@matthias-springer matthias-springer marked this pull request as draft November 2, 2025 05:31
@matthias-springer matthias-springer marked this pull request as ready for review November 2, 2025 05:42
@matthias-springer matthias-springer force-pushed the users/matthias-springer/func_convert_entry_block branch from e805f6f to 7bee741 Compare November 2, 2025 05:42
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. I tried integrating downstream and am only seeing a single failure iree-org/iree#22520.

@@ -0,0 +1,67 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting up the legalization test looks unrelated, do we have to do it as a part of this change?

DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of llvm#165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants