Skip to content

Commit

Permalink
feat: allow edges between (dominating) basic blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 17, 2024
1 parent 2a7e7e6 commit c8c3bcf
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 8 deletions.
22 changes: 21 additions & 1 deletion hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from typing import Iterable, Sequence
from ._hugr import Hugr, Node, Wire
from ._dfg import DfBase, _from_base
from ._tys import FunctionType, TypeRow, Sum
from ._exceptions import NoSiblingAncestor, NotInSameCfg
import hugr._ops as ops


Expand All @@ -15,6 +16,25 @@ def single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError

Check warning on line 17 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L17

Added line #L17 was not covered by tests

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)

Check warning on line 33 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L33

Added line #L33 was not covered by tests
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))


@dataclass
class Cfg:
Expand Down
17 changes: 10 additions & 7 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ def add_state_order(self, src: Node, dst: Node) -> None:

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(i))
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))


C = TypeVar("C", bound=DfBase)
Expand Down
10 changes: 10 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,15 @@ def msg(self):
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."


@dataclass
class NotInSameCfg(Exception):
src: int
tgt: int

@property
def msg(self):
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."

Check warning on line 21 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L21

Added line #L21 was not covered by tests


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
22 changes: 22 additions & 0 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,25 @@ def test_nested_cfg() -> None:
dfg.set_outputs(cfg.root)

_validate(dfg.hugr)


def test_dom_edge() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [INT_T])
b, u, i = entry.inputs()
entry.set_block_outputs(b, i)

# entry dominates both middles so Unit type can be used as inter-graph
# value between basic blocks
middle_1 = cfg.simple_block([INT_T], 1, [INT_T])
middle_1.set_block_outputs(u, *middle_1.inputs())
middle_2 = cfg.simple_block([INT_T], 1, [INT_T])
middle_2.set_block_outputs(u, *middle_2.inputs())

cfg.branch(entry.root.out(0), middle_1.root)
cfg.branch(entry.root.out(1), middle_2.root)

cfg.branch(middle_1.root.out(0), cfg.exit)
cfg.branch(middle_2.root.out(0), cfg.exit)

_validate(cfg.hugr)

0 comments on commit c8c3bcf

Please sign in to comment.