diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 17a652b..5149ccd 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -762,11 +762,16 @@ def codegen_view() -> list[Any]: # A loop region gives rise to a Python while True loop. We # recursively visit the body. rval = [ + ast.Assign( + [ast.Name("__loop_cont__")], + ast.Constant(True), + lineno=0, + ), ast.While( - test=ast.Constant(value=True), + test=ast.Name("__loop_cont__"), body=codegen_view(), orelse=[], - ) + ), ] else: raise NotImplementedError @@ -791,28 +796,17 @@ def codegen_view() -> list[Any]: # special reserved variable. return [ast.Return(ast.Name("__return_value__"))] elif type(block) is SyntheticExitingLatch: - # The synthetic exiting latch much create a query on the variable - # it holds and then insert a Python if that will either break or - # continue. This effectively generates the backedge for the looping - # region. + # The synthetic exiting latch simply assigns the negated value of + # the exit variable to '__loop_cont__'. assert len(block.jump_targets) == 1 assert len(block.backedges) == 1 - compare_value = [ - i - for i, v in block.branch_value_table.items() - if v == block.backedges[0] - ][0] - if_beak_node_test = ast.Compare( - left=ast.Name(block.variable), - ops=[ast.Eq()], - comparators=[ast.Constant(compare_value)], - ) - if_break_node = ast.If( - test=if_beak_node_test, - body=[ast.Continue()], - orelse=[ast.Break()], - ) - return [if_break_node] + return [ + ast.Assign( + [ast.Name("__loop_cont__")], + ast.UnaryOp(ast.Not(), ast.Name(block.variable)), + lineno=0, + ) + ] elif type(block) in (SyntheticExitBranch, SyntheticHead): # Both the Synthetic exit branch and the synthetic head contain a # branching statement with potentially multiple outgoing branches.