diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h new file mode 100644 index 0000000000000..a32d9e2025c76 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h @@ -0,0 +1,48 @@ +//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H +#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { + +class ConversionTarget; +class TypeConverter; + +namespace cf { + +/// Populates patterns for CF structural type conversions and sets up the +/// provided ConversionTarget with the appropriate legality configuration for +/// the ops to get converted properly. +/// +/// A "structural" type conversion is one where the underlying ops are +/// completely agnostic to the actual types involved and simply need to update +/// their types. An example of this is cf.br -- the cf.br op needs to update +/// its types accordingly to the TypeConverter, but otherwise does not care +/// what type conversions are happening. +void populateCFStructuralTypeConversionsAndLegality( + const TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, PatternBenefit benefit = 1); + +/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not +/// populate the conversion target. +void populateCFStructuralTypeConversions(const TypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Updates the ConversionTarget with dynamic legality of CF operations based +/// on the provided type converter. +void populateCFStructuralTypeConversionTarget( + const TypeConverter &typeConverter, ConversionTarget &target); + +} // namespace cf +} // namespace mlir + +#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index 47740d31844f4..e9da135ed46f9 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp + StructuralTypeConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp new file mode 100644 index 0000000000000..5e2a742c2d64c --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,169 @@ +//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard and builtin dialects +// into the LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Helper function for converting branch ops. This function converts the +/// signature of the given block. If the new block signature is different from +/// `expectedTypes`, returns "failure". +static FailureOr getConvertedBlock(ConversionPatternRewriter &rewriter, + const TypeConverter *converter, + Operation *branchOp, Block *block, + TypeRange expectedTypes) { + assert(converter && "expected non-null type converter"); + assert(!block->isEntryBlock() && "entry blocks have no predecessors"); + + // There is nothing to do if the types already match. + if (block->getArgumentTypes() == expectedTypes) + return block; + + // Compute the new block argument types and convert the block. + std::optional conversion = + converter->convertBlockSignature(block); + if (!conversion) + return rewriter.notifyMatchFailure(branchOp, + "could not compute block signature"); + if (expectedTypes != conversion->getConvertedTypes()) + return rewriter.notifyMatchFailure( + branchOp, + "mismatch between adaptor operand types and computed block signature"); + return rewriter.applySignatureConversion(block, *conversion, converter); +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Convert the destination block signature (if necessary) and change the +/// operands of the branch op. +struct BranchOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector flattenedAdaptor = flattenValues(adaptor.getOperands()); + FailureOr convertedBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), + TypeRange(ValueRange(flattenedAdaptor))); + if (failed(convertedBlock)) + return failure(); + rewriter.replaceOpWithNewOp(op, flattenedAdaptor, + *convertedBlock); + return success(); + } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the branch op. +struct CondBranchOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector flattenedAdaptorTrue = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector flattenedAdaptorFalse = + flattenValues(adaptor.getFalseDestOperands()); + if (!llvm::hasSingleElement(adaptor.getCondition())) + return rewriter.notifyMatchFailure(op, + "expected single element condition"); + FailureOr convertedTrueBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), + TypeRange(ValueRange(flattenedAdaptorTrue))); + if (failed(convertedTrueBlock)) + return failure(); + FailureOr convertedFalseBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), + TypeRange(ValueRange(flattenedAdaptorFalse))); + if (failed(convertedFalseBlock)) + return failure(); + rewriter.replaceOpWithNewOp( + op, llvm::getSingleElement(adaptor.getCondition()), + flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), + *convertedTrueBlock, *convertedFalseBlock); + return success(); + } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the switch op. +struct SwitchOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get or convert default block. + FailureOr convertedDefaultBlock = getConvertedBlock( + rewriter, getTypeConverter(), op, op.getDefaultDestination(), + TypeRange(adaptor.getDefaultOperands())); + if (failed(convertedDefaultBlock)) + return failure(); + + // Get or convert all case blocks. + SmallVector caseDestinations; + SmallVector caseOperands = adaptor.getCaseOperands(); + for (auto it : llvm::enumerate(op.getCaseDestinations())) { + Block *b = it.value(); + FailureOr convertedBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, b, + TypeRange(caseOperands[it.index()])); + if (failed(convertedBlock)) + return failure(); + caseDestinations.push_back(*convertedBlock); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getFlag(), *convertedDefaultBlock, + adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), + caseDestinations, caseOperands); + return success(); + } +}; + +} // namespace + +void mlir::cf::populateCFStructuralTypeConversions( + const TypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add( + typeConverter, patterns.getContext(), benefit); +} + +void mlir::cf::populateCFStructuralTypeConversionTarget( + const TypeConverter &typeConverter, ConversionTarget &target) { + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); }); +} + +void mlir::cf::populateCFStructuralTypeConversionsAndLegality( + const TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, PatternBenefit benefit) { + populateCFStructuralTypeConversions(typeConverter, patterns, benefit); + populateCFStructuralTypeConversionTarget(typeConverter, target); +} diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index c003f8b2cb1cd..91f83a0afaeef 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() { return } +// ----- + +// CHECK-LABEL: func @test_unstructured_cf_conversion( +// CHECK-SAME: %[[arg0:.*]]: f64, %[[c:.*]]: i1) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32 +// CHECK: "test.foo"(%[[cast1]]) +// CHECK: cf.br ^[[bb1:.*]](%[[arg0]] : f64) +// CHECK: ^[[bb1]](%[[arg1:.*]]: f64): +// CHECK: cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64) +// CHECK: ^[[bb2]](%[[arg2:.*]]: f64): +// CHECK: %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32 +// CHECK: "test.bar"(%[[cast2]]) +// CHECK: return +func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) { + "test.foo"(%arg0) : (f32) -> () + cf.br ^bb1(%arg0: f32) +^bb1(%arg1: f32): + cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32) +^bb2(%arg2: f32): + "test.bar"(%arg2) : (f32) -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index f099d01abd31a..9354a85d984c9 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect ) mlir_target_link_libraries(MLIRTestDialect PUBLIC MLIRControlFlowInterfaces + MLIRControlFlowTransforms MLIRDataLayoutInterfaces MLIRDerivedAttributeOpInterface MLIRDestinationStyleOpInterface diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index efbdbfb65d65b..fd2b943ff1296 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -11,6 +11,7 @@ #include "TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver }); converter.addConversion([](IndexType type) { return type; }); converter.addConversion([](IntegerType type, SmallVectorImpl &types) { + if (type.isInteger(1)) { + // i1 is legal. + types.push_back(type); + } if (type.isInteger(38)) { // i38 is legal. types.push_back(type); @@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver converter); mlir::scf::populateSCFStructuralTypeConversionsAndLegality( converter, patterns, target); + mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter, + patterns, target); ConversionConfig config; config.allowPatternRollback = allowPatternRollback;