Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove all refops if a basic block is a raising block #1045

Closed
wants to merge 12 commits into from
17 changes: 16 additions & 1 deletion ffi/custom_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,22 @@ struct RefPrunePass : public FunctionPass {
stats_per_bb += 1;
}

// Second: Find matching pairs of incref decref
// Second: if it's a Raising block, then remove all refs
if (isRaising(&bb)) {
for (CallInst *ci : incref_list) {
ci->eraseFromParent();
mutated = true;
stats_per_bb += 1;
}
for (CallInst *ci : decref_list) {
ci->eraseFromParent();
mutated = true;
stats_per_bb += 1;
}
continue;
}

// Third: Find matching pairs of incref decref
while (incref_list.size() > 0) {
// get an incref
CallInst *incref = incref_list.pop_back_val();
Expand Down
26 changes: 26 additions & 0 deletions llvmlite/tests/test_refprune.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,32 @@ def test_per_bb_4(self):
# not pruned
self.assertIn("call void @NRT_decref(i8* %other)", str(mod))

# test case is for removing all refs if the BB is raising
per_bb_ir_5 = r"""
define i32 @main(i8* %ptr, i1 %cond1, i1 %cond2, i8** %excinfo) {
bb_A:
br i1 %cond1, label %bb_C, label %bb_B
bb_B:
br i1 %cond2, label %bb_D, label %bb_C
bb_C:
%sroa = phi i8* [ %ptr, %bb_A ], [ null, %bb_B ]
tail call void @NRT_decref(i8* %ptr)
tail call void @NRT_decref(i8* %sroa)
store i8* null, i8** %excinfo, !numba_exception_output !0
br label %common.ret
bb_D:
br label %common.ret
common.ret:
%common.ret.op = phi i32 [ 0, %bb_D ], [ 1, %bb_C ]
ret i32 %common.ret.op
}
!0 = !{i1 1}
"""

def test_per_bb_5(self):
mod, stats = self.check(self.per_bb_ir_5)
self.assertEqual(stats.basicblock, 2)


class TestDiamond(BaseTestByIR):
refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND
Expand Down