From eb42868f25665ba6301a94a30e9df33e0d6ae61f Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Mon, 8 Jan 2024 10:29:32 -0800 Subject: [PATCH] [MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver (#77258) Make GreedyPatternRewriteDriver handle failures of `materializeConstant` gracefully. Previously it was not checking whether the returned op was null and crashing. This PR handles it similarly to how OperationFolder does it. --- .../Utils/GreedyPatternRewriteDriver.cpp | 34 ++++++++++++++++--- mlir/test/Transforms/canonicalize.mlir | 11 ++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 82438e2bf706c1..67c2d9d59f4c92 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() { SmallVector foldResults; if (succeeded(op->fold(foldResults))) { LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - changed = true; if (foldResults.empty()) { // Op was modified in-place. notifyOperationModified(op); + changed = true; #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (config.scope && failed(verify(config.scope->getParentOp()))) llvm::report_fatal_error("IR failed to verify after folding"); @@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { OpBuilder::InsertionGuard g(*this); setInsertionPoint(op); SmallVector replacements; + bool materializationSucceeded = true; for (auto [ofr, resultType] : llvm::zip_equal(foldResults, op->getResultTypes())) { if (auto value = ofr.dyn_cast()) { @@ -462,18 +463,41 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Materialize Attributes as SSA values. Operation *constOp = op->getDialect()->materializeConstant( *this, ofr.get(), resultType, op->getLoc()); + + if (!constOp) { + // If materialization fails, cleanup any operations generated for + // the previous results. + llvm::SmallDenseSet replacementOps; + for (Value replacement : replacements) { + assert(replacement.use_empty() && + "folder reused existing op for one result but constant " + "materialization failed for another result"); + replacementOps.insert(replacement.getDefiningOp()); + } + for (Operation *op : replacementOps) { + eraseOp(op); + } + + materializationSucceeded = false; + break; + } + assert(constOp->hasTrait() && "materializeConstant produced op that is not a ConstantLike"); assert(constOp->getResultTypes()[0] == resultType && "materializeConstant produced incorrect result type"); replacements.push_back(constOp->getResult(0)); } - replaceOp(op, replacements); + + if (materializationSucceeded) { + replaceOp(op, replacements); + changed = true; #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope && failed(verify(config.scope->getParentOp()))) - llvm::report_fatal_error("IR failed to verify after folding"); + if (config.scope && failed(verify(config.scope->getParentOp()))) + llvm::report_fatal_error("IR failed to verify after folding"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - continue; + continue; + } } } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 47a19bb598c25e..9b578e6c2631a7 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr // CHECK-NEXT: scf.yield %[[ALLOC3_2]] // CHECK: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: return %[[ALLOC2]] + +// ----- + +// CHECK-LABEL: func @test_materialize_failure +func.func @test_materialize_failure() -> i64 { + %const = index.constant 1234 + // Cannot materialize this castu's output constant. + // CHECK: index.castu + %u = index.castu %const : index to i64 + return %u: i64 +}