From c8c3bcfac798c59f55b959a9831092848d7eb737 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 12:33:51 +0100 Subject: [PATCH] feat: allow edges between (dominating) basic blocks --- hugr-py/src/hugr/_cfg.py | 22 +++++++++++++++++++++- hugr-py/src/hugr/_dfg.py | 17 ++++++++++------- hugr-py/src/hugr/_exceptions.py | 10 ++++++++++ hugr-py/tests/test_cfg.py | 22 ++++++++++++++++++++++ 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 4af60bb4f..0151aeab2 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,9 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence +from typing import Iterable, Sequence from ._hugr import Hugr, Node, Wire from ._dfg import DfBase, _from_base from ._tys import FunctionType, TypeRow, Sum +from ._exceptions import NoSiblingAncestor, NotInSameCfg import hugr._ops as ops @@ -15,6 +16,25 @@ def single_successor_outputs(self, *outputs: Wire) -> None: # TODO requires constants raise NotImplementedError + def _wire_up(self, node: Node, ports: Iterable[Wire]): + for i, p in enumerate(ports): + src = p.out_port() + cfg_node = self.hugr[self.root].parent + assert cfg_node is not None + src_parent = self.hugr[src.node].parent + try: + self._wire_up_port(node, i, p) + except NoSiblingAncestor: + # note this just checks if there is a common CFG ancestor + # it does not check for valid dominance between basic blocks + # that is deferred to full HUGR validation. + while cfg_node != src_parent: + if src_parent is None or src_parent == self.hugr.root: + raise NotInSameCfg(src.node.idx, node.idx) + src_parent = self.hugr[src_parent].parent + + self.hugr.add_link(src, node.inp(i)) + @dataclass class Cfg: diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 5f10b5029..17499a7b9 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -97,13 +97,16 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): - src = p.out_port() - node_ancestor = _ancestral_sibling(self.hugr, src.node, node) - if node_ancestor is None: - raise NoSiblingAncestor(src.node.idx, node.idx) - if node_ancestor != node: - self.add_state_order(src.node, node_ancestor) - self.hugr.add_link(src, node.inp(i)) + self._wire_up_port(node, i, p) + + def _wire_up_port(self, node: Node, offset: int, p: Wire): + src = p.out_port() + node_ancestor = _ancestral_sibling(self.hugr, src.node, node) + if node_ancestor is None: + raise NoSiblingAncestor(src.node.idx, node.idx) + if node_ancestor != node: + self.add_state_order(src.node, node_ancestor) + self.hugr.add_link(src, node.inp(offset)) C = TypeVar("C", bound=DfBase) diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index d59d99972..92ba7ceb0 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -11,5 +11,15 @@ def msg(self): return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." +@dataclass +class NotInSameCfg(Exception): + src: int + tgt: int + + @property + def msg(self): + return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." + + class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 20372b328..457334526 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -47,3 +47,25 @@ def test_nested_cfg() -> None: dfg.set_outputs(cfg.root) _validate(dfg.hugr) + + +def test_dom_edge() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [INT_T]) + b, u, i = entry.inputs() + entry.set_block_outputs(b, i) + + # entry dominates both middles so Unit type can be used as inter-graph + # value between basic blocks + middle_1 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_1.set_block_outputs(u, *middle_1.inputs()) + middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_2.set_block_outputs(u, *middle_2.inputs()) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr)