-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][Transforms] Dialect Conversion: Convert entry block only #165180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Transforms] Dialect Conversion: Convert entry block only #165180
Conversation
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of #165180, which changes the way blocks are converted. (Only entry blocks are converted.)
b91825b to
a9aa0a2
Compare
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of llvm#165180, which changes the way blocks are converted. (Only entry blocks are converted.)
a9aa0a2 to
e805f6f
Compare
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesWhen converting a function, convert only the entry block signature. The remaining block signatures should be converted by the respective branching ops. The struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
LogicalResult
matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert successor block.
SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
TypeRange(ValueRange(flattenedAdaptor)));
// ...
}
};This is consistent with the fact that operations from unreachable blocks are not put on the initial worklist. With this change, parent ops are no longer recursively legalized when inserting a block, simplifying the conversion driver a bit. Note for LLVM integration: If you are seeing failures, make sure to:
Full diff: https://github.com/llvm/llvm-project/pull/165180.diff 3 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..2fe06970eb568 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;
- /// A set of blocks that were inserted (newly-created blocks or moved blocks)
- /// by the current pattern.
- SetVector<Block *> patternInsertedBlocks;
-
/// A list of unresolved materializations that were created by the current
/// pattern.
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
if (!config.allowPatternRollback && config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
- patternInsertedBlocks.insert(block);
-
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
- LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- const RewriterState &curState,
- const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks);
-
- /// Legalizes the actions registered during the execution of a pattern.
LogicalResult
- legalizePatternBlockRewrites(Operation *op,
- const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps);
+ legalizePatternResult(Operation *op, const Pattern &pattern,
+ const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps);
+
LogicalResult
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto cleanup = llvm::make_scope_exit([&]() {
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
});
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
static void
reportNewIrLegalizationFatalError(const Pattern &pattern,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
auto newOpNames = llvm::map_range(
newOps, [](Operation *op) { return op->getName().getStringRef(); });
auto modifiedOpNames = llvm::map_range(
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
- StringRef detachedBlockStr = "(detached block)";
- auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
- if (block->getParentOp())
- return block->getParentOp()->getName().getStringRef();
- return detachedBlockStr;
- });
- llvm::report_fatal_error(
- "pattern '" + pattern.getDebugName() +
- "' produced IR that could not be legalized. " + "new ops: {" +
- llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
- llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
- llvm::join(insertedBlockNames, ", ") + "}");
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' produced IR that could not be legalized. " +
+ "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+ "modified ops: {" +
+ llvm::join(modifiedOpNames, ", ") + "}");
}
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
}
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
- SetVector<Block *> insertedBlocks =
- moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, curState, newOps,
- modifiedOps, insertedBlocks);
+ auto result =
+ legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
- reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
- insertedBlocks);
+ reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, const RewriterState &curState,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
[[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
- failed(legalizePatternRootUpdates(modifiedOps)) ||
+ if (failed(legalizePatternRootUpdates(modifiedOps)) ||
failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return success();
}
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps) {
- ConversionPatternRewriterImpl &impl = rewriter.getImpl();
- SmallPtrSet<Operation *, 16> alreadyLegalized;
-
- // If the pattern moved or created any blocks, make sure the types of block
- // arguments get legalized.
- for (Block *block : insertedBlocks) {
- if (impl.erasedBlocks.contains(block))
- continue;
-
- // Only check blocks outside of the current operation.
- Operation *parentOp = block->getParentOp();
- if (!parentOp || parentOp == op || block->getNumArguments() == 0)
- continue;
-
- // If the region of the block has a type converter, try to convert the block
- // directly.
- if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion) {
- LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
- "block"));
- return failure();
- }
- impl.applySignatureConversion(block, converter, *conversion);
- continue;
- }
-
- // Otherwise, try to legalize the parent operation if it was not generated
- // by this pattern. This is because we will attempt to legalize the parent
- // operation, and blocks in regions created by this pattern will already be
- // legalized later on.
- if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "operation '{0}'({1}) became illegal after rewrite",
- parentOp->getName(), parentOp));
- return failure();
- }
- }
- }
- return success();
-}
-
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
TypeConverter::SignatureConversion result(type.getNumInputs());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
- failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
- failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
- typeConverter, &result)))
+ failed(typeConverter.convertTypes(type.getResults(), newResults)))
return failure();
+ if (!funcOp.getFunctionBody().empty())
+ rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+ &typeConverter);
// Update the function signature in-place.
auto newType = FunctionType::get(rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..ba1f962fdb68b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
// -----
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) : () -> ()
- // expected-remark@+1 {{op 'func.return' is not legalizable}}
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) {legalizer.should_clone} : () -> ()
- // expected-remark@+1 {{op 'func.return' is not legalizable}}
- return
-}
-
// CHECK-LABEL: func @remap_drop_region
func.func @remap_drop_region() {
// CHECK-NEXT: return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index fd2b943ff1296..12edecc113495 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1553,8 +1553,7 @@ struct TestLegalizePatternDriver
[](Type type) { return type.isF32(); });
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
@@ -2156,8 +2155,7 @@ struct TestTypeConversionDriver
recursiveType.getName() == "outer_converted_type");
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
// Allow casts from F64 to F32.
|
e805f6f to
7bee741
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. I tried integrating downstream and am only seeing a single failure iree-org/iree#22520.
| @@ -0,0 +1,67 @@ | |||
| // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Splitting up the legalization test looks unrelated, do we have to do it as a part of this change?
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of llvm#165180, which changes the way blocks are converted. (Only entry blocks are converted.)
When converting a function, convert only the entry block signature. The remaining block signatures should be converted by the respective branching ops. The
FuncToLLVM/ControlFlowToLLVMpatterns already use that design.This is consistent with the fact that operations from unreachable blocks are not put on the initial worklist.
With this change, parent ops are no longer recursively legalized when inserting a block, simplifying the conversion driver a bit.
Note for LLVM integration: If you are seeing failures, make sure to:
converter.isLegal(&op.getBody())when checking the legality of a function op. Only the entry block signature / function type should be taken into account.cfbranching ops, addpopulateCFStructuralTypeConversionsAndLegality.populateCFStructuralTypeConversionsAndLegality.