Skip to content

Commit b91825b

Browse files
[mlir][Transforms] Dialect Conversion: Convert entry block only
1 parent 5142707 commit b91825b

File tree

2 files changed

+20
-122
lines changed

2 files changed

+20
-122
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11051105
/// A set of operations that were modified by the current pattern.
11061106
SetVector<Operation *> patternModifiedOps;
11071107

1108-
/// A set of blocks that were inserted (newly-created blocks or moved blocks)
1109-
/// by the current pattern.
1110-
SetVector<Block *> patternInsertedBlocks;
1111-
11121108
/// A list of unresolved materializations that were created by the current
11131109
/// pattern.
11141110
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
20462042
if (!config.allowPatternRollback && config.listener)
20472043
config.listener->notifyBlockInserted(block, previous, previousIt);
20482044

2049-
patternInsertedBlocks.insert(block);
2050-
20512045
if (wasDetached) {
20522046
// If the block was detached, it is most likely a newly created block.
20532047
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
23992393
bool canApplyPattern(Operation *op, const Pattern &pattern);
24002394

24012395
/// Legalize the resultant IR after successfully applying the given pattern.
2402-
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
2403-
const RewriterState &curState,
2404-
const SetVector<Operation *> &newOps,
2405-
const SetVector<Operation *> &modifiedOps,
2406-
const SetVector<Block *> &insertedBlocks);
2407-
2408-
/// Legalizes the actions registered during the execution of a pattern.
24092396
LogicalResult
2410-
legalizePatternBlockRewrites(Operation *op,
2411-
const SetVector<Block *> &insertedBlocks,
2412-
const SetVector<Operation *> &newOps);
2397+
legalizePatternResult(Operation *op, const Pattern &pattern,
2398+
const RewriterState &curState,
2399+
const SetVector<Operation *> &newOps,
2400+
const SetVector<Operation *> &modifiedOps);
2401+
24132402
LogicalResult
24142403
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
24152404
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26082597
auto cleanup = llvm::make_scope_exit([&]() {
26092598
rewriterImpl.patternNewOps.clear();
26102599
rewriterImpl.patternModifiedOps.clear();
2611-
rewriterImpl.patternInsertedBlocks.clear();
26122600
});
26132601

26142602
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26622650
static void
26632651
reportNewIrLegalizationFatalError(const Pattern &pattern,
26642652
const SetVector<Operation *> &newOps,
2665-
const SetVector<Operation *> &modifiedOps,
2666-
const SetVector<Block *> &insertedBlocks) {
2653+
const SetVector<Operation *> &modifiedOps) {
26672654
auto newOpNames = llvm::map_range(
26682655
newOps, [](Operation *op) { return op->getName().getStringRef(); });
26692656
auto modifiedOpNames = llvm::map_range(
26702657
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2671-
StringRef detachedBlockStr = "(detached block)";
2672-
auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
2673-
if (block->getParentOp())
2674-
return block->getParentOp()->getName().getStringRef();
2675-
return detachedBlockStr;
2676-
});
2677-
llvm::report_fatal_error(
2678-
"pattern '" + pattern.getDebugName() +
2679-
"' produced IR that could not be legalized. " + "new ops: {" +
2680-
llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
2681-
llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
2682-
llvm::join(insertedBlockNames, ", ") + "}");
2658+
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2659+
"' produced IR that could not be legalized. " +
2660+
"new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2661+
"modified ops: {" +
2662+
llvm::join(modifiedOpNames, ", ") + "}");
26832663
}
26842664

26852665
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27432723
}
27442724
rewriterImpl.patternNewOps.clear();
27452725
rewriterImpl.patternModifiedOps.clear();
2746-
rewriterImpl.patternInsertedBlocks.clear();
27472726
LLVM_DEBUG({
27482727
logFailure(rewriterImpl.logger, "pattern failed to match");
27492728
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27772756
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
27782757
SetVector<Operation *> modifiedOps =
27792758
moveAndReset(rewriterImpl.patternModifiedOps);
2780-
SetVector<Block *> insertedBlocks =
2781-
moveAndReset(rewriterImpl.patternInsertedBlocks);
2782-
auto result = legalizePatternResult(op, pattern, curState, newOps,
2783-
modifiedOps, insertedBlocks);
2759+
auto result =
2760+
legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
27842761
appliedPatterns.erase(&pattern);
27852762
if (failed(result)) {
27862763
if (!rewriterImpl.config.allowPatternRollback)
2787-
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
2788-
insertedBlocks);
2764+
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
27892765
rewriterImpl.resetState(curState, pattern.getDebugName());
27902766
}
27912767
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
28232799
LogicalResult OperationLegalizer::legalizePatternResult(
28242800
Operation *op, const Pattern &pattern, const RewriterState &curState,
28252801
const SetVector<Operation *> &newOps,
2826-
const SetVector<Operation *> &modifiedOps,
2827-
const SetVector<Block *> &insertedBlocks) {
2802+
const SetVector<Operation *> &modifiedOps) {
28282803
[[maybe_unused]] auto &impl = rewriter.getImpl();
28292804
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
28302805

@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28432818
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
28442819

28452820
// Legalize each of the actions registered during application.
2846-
if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2847-
failed(legalizePatternRootUpdates(modifiedOps)) ||
2821+
if (failed(legalizePatternRootUpdates(modifiedOps)) ||
28482822
failed(legalizePatternCreatedOperations(newOps))) {
28492823
return failure();
28502824
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28532827
return success();
28542828
}
28552829

2856-
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2857-
Operation *op, const SetVector<Block *> &insertedBlocks,
2858-
const SetVector<Operation *> &newOps) {
2859-
ConversionPatternRewriterImpl &impl = rewriter.getImpl();
2860-
SmallPtrSet<Operation *, 16> alreadyLegalized;
2861-
2862-
// If the pattern moved or created any blocks, make sure the types of block
2863-
// arguments get legalized.
2864-
for (Block *block : insertedBlocks) {
2865-
if (impl.erasedBlocks.contains(block))
2866-
continue;
2867-
2868-
// Only check blocks outside of the current operation.
2869-
Operation *parentOp = block->getParentOp();
2870-
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2871-
continue;
2872-
2873-
// If the region of the block has a type converter, try to convert the block
2874-
// directly.
2875-
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2876-
std::optional<TypeConverter::SignatureConversion> conversion =
2877-
converter->convertBlockSignature(block);
2878-
if (!conversion) {
2879-
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2880-
"block"));
2881-
return failure();
2882-
}
2883-
impl.applySignatureConversion(block, converter, *conversion);
2884-
continue;
2885-
}
2886-
2887-
// Otherwise, try to legalize the parent operation if it was not generated
2888-
// by this pattern. This is because we will attempt to legalize the parent
2889-
// operation, and blocks in regions created by this pattern will already be
2890-
// legalized later on.
2891-
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2892-
if (failed(legalize(parentOp))) {
2893-
LLVM_DEBUG(logFailure(
2894-
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2895-
parentOp->getName(), parentOp));
2896-
return failure();
2897-
}
2898-
}
2899-
}
2900-
return success();
2901-
}
2902-
29032830
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
29042831
const SetVector<Operation *> &newOps) {
29052832
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
38003727
TypeConverter::SignatureConversion result(type.getNumInputs());
38013728
SmallVector<Type, 1> newResults;
38023729
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3803-
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3804-
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3805-
typeConverter, &result)))
3730+
failed(typeConverter.convertTypes(type.getResults(), newResults)))
38063731
return failure();
3732+
if (!funcOp.getFunctionBody().empty())
3733+
rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3734+
&typeConverter);
38073735

38083736
// Update the function signature in-place.
38093737
auto newType = FunctionType::get(rewriter.getContext(),

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
146146

147147
// -----
148148

149-
// CHECK-LABEL: func @remap_moved_region_args
150-
func.func @remap_moved_region_args() {
151-
// CHECK-NEXT: return
152-
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
153-
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
154-
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
155-
"test.region"() ({
156-
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
157-
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
158-
}) : () -> ()
159-
// expected-remark@+1 {{op 'func.return' is not legalizable}}
160-
return
161-
}
162-
163-
// -----
164-
165-
// CHECK-LABEL: func @remap_cloned_region_args
166-
func.func @remap_cloned_region_args() {
167-
// CHECK-NEXT: return
168-
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
169-
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
170-
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
171-
"test.region"() ({
172-
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
173-
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
174-
}) {legalizer.should_clone} : () -> ()
175-
// expected-remark@+1 {{op 'func.return' is not legalizable}}
176-
return
177-
}
178-
179149
// CHECK-LABEL: func @remap_drop_region
180150
func.func @remap_drop_region() {
181151
// CHECK-NEXT: return

0 commit comments

Comments
 (0)