Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 29, 2025

Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of #165180, which changes the way blocks are converted. (Only entry blocks are converted.)

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns.


Full diff: https://github.com/llvm/llvm-project/pull/165629.diff

6 Files Affected:

  • (added) mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h (+48)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp (+169)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+7)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const ValueRange &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(ValueRange(flattenedAdaptor)));
+    if (failed(convertedBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+                                              *convertedBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(ValueRange(flattenedAdaptorTrue)));
+    if (failed(convertedTrueBlock))
+      return failure();
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(ValueRange(flattenedAdaptorFalse)));
+    if (failed(convertedFalseBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+        *convertedTrueBlock, *convertedFalseBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
+      return failure();
+
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
+        return failure();
+      caseDestinations.push_back(*convertedBlock);
+    }
+
+    rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+      typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target) {
+  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit) {
+  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+  populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+//       CHECK:   "test.foo"(%[[cast1]])
+//       CHECK:   cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+//       CHECK:   cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+//       CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+//       CHECK:   "test.bar"(%[[cast2]])
+//       CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+  "test.foo"(%arg0) : (f32) -> ()
+  cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+  cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+  "test.bar"(%arg2) : (f32) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
   )
 mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRControlFlowInterfaces
+  MLIRControlFlowTransforms
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
   MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
     });
     converter.addConversion([](IndexType type) { return type; });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(1)) {
+        // i1 is legal.
+        types.push_back(type);
+      }
       if (type.isInteger(38)) {
         // i38 is legal.
         types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
                                                               converter);
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
+    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+                                                             patterns, target);
 
     ConversionConfig config;
     config.allowPatternRollback = allowPatternRollback;

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2025

@llvm/pr-subscribers-mlir-cf

Author: Matthias Springer (matthias-springer)

Changes

Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns.


Full diff: https://github.com/llvm/llvm-project/pull/165629.diff

6 Files Affected:

  • (added) mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h (+48)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp (+169)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+7)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const ValueRange &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(ValueRange(flattenedAdaptor)));
+    if (failed(convertedBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+                                              *convertedBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(ValueRange(flattenedAdaptorTrue)));
+    if (failed(convertedTrueBlock))
+      return failure();
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(ValueRange(flattenedAdaptorFalse)));
+    if (failed(convertedFalseBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+        *convertedTrueBlock, *convertedFalseBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
+      return failure();
+
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
+        return failure();
+      caseDestinations.push_back(*convertedBlock);
+    }
+
+    rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+      typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target) {
+  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit) {
+  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+  populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+//       CHECK:   "test.foo"(%[[cast1]])
+//       CHECK:   cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+//       CHECK:   cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+//       CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+//       CHECK:   "test.bar"(%[[cast2]])
+//       CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+  "test.foo"(%arg0) : (f32) -> ()
+  cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+  cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+  "test.bar"(%arg2) : (f32) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
   )
 mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRControlFlowInterfaces
+  MLIRControlFlowTransforms
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
   MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
     });
     converter.addConversion([](IndexType type) { return type; });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(1)) {
+        // i1 is legal.
+        types.push_back(type);
+      }
       if (type.isInteger(38)) {
         // i38 is legal.
         types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
                                                               converter);
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
+    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+                                                             patterns, target);
 
     ConversionConfig config;
     config.allowPatternRollback = allowPatternRollback;

Copy link
Contributor

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit ca84e9e into main Oct 30, 2025
13 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/cf_structural_conv branch October 30, 2025 01:12
rupprecht added a commit to rupprecht/llvm-project that referenced this pull request Oct 30, 2025
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of llvm#165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
luciechoi pushed a commit to luciechoi/llvm-project that referenced this pull request Nov 1, 2025
DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.

This commit adds missing functionality and is in preparation of llvm#165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants