- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[mlir][CF] Add structural type conversion patterns #165629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][CF] Add structural type conversion patterns #165629
Conversation
| 
          
 @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. Full diff: https://github.com/llvm/llvm-project/pull/165629.diff 6 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const ValueRange &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(ValueRange(flattenedAdaptor)));
+    if (failed(convertedBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+                                              *convertedBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(ValueRange(flattenedAdaptorTrue)));
+    if (failed(convertedTrueBlock))
+      return failure();
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(ValueRange(flattenedAdaptorFalse)));
+    if (failed(convertedFalseBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+        *convertedTrueBlock, *convertedFalseBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
+      return failure();
+
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
+        return failure();
+      caseDestinations.push_back(*convertedBlock);
+    }
+
+    rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+      typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target) {
+  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit) {
+  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+  populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+//       CHECK:   "test.foo"(%[[cast1]])
+//       CHECK:   cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+//       CHECK:   cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+//       CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+//       CHECK:   "test.bar"(%[[cast2]])
+//       CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+  "test.foo"(%arg0) : (f32) -> ()
+  cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+  cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+  "test.bar"(%arg2) : (f32) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
   )
 mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRControlFlowInterfaces
+  MLIRControlFlowTransforms
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
   MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
     });
     converter.addConversion([](IndexType type) { return type; });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(1)) {
+        // i1 is legal.
+        types.push_back(type);
+      }
       if (type.isInteger(38)) {
         // i38 is legal.
         types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
                                                               converter);
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
+    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+                                                             patterns, target);
 
     ConversionConfig config;
     config.allowPatternRollback = allowPatternRollback;
 | 
    
| 
          
 @llvm/pr-subscribers-mlir-cf Author: Matthias Springer (matthias-springer) ChangesAdd structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. Full diff: https://github.com/llvm/llvm-project/pull/165629.diff 6 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const ValueRange &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(ValueRange(flattenedAdaptor)));
+    if (failed(convertedBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+                                              *convertedBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(ValueRange(flattenedAdaptorTrue)));
+    if (failed(convertedTrueBlock))
+      return failure();
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(ValueRange(flattenedAdaptorFalse)));
+    if (failed(convertedFalseBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+        *convertedTrueBlock, *convertedFalseBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
+      return failure();
+
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
+        return failure();
+      caseDestinations.push_back(*convertedBlock);
+    }
+
+    rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+      typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target) {
+  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit) {
+  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+  populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+//       CHECK:   "test.foo"(%[[cast1]])
+//       CHECK:   cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+//       CHECK:   cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+//       CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+//       CHECK:   "test.bar"(%[[cast2]])
+//       CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+  "test.foo"(%arg0) : (f32) -> ()
+  cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+  cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+  "test.bar"(%arg2) : (f32) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
   )
 mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRControlFlowInterfaces
+  MLIRControlFlowTransforms
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
   MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
     });
     converter.addConversion([](IndexType type) { return type; });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(1)) {
+        // i1 is legal.
+        types.push_back(type);
+      }
       if (type.isInteger(38)) {
         // i38 is legal.
         types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
                                                               converter);
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
+    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+                                                             patterns, target);
 
     ConversionConfig config;
     config.allowPatternRollback = allowPatternRollback;
 | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of llvm#165180, which changes the way blocks are converted. (Only entry blocks are converted.)
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of llvm#165180, which changes the way blocks are converted. (Only entry blocks are converted.)
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns.
This commit adds missing functionality and is in preparation of #165180, which changes the way blocks are converted. (Only entry blocks are converted.)