Skip to content

Commit

Permalink
[MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriv…
Browse files Browse the repository at this point in the history
…er (llvm#77258)

Make GreedyPatternRewriteDriver handle failures of `materializeConstant`
gracefully. Previously it was not checking whether the returned op was
null and crashing. This PR handles it similarly to how OperationFolder
does it.
  • Loading branch information
zyx-billy authored Jan 8, 2024
1 parent c1023c5 commit eb42868
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
34 changes: 29 additions & 5 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
Expand All @@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
Expand All @@ -462,18 +463,41 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());

if (!constOp) {
// If materialization fails, cleanup any operations generated for
// the previous results.
llvm::SmallDenseSet<Operation *> replacementOps;
for (Value replacement : replacements) {
assert(replacement.use_empty() &&
"folder reused existing op for one result but constant "
"materialization failed for another result");
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
eraseOp(op);
}

materializationSucceeded = false;
break;
}

assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
replaceOp(op, replacements);

if (materializationSucceeded) {
replaceOp(op, replacements);
changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
continue;
}
}
}

Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
// CHECK-NEXT: scf.yield %[[ALLOC3_2]]
// CHECK: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: return %[[ALLOC2]]

// -----

// CHECK-LABEL: func @test_materialize_failure
func.func @test_materialize_failure() -> i64 {
%const = index.constant 1234
// Cannot materialize this castu's output constant.
// CHECK: index.castu
%u = index.castu %const : index to i64
return %u: i64
}

0 comments on commit eb42868

Please sign in to comment.