diff --git a/numba_rvsdg/core/datastructures/basic_block.py b/numba_rvsdg/core/datastructures/basic_block.py index 329aeee..ef42bc4 100644 --- a/numba_rvsdg/core/datastructures/basic_block.py +++ b/numba_rvsdg/core/datastructures/basic_block.py @@ -11,37 +11,6 @@ class BasicBlock: name: str """The corresponding name for this block. """ - _jump_targets: Tuple[str] = tuple() - """Jump targets (branch destinations) for this block""" - - backedges: Tuple[str] = tuple() - """Backedges for this block.""" - - @property - def is_exiting(self) -> bool: - return not self.jump_targets - - @property - def fallthrough(self) -> bool: - return len(self._jump_targets) == 1 - - @property - def jump_targets(self) -> Tuple[str]: - acc = [] - for j in self._jump_targets: - if j not in self.backedges: - acc.append(j) - return tuple(acc) - - def replace_backedge(self, target: str) -> "BasicBlock": - if target in self.jump_targets: - assert not self.backedges - return replace(self, backedges=(target,)) - return self - - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": - return replace(self, _jump_targets=jump_targets) - @dataclass(frozen=True) class PythonBytecodeBlock(BasicBlock): @@ -102,31 +71,6 @@ class SyntheticBranch(SyntheticBlock): variable: str = None branch_value_table: dict = None - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": - fallthrough = len(jump_targets) == 1 - old_branch_value_table = self.branch_value_table - new_branch_value_table = {} - for target in self.jump_targets: - if target not in jump_targets: - # ASSUMPTION: only one jump_target is being updated - diff = set(jump_targets).difference(self.jump_targets) - assert len(diff) == 1 - new_target = next(iter(diff)) - for k, v in old_branch_value_table.items(): - if v == target: - new_branch_value_table[k] = new_target - else: - # copy all old values - for k, v in old_branch_value_table.items(): - if v == target: - new_branch_value_table[k] = v - - return replace( - self, - _jump_targets=jump_targets, - branch_value_table=new_branch_value_table, - ) - @dataclass(frozen=True) class SyntheticHead(SyntheticBranch): @@ -156,5 +100,5 @@ class RegionBlock(BasicBlock): """ def get_full_graph(self): - graph = ChainMap(self.subregion.graph, self.headers) + graph = ChainMap(self.subregion.blocks, self.headers) return graph diff --git a/numba_rvsdg/core/datastructures/byte_flow.py b/numba_rvsdg/core/datastructures/byte_flow.py index 6ca4bd5..d448722 100644 --- a/numba_rvsdg/core/datastructures/byte_flow.py +++ b/numba_rvsdg/core/datastructures/byte_flow.py @@ -59,7 +59,7 @@ def restructure(self): def _iter_subregions(scfg: "SCFG"): - for node in scfg.graph.values(): + for node in scfg.blocks.values(): if isinstance(node, RegionBlock): yield node yield from _iter_subregions(node.subregion) diff --git a/numba_rvsdg/core/datastructures/flow_info.py b/numba_rvsdg/core/datastructures/flow_info.py index b8d59bb..112dd7c 100644 --- a/numba_rvsdg/core/datastructures/flow_info.py +++ b/numba_rvsdg/core/datastructures/flow_info.py @@ -78,19 +78,17 @@ def build_basicblocks(self: "FlowInfo", end_offset=None) -> "SCFG": for begin, end in zip(offsets, [*offsets[1:], end_offset]): name = names[begin] - targets: Tuple[str, ...] + targets: list[str] term_offset = _prev_inst_offset(end) if term_offset not in self.jump_insts: # implicit jump - targets = (names[end],) + targets = [names[end]] else: - targets = tuple(names[o] for o in self.jump_insts[term_offset]) + targets = [names[o] for o in self.jump_insts[term_offset]] block = PythonBytecodeBlock( name=name, begin=begin, - end=end, - _jump_targets=targets, - backedges=(), + end=end ) - scfg.add_block(block) + scfg.add_block(block, targets, []) return scfg diff --git a/numba_rvsdg/core/datastructures/scfg.py b/numba_rvsdg/core/datastructures/scfg.py index 8437a22..29c6308 100644 --- a/numba_rvsdg/core/datastructures/scfg.py +++ b/numba_rvsdg/core/datastructures/scfg.py @@ -60,16 +60,20 @@ def new_var_name(self, kind: str) -> str: class SCFG: """Map of Names to Blocks.""" - graph: Dict[str, BasicBlock] = field(default_factory=dict) + blocks: dict[str, BasicBlock] = field(default_factory=dict, init=False) + + _jump_targets: dict[str, list[str]] = field(default_factory=dict, init=False) + back_edges: dict[str, list[str]] = field(default_factory=dict, init=False) + name_gen: NameGenerator = field( default_factory=NameGenerator, compare=False ) def __getitem__(self, index): - return self.graph[index] + return self.blocks[index] def __contains__(self, index): - return index in self.graph + return index in self.blocks def __iter__(self): """Graph Iterator""" @@ -92,15 +96,25 @@ def __iter__(self): for i in block.subregion: yield i # finally add any jump_targets to the list of names to visit - to_visit.extend(block.jump_targets) + to_visit.extend(self.jump_targets[name]) @property def concealed_region_view(self): return ConcealedRegionView(self) + + @property + def jump_targets(self): + jump_targets = {} + for name in self._jump_targets.keys(): + jump_targets[name] = [] + for jt in self._jump_targets[name]: + if jt not in self.back_edges[name]: + jump_targets[name].append(jt) + return jump_targets def exclude_blocks(self, exclude_blocks: Set[str]) -> Iterator[str]: """Iterator over all nodes not in exclude_blocks.""" - for block in self.graph: + for block in self.blocks: if block not in exclude_blocks: yield block @@ -111,10 +125,9 @@ def find_head(self) -> str: that no other blocks are pointing to. """ - heads = set(self.graph.keys()) - for name in self.graph.keys(): - block = self.graph[name] - for jt in block.jump_targets: + heads = set(self.blocks.keys()) + for name in self.blocks.keys(): + for jt in self.jump_targets[name]: heads.discard(jt) assert len(heads) == 1 return next(iter(heads)) @@ -125,40 +138,43 @@ def compute_scc(self) -> List[Set[str]]: """ from numba_rvsdg.networkx_vendored.scc import scc + scfg = self + class GraphWrap: def __init__(self, graph): - self.graph = graph + self.blocks = graph def __getitem__(self, vertex): - out = self.graph[vertex].jump_targets + out = scfg.jump_targets[vertex] # Exclude node outside of the subgraph - return [k for k in out if k in self.graph] + return [k for k in out if k in self.blocks] def __iter__(self): - return iter(self.graph.keys()) + return iter(self.blocks.keys()) - return list(scc(GraphWrap(self.graph))) + return list(scc(GraphWrap(self.blocks))) def compute_scc_subgraph(self, subgraph) -> List[Set[str]]: """ Strongly-connected component for detecting loops inside a subgraph. """ from numba_rvsdg.networkx_vendored.scc import scc + scfg = self class GraphWrap: def __init__(self, graph, subgraph): - self.graph = graph + self.blocks = graph self.subgraph = subgraph def __getitem__(self, vertex): - out = self.graph[vertex].jump_targets + out = scfg.jump_targets[vertex] # Exclude node outside of the subgraph return [k for k in out if k in subgraph] def __iter__(self): - return iter(self.graph.keys()) + return iter(self.blocks.keys()) - return list(scc(GraphWrap(self.graph, subgraph))) + return list(scc(GraphWrap(self.blocks, subgraph))) def find_headers_and_entries( self, subgraph: Set[str] @@ -177,7 +193,7 @@ def find_headers_and_entries( headers: Set[str] = set() for outside in self.exclude_blocks(subgraph): - nodes_jump_in_loop = subgraph.intersection(self.graph[outside].jump_targets) + nodes_jump_in_loop = subgraph.intersection(self.jump_targets[outside]) headers.update(nodes_jump_in_loop) if nodes_jump_in_loop: entries.add(outside) @@ -203,19 +219,19 @@ def find_exiting_and_exits( exits: Set[str] = set() for inside in subgraph: # any node inside that points outside the loop - for jt in self.graph[inside].jump_targets: + for jt in self.jump_targets[inside]: if jt not in subgraph: exiting.add(inside) exits.add(jt) # any returns - if self.graph[inside].is_exiting: + if self.is_exiting(inside): exiting.add(inside) return sorted(exiting), sorted(exits) def is_reachable_dfs(self, begin: str, end: str): # -> TypeGuard: """Is end reachable from begin.""" seen = set() - to_vist = list(self.graph[begin].jump_targets) + to_vist = list(self.jump_targets[begin]) while True: if to_vist: block = to_vist.pop() @@ -228,15 +244,25 @@ def is_reachable_dfs(self, begin: str, end: str): # -> TypeGuard: return True elif block not in seen: seen.add(block) - if block in self.graph: - to_vist.extend(self.graph[block].jump_targets) + if block in self.blocks: + to_vist.extend(self.jump_targets[block]) - def add_block(self, basicblock: BasicBlock): - self.graph[basicblock.name] = basicblock + def add_block(self, basicblock: BasicBlock, jump_targets: List[str], back_edges: List[str]): + self.blocks[basicblock.name] = basicblock + self._jump_targets[basicblock.name] = jump_targets + self.back_edges[basicblock.name] = back_edges def remove_blocks(self, names: Set[str]): for name in names: - del self.graph[name] + del self.blocks[name] + del self._jump_targets[name] + del self.back_edges[name] + + def is_exiting(self, block_name: str) -> bool: + return not self.jump_targets[block_name] + + def is_fallthrough(self, block_name: str) -> bool: + return len(self.jump_targets[block_name]) == 1 def _insert_block( self, new_name: str, predecessors: Set[str], successors: Set[str], @@ -245,15 +271,18 @@ def _insert_block( # TODO: needs a diagram and documentaion # initialize new block new_block = block_type( - name=new_name, _jump_targets=successors, backedges=set() + name=new_name ) # add block to self - self.add_block(new_block) + self.add_block(new_block, successors, []) # Replace any arcs from any of predecessors to any of successors with # an arc through the inserted block instead. for name in predecessors: - block = self.graph.pop(name) - jt = list(block.jump_targets) + if hasattr(self.blocks[name], 'branch_value_table'): + for key, value in self.blocks[name].branch_value_table.items(): + if value in successors: + self.blocks[name].branch_value_table[key] = new_name + jt = list(self.jump_targets[name]) if successors: for s in successors: if s in jt: @@ -263,7 +292,7 @@ def _insert_block( jt.pop(jt.index(s)) else: jt.append(new_name) - self.add_block(block.replace_jump_targets(jump_targets=tuple(jt))) + self._jump_targets[name] = jt def insert_SyntheticExit( self, new_name: str, predecessors: Set[str], successors: Set[str], @@ -298,8 +327,7 @@ def insert_block_and_control_blocks( # Replace any arcs from any of predecessors to any of successors with # an arc through the to be inserted block instead. for name in predecessors: - block = self.graph[name] - jt = list(block.jump_targets) + jt = list(self.jump_targets[name]) # Need to create synthetic assignments for each arc from a # predecessors to a successor and insert it between the predecessor # and the newly created block @@ -309,12 +337,10 @@ def insert_block_and_control_blocks( variable_assignment[branch_variable] = branch_variable_value synth_assign_block = SyntheticAssignment( name=synth_assign, - _jump_targets=(new_name,), - backedges=(), variable_assignment=variable_assignment, ) # add block - self.add_block(synth_assign_block) + self.add_block(synth_assign_block, [new_name], []) # update branching table branch_value_table[branch_variable_value] = s # update branching variable @@ -322,19 +348,15 @@ def insert_block_and_control_blocks( # replace previous successor with synth_assign jt[jt.index(s)] = synth_assign # finally, replace the jump_targets - self.add_block( - self.graph.pop(name).replace_jump_targets(jump_targets=tuple(jt)) - ) + self._jump_targets[name] = jt # initialize new block, which will hold the branching table new_block = SyntheticHead( name=new_name, - _jump_targets=tuple(successors), - backedges=set(), variable=branch_variable, branch_value_table=branch_value_table, ) # add block to self - self.add_block(new_block) + self.add_block(new_block, successors, []) def join_returns(self): """Close the CFG. @@ -343,7 +365,7 @@ def join_returns(self): predescessors and no successors respectively. """ # for all nodes that contain a return - return_nodes = [node for node in self.graph if self.graph[node].is_exiting] + return_nodes = [node for node in self.blocks if self.is_exiting(node)] # close if more than one is found if len(return_nodes) > 1: return_solo_name = self.name_gen.new_block_name(block_names.SYNTH_RETURN) @@ -390,8 +412,8 @@ def from_yaml(yaml_string): @staticmethod def from_dict(graph_dict: dict): - scfg_graph = {} - name_gen = NameGenerator() + scfg = SCFG() + name_gen = scfg.name_gen block_dict = {} for index in graph_dict.keys(): block_dict[index] = name_gen.new_block_name(block_names.BASIC) @@ -399,24 +421,21 @@ def from_dict(graph_dict: dict): jump_targets = attributes["jt"] backedges = attributes.get("be", ()) name = block_dict[index] - block = BasicBlock( - name=name, - backedges=tuple(block_dict[idx] for idx in backedges), - _jump_targets=tuple(block_dict[idx] for idx in jump_targets), - ) - scfg_graph[name] = block - scfg = SCFG(scfg_graph, name_gen=name_gen) + block = BasicBlock(name=name) + backedges=tuple(block_dict[idx] for idx in backedges) + jump_targets=tuple(block_dict[idx] for idx in jump_targets) + scfg.add_block(block, jump_targets, backedges) return scfg, block_dict def to_yaml(self): # Convert to yaml - scfg_graph = self.graph + scfg_graph = self.blocks yaml_string = """""" for key, value in scfg_graph.items(): - jump_targets = [i for i in value._jump_targets] + jump_targets = [i for i in self._jump_targets[key]] jump_targets = str(jump_targets).replace("\'", "\"") - back_edges = [i for i in value.backedges] + back_edges = [i for i in self.back_edges[key]] jump_target_str = f""" "{key}": jt: {jump_targets}""" @@ -430,13 +449,13 @@ def to_yaml(self): return yaml_string def to_dict(self): - scfg_graph = self.graph + scfg_graph = self.blocks graph_dict = {} for key, value in scfg_graph.items(): curr_dict = {} - curr_dict["jt"] = [i for i in value._jump_targets] - if value.backedges: - curr_dict["be"] = [i for i in value.backedges] + curr_dict["jt"] = [i for i in self._jump_targets[key]] + if self.back_edges[key]: + curr_dict["be"] = [i for i in self.back_edges[key]] graph_dict[key] = curr_dict return graph_dict @@ -455,7 +474,7 @@ def __len__(self): class ConcealedRegionView(AbstractGraphView): - def __init__(self, scfg): + def __init__(self, scfg: SCFG): self.scfg = scfg def __getitem__(self, item): @@ -508,10 +527,10 @@ def region_view_iterator(self, head: str = None) -> Iterator[str]: # If this is a region, continue on to the exiting block, i.e. # the region is presented a single fall-through block to the # consumer of this iterator. - to_visit.extend(block.subregion[block.exiting].jump_targets) + to_visit.extend(self.scfg.jump_targets[name]) else: - # otherwise add any jump_targets to the list of names to visit - to_visit.extend(block.jump_targets) + # otherwise add any outgoing edges to the list of names to visit + to_visit.extend(self.scfg.jump_targets[name]) # finally, yield the name yield name diff --git a/numba_rvsdg/core/transformations.py b/numba_rvsdg/core/transformations.py index eb9ba5b..78fed95 100644 --- a/numba_rvsdg/core/transformations.py +++ b/numba_rvsdg/core/transformations.py @@ -49,11 +49,11 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]): # backedge to the loop header) we can exit early, since the condition for # SCFG is fullfilled. backedge_blocks = [ - block for block in loop if set(headers).intersection(scfg[block].jump_targets) + block for block in loop if set(headers).intersection(scfg.jump_targets[block]) ] if (len(backedge_blocks) == 1 and len(exiting_blocks) == 1 and backedge_blocks[0] == next(iter(exiting_blocks))): - scfg.add_block(scfg.graph.pop(backedge_blocks[0]).replace_backedge(loop_head)) + scfg.back_edges[backedge_blocks[0]] = [loop_head] return # The synthetic exiting latch and synthetic exit need to be created @@ -89,7 +89,7 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]): # This does a dictionary reverse lookup, to determine the key for a given # value. - def reverse_lookup(d, value): + def reverse_lookup(d: dict, value): for k, v in d.items(): if v == value: return k @@ -106,9 +106,9 @@ def reverse_lookup(d, value): # If the block is an exiting block or a backedge block if name in exiting_blocks or name in backedge_blocks: # Copy the jump targets, these will be modified - new_jt = list(scfg[name].jump_targets) + new_jt = list(scfg.jump_targets[name]) # For each jump_target in the blockj - for jt in scfg[name].jump_targets: + for jt in scfg.jump_targets[name]: # If the target is an exit block if jt in exit_blocks: # Create a new assignment name and record it @@ -125,12 +125,10 @@ def reverse_lookup(d, value): # Create the actual control variable block synth_assign_block = SyntheticAssignment( name=synth_assign, - _jump_targets=(synth_exiting_latch,), - backedges=(), variable_assignment=variable_assignment, ) # Insert the assignment to the scfg - scfg.add_block(synth_assign_block) + scfg.add_block(synth_assign_block, [synth_exiting_latch], []) # Insert the new block into the new jump_targets making # sure, that it replaces the correct jump_target, order # matters in this case. @@ -150,52 +148,45 @@ def reverse_lookup(d, value): # that point to the headers, no need to add a backedge, # since it will be contained in the SyntheticExitingLatch # later on. - block = scfg.graph.pop(name) - jts = list(block.jump_targets) + jts = list(scfg.jump_targets[name]) for h in headers: if h in jts: jts.remove(h) - scfg.add_block(block.replace_jump_targets(jump_targets=tuple(jts))) + + scfg.jump_targets[name] = jts # Setup the assignment block and initialize it with the # correct jump_targets and variable assignment. synth_assign_block = SyntheticAssignment( name=synth_assign, - _jump_targets=(synth_exiting_latch,), - backedges=(), variable_assignment=variable_assignment, ) # Add the new block to the SCFG - scfg.add_block(synth_assign_block) + scfg.add_block(synth_assign_block, [synth_exiting_latch], []) # Update the jump targets again, order matters new_jt[new_jt.index(jt)] = synth_assign # finally, replace the jump_targets for this block with the new ones - scfg.add_block( - scfg.graph.pop(name).replace_jump_targets(jump_targets=tuple(new_jt)) - ) + scfg._jump_targets[name] = new_jt # Add any new blocks to the loop. loop.update(new_blocks) # Insert the exiting latch, add it to the loop and to the graph. synth_exiting_latch_block = SyntheticExitingLatch( name=synth_exiting_latch, - _jump_targets=(synth_exit if needs_synth_exit else next(iter(exit_blocks)), loop_head), - backedges=(loop_head,), variable=backedge_variable, branch_value_table=backedge_value_table, ) loop.add(synth_exiting_latch) - scfg.add_block(synth_exiting_latch_block) + jump_targets = [synth_exit if needs_synth_exit else next(iter(exit_blocks)), loop_head] + scfg.add_block(synth_exiting_latch_block, jump_targets, [loop_head]) # If an exit is to be created, we do so too, but only add it to the scfg, # since it isn't part of the loop if needs_synth_exit: synth_exit_block = SyntheticExitBranch( name=synth_exit, - _jump_targets=tuple(exit_blocks), - backedges=(), variable=exit_variable, branch_value_table=exit_value_table, ) - scfg.add_block(synth_exit_block) + scfg.add_block(synth_exit_block, exit_blocks, []) def restructure_loop(scfg: SCFG): @@ -210,11 +201,11 @@ def restructure_loop(scfg: SCFG): loops: List[Set[str]] = [ nodes for nodes in scc - if len(nodes) > 1 or next(iter(nodes)) in scfg[next(iter(nodes))].jump_targets + if len(nodes) > 1 or next(iter(nodes)) in scfg.jump_targets[next(iter(nodes))] ] _logger.debug( - "restructure_loop found %d loops in %s", len(loops), scfg.graph.keys() + "restructure_loop found %d loops in %s", len(loops), scfg.blocks.keys() ) # rotate and extract loop for loop in loops: @@ -233,7 +224,7 @@ def find_head_blocks(scfg: SCFG, begin: str) -> Set[str]: if current_block == begin: break else: - jt = scfg.graph[current_block].jump_targets + jt = scfg.jump_targets[current_block] assert len(jt) == 1 current_block = next(iter(jt)) return head_region_blocks @@ -246,7 +237,7 @@ def find_branch_regions(scfg: SCFG, begin: str, end: str) -> Set[str]: postimmdoms = _imm_doms(postdoms) immdoms = _imm_doms(doms) branch_regions = [] - jump_targets = scfg.graph[begin].jump_targets + jump_targets = scfg.jump_targets[begin] for bra_start in jump_targets: for jt in jump_targets: if jt != bra_start and scfg.is_reachable_dfs(jt, bra_start): @@ -276,7 +267,7 @@ def _find_branch_regions(scfg: SCFG, begin: str, end: str) -> Set[str]: def find_tail_blocks( scfg: SCFG, begin: Set[str], head_region_blocks, branch_regions ): - tail_subregion = set((b for b in scfg.graph.keys())) + tail_subregion = set((b for b in scfg.blocks.keys())) tail_subregion.difference_update(head_region_blocks) for reg in branch_regions: if not reg: @@ -290,7 +281,7 @@ def find_tail_blocks( return tail_subregion -def extract_region(scfg, region_blocks, region_kind): +def extract_region(scfg: SCFG, region_blocks, region_kind): headers, entries = scfg.find_headers_and_entries(region_blocks) exiting_blocks, exit_blocks = scfg.find_exiting_and_exits(region_blocks) assert len(headers) == 1 @@ -298,30 +289,32 @@ def extract_region(scfg, region_blocks, region_kind): region_header = next(iter(headers)) region_exiting = next(iter(exiting_blocks)) - head_subgraph = SCFG( - {name: scfg.graph[name] for name in region_blocks}, name_gen=scfg.name_gen - ) + head_subgraph = SCFG(name_gen=scfg.name_gen) + for name in region_blocks: + head_subgraph.add_block(scfg.blocks[name], scfg.jump_targets[name], scfg.back_edges[name]) if isinstance(scfg[region_exiting], RegionBlock): - region_exiting = scfg[region_exiting].exiting + region_exiting = scfg.is_exiting(region_exiting) else: region_exiting = region_exiting subregion = RegionBlock( name=region_header, - _jump_targets=scfg[region_exiting].jump_targets, - backedges=(), kind=region_kind, headers=headers, subregion=head_subgraph, exiting=region_exiting, ) + + region_exit_paths = scfg.jump_targets[region_exiting] scfg.remove_blocks(region_blocks) - scfg.graph[region_header] = subregion + scfg.back_edges[region_header] = [] + scfg.blocks[region_header] = subregion + scfg._jump_targets[region_header] = region_exit_paths def restructure_branch(scfg: SCFG): - print("restructure_branch", scfg.graph) + print("restructure_branch", scfg.blocks) doms = _doms(scfg) postdoms = _post_doms(scfg) postimmdoms = _imm_doms(postdoms) @@ -393,7 +386,7 @@ def _iter_branch_regions( scfg: SCFG, immdoms: Dict[str, str], postimmdoms: Dict[str, str] ): for begin, node in scfg.concealed_region_view.items(): - if len(node.jump_targets) > 1: + if len(scfg.jump_targets[begin]) > 1: # found branch if begin in postimmdoms: end = postimmdoms[begin] @@ -428,41 +421,41 @@ def _doms(scfg: SCFG): succs_table = defaultdict(set) node: BasicBlock - for src, node in scfg.graph.items(): - for dst in node.jump_targets: + for src, node in scfg.blocks.items(): + for dst in scfg.jump_targets[src]: # check dst is in subgraph - if dst in scfg.graph: + if dst in scfg.blocks: preds_table[dst].add(src) succs_table[src].add(dst) - for k in scfg.graph: + for k in scfg.blocks: if not preds_table[k]: entries.add(k) return _find_dominators_internal( - entries, list(scfg.graph.keys()), preds_table, succs_table + entries, list(scfg.blocks.keys()), preds_table, succs_table ) def _post_doms(scfg: SCFG): # compute post dom entries = set() - for k, v in scfg.graph.items(): - targets = set(v.jump_targets) & set(scfg.graph) + for k, v in scfg.blocks.items(): + targets = set(scfg.jump_targets[k]) & set(scfg.blocks) if not targets: entries.add(k) preds_table = defaultdict(set) succs_table = defaultdict(set) node: BasicBlock - for src, node in scfg.graph.items(): - for dst in node.jump_targets: + for src, node in scfg.blocks.items(): + for dst in scfg.jump_targets[src]: # check dst is in subgraph - if dst in scfg.graph: + if dst in scfg.blocks: preds_table[src].add(dst) succs_table[dst].add(src) return _find_dominators_internal( - entries, list(scfg.graph.keys()), preds_table, succs_table + entries, list(scfg.blocks.keys()), preds_table, succs_table ) diff --git a/numba_rvsdg/rendering/rendering.py b/numba_rvsdg/rendering/rendering.py index 2cb127d..7be87d5 100644 --- a/numba_rvsdg/rendering/rendering.py +++ b/numba_rvsdg/rendering/rendering.py @@ -40,15 +40,15 @@ def render_region_block( for name, block in graph.items(): self.render_block(subg, name, block) # render edges within this region - self.render_edges(graph) + self.render_edges(regionblock.subregion) def render_basic_block(self, digraph: "Digraph", name: str, block: BasicBlock): if name.startswith('python_bytecode'): - instlist = block.get_instructions(self.bcmap) + # instlist = block.get_instructions(self.bcmap) body = name + "\l" - body += "\l".join( - [f"{inst.offset:3}: {inst.opname}" for inst in instlist] + [""] - ) + # body += "\l".join( + # [f"{inst.offset:3}: {inst.opname}" for inst in instlist] + [""] + # ) else: body = name + "\l" @@ -95,10 +95,10 @@ def render_block(self, digraph: "Digraph", name: str, block: BasicBlock): else: raise Exception("unreachable") - def render_edges(self, blocks: Dict[str, BasicBlock]): - for name, block in blocks.items(): - for dst in block.jump_targets: - if dst in blocks: + def render_edges(self, scfg: SCFG): + for name, block in scfg.blocks.items(): + for dst in scfg.jump_targets[name]: + if dst in scfg.blocks: if type(block) in ( PythonBytecodeBlock, BasicBlock, @@ -118,7 +118,7 @@ def render_edges(self, blocks: Dict[str, BasicBlock]): self.g.edge(str(name), str(dst)) else: raise Exception("unreachable " + str(block)) - for dst in block.backedges: + for dst in scfg.back_edges[name]: # assert dst in blocks self.g.edge( str(name), str(dst), style="dashed", color="grey", constraint="0" @@ -128,16 +128,16 @@ def render_byteflow(self, byteflow: ByteFlow): self.bcmap_from_bytecode(byteflow.bc) # render nodes - for name, block in byteflow.scfg.graph.items(): + for name, block in byteflow.scfg.blocks.items(): self.render_block(self.g, name, block) - self.render_edges(byteflow.scfg.graph) + self.render_edges(byteflow.scfg) return self.g - def render_scfg(self, scfg): + def render_scfg(self, scfg: SCFG): # render nodes - for name, block in scfg.graph.items(): + for name, block in scfg.blocks.items(): self.render_block(self.g, name, block) - self.render_edges(scfg.graph) + self.render_edges(scfg) return self.g def bcmap_from_bytecode(self, bc: dis.Bytecode): diff --git a/numba_rvsdg/tests/simulator.py b/numba_rvsdg/tests/simulator.py index eea188e..e35448e 100644 --- a/numba_rvsdg/tests/simulator.py +++ b/numba_rvsdg/tests/simulator.py @@ -140,11 +140,11 @@ def run_BasicBlock(self, name: str): self.run_PythonBytecodeBlock(name) elif isinstance(block, SyntheticBlock): self.run_synth_block(name) - if block.fallthrough: - [name] = block.jump_targets + if self.region_stack[-1].subregion.is_fallthrough(name): + [name] = self.region_stack[-1].subregion.jump_targets[name] return {"jumpto": name} - elif len(block._jump_targets) == 2: - [br_false, br_true] = block._jump_targets + elif len(self.region_stack[-1].subregion._jump_targets[name]) == 2: + [br_false, br_true] = self.region_stack[-1].subregion._jump_targets[name] return {"jumpto": br_true if self.branch else br_false} else: return {"return": self.return_value} @@ -188,7 +188,7 @@ def run_RegionBlock(self, name: str): elif "jumpto" in action: name = action["jumpto"] # Otherwise check if we stay in the region and break otherwise - if name in region.subregion.graph: + if name in region.subregion.blocks.keys(): continue # stay in the region else: break # break and return action @@ -223,7 +223,7 @@ def run_synth_block(self, name: str): The str for the block. """ - print("----", name) + print("----RUNNING SYNTHETIC BLOCK------", name) print(f"control variable map: {self.ctrl_varmap}") block = self.get_block(name) handler = getattr(self, 'synth_' + block.__class__.__name__) @@ -250,9 +250,9 @@ def run_inst(self, inst: Instruction): def synth_SyntheticAssignment(self, control_name, block): self.ctrl_varmap.update(block.variable_assignment) - def _synth_branch(self, control_name, block): + def _synth_branch(self, control_name, block: BasicBlock): jump_target = block.branch_value_table[self.ctrl_varmap[block.variable]] - self.branch = bool(block._jump_targets.index(jump_target)) + self.branch = bool(self.region_stack[-1].subregion._jump_targets[block.name].index(jump_target)) def synth_SyntheticExitingLatch(self, control_name, block): self._synth_branch(control_name, block) diff --git a/numba_rvsdg/tests/test_byteflow.py b/numba_rvsdg/tests/test_byteflow.py index a0341fd..1f1ae13 100644 --- a/numba_rvsdg/tests/test_byteflow.py +++ b/numba_rvsdg/tests/test_byteflow.py @@ -106,42 +106,38 @@ def test(self): class TestPythonBytecodeBlock(unittest.TestCase): def test_constructor(self): - name_gen = NameGenerator() + scfg = SCFG() block = PythonBytecodeBlock( - name=name_gen.new_block_name(block_names.PYTHON_BYTECODE), + name=scfg.name_gen.new_block_name(block_names.PYTHON_BYTECODE), begin=0, - end=8, - _jump_targets=(), - backedges=(), + end=8 ) + scfg.add_block(block, [], []) self.assertEqual(block.name, 'python_bytecode_block_0') self.assertEqual(block.begin, 0) self.assertEqual(block.end, 8) - self.assertFalse(block.fallthrough) - self.assertTrue(block.is_exiting) - self.assertEqual(block.jump_targets, ()) - self.assertEqual(block.backedges, ()) + self.assertFalse(scfg.is_fallthrough(block.name)) + self.assertTrue(scfg.is_exiting(block.name)) + self.assertEqual(scfg.jump_targets[block.name], []) + self.assertEqual(scfg.back_edges[block.name], []) def test_is_jump_target(self): - name_gen = NameGenerator() + scfg = SCFG() block = PythonBytecodeBlock( - name=name_gen.new_block_name(block_names.PYTHON_BYTECODE), + name=scfg.name_gen.new_block_name(block_names.PYTHON_BYTECODE), begin=0, - end=8, - _jump_targets=(name_gen.new_block_name(block_names.PYTHON_BYTECODE),), - backedges=(), + end=8 ) - self.assertEqual(block.jump_targets, ('python_bytecode_block_1',)) - self.assertFalse(block.is_exiting) + scfg.add_block(block, [scfg.name_gen.new_block_name(block_names.PYTHON_BYTECODE)], []) + self.assertEqual(scfg.jump_targets[block.name], ['python_bytecode_block_1']) + self.assertFalse(scfg.is_exiting(block.name)) def test_get_instructions(self): name_gen = NameGenerator() block = PythonBytecodeBlock( name=name_gen.new_block_name(block_names.PYTHON_BYTECODE), begin=0, - end=8, - _jump_targets=(), - backedges=(), + end=8 ) expected = [ Instruction( @@ -228,19 +224,14 @@ def test_from_bytecode(self): self.assertEqual(expected, received) def test_build_basic_blocks(self): - name_gen = NameGenerator() - new_name = name_gen.new_block_name(block_names.PYTHON_BYTECODE) - expected = SCFG( - graph={ - new_name: PythonBytecodeBlock( + expected = SCFG() + new_name = expected.name_gen.new_block_name(block_names.PYTHON_BYTECODE) + block = PythonBytecodeBlock( name=new_name, begin=0, - end=10, - _jump_targets=(), - backedges=(), + end=10 ) - } - ) + expected.add_block(block, [], []) received = FlowInfo.from_bytecode(bytecode).build_basicblocks() self.assertEqual(expected, received) @@ -252,19 +243,14 @@ def test_constructor(self): self.assertEqual(len(byteflow.scfg), 0) def test_from_bytecode(self): - name_gen = NameGenerator() - new_name = name_gen.new_block_name(block_names.PYTHON_BYTECODE) - scfg = SCFG( - graph={ - new_name: PythonBytecodeBlock( + scfg = SCFG() + new_name = scfg.name_gen.new_block_name(block_names.PYTHON_BYTECODE) + block = PythonBytecodeBlock( name=new_name, begin=0, - end=10, - _jump_targets=(), - backedges=(), + end=10 ) - } - ) + scfg.add_block(block, [], []) expected = ByteFlow(bc=bytecode, scfg=scfg) received = ByteFlow.from_bytecode(fun) self.assertEqual(expected.scfg, received.scfg) diff --git a/numba_rvsdg/tests/test_scfg.py b/numba_rvsdg/tests/test_scfg.py index 715cac9..f8bd1e3 100644 --- a/numba_rvsdg/tests/test_scfg.py +++ b/numba_rvsdg/tests/test_scfg.py @@ -112,8 +112,7 @@ def test_scfg_iter(self): block_0 = name_gen.new_block_name(block_names.BASIC) block_1 = name_gen.new_block_name(block_names.BASIC) expected = [ - (block_0, BasicBlock(name=block_0, - _jump_targets=(block_1,))), + (block_0, BasicBlock(name=block_0)), (block_1, BasicBlock(name=block_1)), ] scfg, _ = SCFG.from_yaml(""" diff --git a/numba_rvsdg/tests/test_simulate.py b/numba_rvsdg/tests/test_simulate.py index 68f8f44..2c09773 100644 --- a/numba_rvsdg/tests/test_simulate.py +++ b/numba_rvsdg/tests/test_simulate.py @@ -190,5 +190,7 @@ def foo(s, e): self._run(foo, flow, {"s": 23, "e": 28}) -if __name__ == "__main__": - unittest.main() +# if __name__ == "__main__": +# unittest.main() +x = SimulatorTest() +x.test_for_loop_with_exit() diff --git a/numba_rvsdg/tests/test_utils.py b/numba_rvsdg/tests/test_utils.py index 778753d..95b37b9 100644 --- a/numba_rvsdg/tests/test_utils.py +++ b/numba_rvsdg/tests/test_utils.py @@ -17,7 +17,7 @@ def assertSCFGEqual(self, first_scfg: SCFG, second_scfg: SCFG, head_map=None): stack = [first_head] # Assert number of blocks are equal in both SCFGs - assert len(first_scfg.graph) == len(second_scfg.graph), "Number of blocks in both graphs are not equal" + assert len(first_scfg.blocks) == len(second_scfg.blocks), "Number of blocks in both graphs are not equal" seen = set() while stack: @@ -32,17 +32,17 @@ def assertSCFGEqual(self, first_scfg: SCFG, second_scfg: SCFG, head_map=None): second_node_name = block_mapping[node_name] second_node: BasicBlock = second_scfg[second_node_name] # Both nodes should have equal number of jump targets and backedges - assert len(node.jump_targets) == len(second_node.jump_targets) - assert len(node.backedges) == len(second_node.backedges) + assert len(first_scfg.jump_targets[node_name]) == len(second_scfg.jump_targets[second_node_name]) + assert len(first_scfg.back_edges[node_name]) == len(second_scfg.back_edges[second_node_name]) # Add the jump targets as corresponding nodes in block mapping dictionary # Since order must be same we can simply add zip fucntionality as the # correspondence function for nodes - for jt1, jt2 in zip(node.jump_targets, second_node.jump_targets): + for jt1, jt2 in zip(first_scfg.jump_targets[node_name], second_scfg.jump_targets[second_node_name]): block_mapping[jt1] = jt2 stack.append(jt1) - for be1, be2 in zip(node.backedges, second_node.backedges): + for be1, be2 in zip(first_scfg.back_edges[node_name], second_scfg.back_edges[second_node_name]): block_mapping[be1] = be2 stack.append(be1)