From 2a6fb563f40e97ff76b8cc741312d6f91f14466d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 12 Jun 2024 15:39:38 +0100 Subject: [PATCH 1/6] feat: generalise dataflow parent builder --- hugr-py/src/hugr/_dfg.py | 89 ++++++++++++++++++++++++++++----------- hugr-py/src/hugr/_hugr.py | 10 ++--- hugr-py/src/hugr/_ops.py | 67 ++++++++++++++++++++++++++++- 3 files changed, 135 insertions(+), 31 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 4c4ab3da7..b371e0ba4 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,51 +1,59 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, Iterable +from typing import Sequence, Iterable, TYPE_CHECKING, Generic, TypeVar, cast +import typing from ._hugr import Hugr, Node, Wire, OutPort -from ._ops import Op, Command, Input, Output, DFG +import hugr._ops as ops from ._exceptions import NoSiblingAncestor -from hugr._tys import FunctionType, Type +from hugr._tys import FunctionType, Type, TypeRow + +if TYPE_CHECKING: + from ._cfg import Cfg + + +DP = TypeVar("DP", bound=ops.DfParentOp) @dataclass() -class Dfg: +class DfBase(Generic[DP]): hugr: Hugr root: Node input_node: Node output_node: Node - def __init__( - self, input_types: Sequence[Type], output_types: Sequence[Type] - ) -> None: - input_types = list(input_types) - output_types = list(output_types) - root_op = DFG(FunctionType(input=input_types, output=output_types)) + def __init__(self, root_op: DP) -> None: + input_types = root_op.input_types() + output_types = root_op.output_types() self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - Input(input_types), self.root, len(input_types) + ops.Input(input_types), self.root, len(input_types) ) - self.output_node = self.hugr.add_node(Output(output_types), self.root) + self.output_node = self.hugr.add_node(ops.Output(output_types), self.root) - @classmethod - def endo(cls, types: Sequence[Type]) -> Dfg: - return Dfg(types, types) - - def _input_op(self) -> Input: + def _input_op(self) -> ops.Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, Input) + assert isinstance(dop, ops.Input) return dop + def _output_op(self) -> ops.Output: + dop = self.hugr[self.output_node].op + assert isinstance(dop, ops.Output) + return dop + + def root_op(self) -> DP: + return cast(DP, self.hugr[self.root].op) + def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: + def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) self._wire_up(new_n, args) return new_n - def add(self, com: Command) -> Node: + def add(self, com: ops.Command) -> Node: return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: @@ -55,13 +63,20 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: def add_nested( self, - input_types: Sequence[Type], - output_types: Sequence[Type], + input_types: TypeRow, + output_types: TypeRow, *args: Wire, ) -> Dfg: - dfg = self.hugr.add_dfg(input_types, output_types) + dfg = self.hugr.add_dfg( + ops.DFG(FunctionType(input=input_types, output=output_types)) + ) self._wire_up(dfg.root, args) - return dfg + return _from_base(Dfg, dfg) + + def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: + mapping = self.hugr.insert_hugr(cfg.hugr, self.root) + self._wire_up(mapping[cfg.root], args) + return mapping[cfg.root] def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) @@ -81,6 +96,32 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]): self.hugr.add_link(src, node.inp(i)) +C = TypeVar("C", bound=DfBase) + + +def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C: + new = cls.__new__(cls) + new.hugr = base.hugr + new.root = base.root + new.input_node = base.input_node + new.output_node = base.output_node + return new + + +class Dfg(DfBase[ops.DFG]): + def __init__( + self, input_types: Sequence[Type], output_types: Sequence[Type] + ) -> None: + input_types = list(input_types) + output_types = list(output_types) + root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) + super().__init__(root_op) + + @classmethod + def endo(cls, types: Sequence[Type]) -> Dfg: + return cls(types, types) + + def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: src_parent = h[src].parent diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 45b3c180c..40882d0bc 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -10,7 +10,6 @@ Iterable, Iterator, Protocol, - Sequence, TypeVar, cast, overload, @@ -19,7 +18,6 @@ from typing_extensions import Self from hugr._ops import Op -from hugr._tys import Type from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -27,7 +25,7 @@ from ._exceptions import ParentBeforeChild if TYPE_CHECKING: - from ._dfg import Dfg + from ._dfg import DfBase, DP class Direction(Enum): @@ -337,10 +335,10 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node ) return mapping - def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg: - from ._dfg import Dfg + def add_dfg(self, root_op: DP) -> DfBase[DP]: + from ._dfg import DfBase - dfg = Dfg(input_types, output_types) + dfg = DfBase(root_op) mapping = self.insert_hugr(dfg.hugr, self.root) dfg.hugr = self dfg.input_node = mapping[dfg.input_node] diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 9724162dd..6b6986081 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -121,8 +121,13 @@ def __call__(self, tuple_: Wire) -> Command: return super().__call__(tuple_) +class DfParentOp(Op, Protocol): + def input_types(self) -> tys.TypeRow: ... + def output_types(self) -> tys.TypeRow: ... + + @dataclass() -class DFG(Op): +class DFG(DfParentOp): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @property @@ -134,3 +139,63 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: parent=parent.idx, signature=self.signature.to_serial(), ) + + def input_types(self) -> tys.TypeRow: + return self.signature.input + + def output_types(self) -> tys.TypeRow: + return self.signature.output + + +@dataclass() +class CFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + @property + def num_out(self) -> int | None: + return len(self.signature.output) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: + return sops.CFG( + parent=parent.idx, + signature=self.signature.to_serial(), + ) + + +@dataclass +class DataflowBlock(DfParentOp): + inputs: tys.TypeRow + sum_rows: list[tys.TypeRow] + other_outputs: tys.TypeRow = field(default_factory=list) + extension_delta: tys.ExtensionSet = field(default_factory=list) + + @property + def num_out(self) -> int | None: + return len(self.sum_rows) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: + return sops.DataflowBlock( + parent=parent.idx, + inputs=ser_it(self.inputs), + sum_rows=list(map(ser_it, self.sum_rows)), + other_outputs=ser_it(self.other_outputs), + extension_delta=self.extension_delta, + ) + + def input_types(self) -> tys.TypeRow: + return self.inputs + + def output_types(self) -> tys.TypeRow: + return [tys.Sum(self.sum_rows), *self.other_outputs] + + +@dataclass +class ExitBlock(Op): + cfg_outputs: tys.TypeRow + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: + return sops.ExitBlock( + parent=parent.idx, + cfg_outputs=ser_it(self.cfg_outputs), + ) From 86113166a55e449a0cd4de0df66bd57dfa5ccb0c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 12 Jun 2024 16:12:20 +0100 Subject: [PATCH 2/6] feat(hugr-py): CFG builder Closes #1188 --- hugr-py/src/hugr/_cfg.py | 74 ++++++++++++++++++++++++++++++++ hugr-py/src/hugr/_dfg.py | 10 +++++ hugr-py/src/hugr/_hugr.py | 18 ++++++++ hugr-py/src/hugr/_tys.py | 1 + hugr-py/tests/test_cfg.py | 49 +++++++++++++++++++++ hugr-py/tests/test_hugr_build.py | 2 +- 6 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 hugr-py/src/hugr/_cfg.py create mode 100644 hugr-py/tests/test_cfg.py diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py new file mode 100644 index 000000000..9a2452c4a --- /dev/null +++ b/hugr-py/src/hugr/_cfg.py @@ -0,0 +1,74 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Sequence +from ._hugr import Hugr, Node, Wire +from ._dfg import DfBase, _from_base +from ._tys import Type, FunctionType, TypeRow, Sum +import hugr._ops as ops + + +class Block(DfBase[ops.DataflowBlock]): + def block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: + self.set_outputs(branching, *other_outputs) + + def single_successor_outputs(self, *outputs: Wire) -> None: + # TODO requires constants + raise NotImplementedError + + +@dataclass +class Cfg: + hugr: Hugr + root: Node + _entry_block: Block + exit: Node + + def __init__( + self, input_types: Sequence[Type], output_types: Sequence[Type] + ) -> None: + input_types = list(input_types) + output_types = list(output_types) + root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) + self.hugr = Hugr(root_op) + self.root = self.hugr.root + # to ensure entry is first child, add a dummy entry at the start + self._entry_block = _from_base( + Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, [])) + ) + + self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root) + + @property + def entry(self) -> Node: + return self._entry_block.root + + def _entry_op(self) -> ops.DataflowBlock: + dop = self.hugr[self.entry].op + assert isinstance(dop, ops.DataflowBlock) + return dop + + def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: + # update entry block types + self._entry_op().sum_rows = list(sum_rows) + self._entry_op().other_outputs = other_outputs + self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs] + return self._entry_block + + def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block: + return self.add_entry([[]] * n_branches, other_outputs) + + def add_block( + self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow + ) -> Block: + new_block = self.hugr.add_dfg( + ops.DataflowBlock(input_types, list(sum_rows), other_outputs) + ) + return _from_base(Block, new_block) + + def simple_block( + self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow + ) -> Block: + return self.add_block(input_types, [[]] * n_branches, other_outputs) + + def branch(self, src: Wire, dst: Node) -> None: + self.hugr.add_link(src.out_port(), dst.inp(0)) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index b371e0ba4..54c5c6b58 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -73,6 +73,16 @@ def add_nested( self._wire_up(dfg.root, args) return _from_base(Dfg, dfg) + def add_cfg( + self, + input_types: Sequence[Type], + output_types: Sequence[Type], + *args: Wire, + ) -> Cfg: + cfg = self.hugr.add_cfg(input_types, output_types) + self._wire_up(cfg.root, args) + return cfg + def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(cfg.hugr, self.root) self._wire_up(mapping[cfg.root], args) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 40882d0bc..a35b87f4a 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -13,11 +13,13 @@ TypeVar, cast, overload, + Sequence, ) from typing_extensions import Self from hugr._ops import Op +from hugr._tys import Type from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -26,6 +28,7 @@ if TYPE_CHECKING: from ._dfg import DfBase, DP + from ._cfg import Cfg class Direction(Enum): @@ -346,6 +349,21 @@ def add_dfg(self, root_op: DP) -> DfBase[DP]: dfg.root = mapping[dfg.root] return dfg + def add_cfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Cfg: + from ._cfg import Cfg + + cfg = Cfg(input_types, output_types) + mapping = self.insert_hugr(cfg.hugr, self.root) + cfg.hugr = self + cfg._entry_block.root = mapping[cfg.entry] + cfg._entry_block.input_node = mapping[cfg._entry_block.input_node] + cfg._entry_block.output_node = mapping[cfg._entry_block.output_node] + cfg._entry_block.hugr = self + cfg.exit = mapping[cfg.exit] + cfg.root = mapping[cfg.root] + # TODO this is horrible + return cfg + def to_serial(self) -> SerialHugr: node_it = (node for node in self._nodes if node is not None) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 7cbd73f2f..6f199959c 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -269,3 +269,4 @@ def to_serial(self) -> stys.Qubit: Qubit = QubitDef() Bool = UnitSum(size=2) +Unit = UnitSum(size=1) diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py new file mode 100644 index 000000000..ce21186b3 --- /dev/null +++ b/hugr-py/tests/test_cfg.py @@ -0,0 +1,49 @@ +from hugr._cfg import Cfg +import hugr._tys as tys +from hugr._dfg import Dfg +from .test_hugr_build import _validate, INT_T, DivMod + + +def build_basic_cfg(cfg: Cfg) -> None: + entry = cfg.simple_entry(1, [tys.Bool]) + + entry.block_outputs(*entry.inputs()) + cfg.branch(entry.root.out(0), cfg.exit) + + +def test_basic_cfg() -> None: + cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool]) + build_basic_cfg(cfg) + _validate(cfg.hugr) + + +def test_branch() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [tys.Unit, INT_T]) + entry.block_outputs(*entry.inputs()) + + middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_1.block_outputs(*middle_1.inputs()) + middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + u, i = middle_2.inputs() + n = middle_2.add(DivMod(i, i)) + middle_2.block_outputs(u, n[0]) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr) + + +def test_nested_cfg() -> None: + dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool]) + + cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) + + build_basic_cfg(cfg) + dfg.set_outputs(cfg.root) + + _validate(dfg.hugr, True) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 37a3b624d..47dd9ce36 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -240,7 +240,7 @@ def test_build_inter_graph(): h.set_outputs(nested.root, b) - _validate(h.hugr, True) + _validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order From b7e9dc5f3dd02f17e8bc38b554b92a6e59d0204e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 11:15:06 +0100 Subject: [PATCH 3/6] refactor: replace `Sequence[Type]` with `TypeRow` --- hugr-py/src/hugr/_cfg.py | 8 ++------ hugr-py/src/hugr/_dfg.py | 16 ++++++---------- hugr-py/src/hugr/_hugr.py | 5 ++--- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 9a2452c4a..23e1e0808 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -3,7 +3,7 @@ from typing import Sequence from ._hugr import Hugr, Node, Wire from ._dfg import DfBase, _from_base -from ._tys import Type, FunctionType, TypeRow, Sum +from ._tys import FunctionType, TypeRow, Sum import hugr._ops as ops @@ -23,11 +23,7 @@ class Cfg: _entry_block: Block exit: Node - def __init__( - self, input_types: Sequence[Type], output_types: Sequence[Type] - ) -> None: - input_types = list(input_types) - output_types = list(output_types) + def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) self.hugr = Hugr(root_op) self.root = self.hugr.root diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 54c5c6b58..5f10b5029 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,12 +1,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, Iterable, TYPE_CHECKING, Generic, TypeVar, cast +from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast import typing from ._hugr import Hugr, Node, Wire, OutPort import hugr._ops as ops from ._exceptions import NoSiblingAncestor -from hugr._tys import FunctionType, Type, TypeRow +from hugr._tys import FunctionType, TypeRow if TYPE_CHECKING: from ._cfg import Cfg @@ -75,8 +75,8 @@ def add_nested( def add_cfg( self, - input_types: Sequence[Type], - output_types: Sequence[Type], + input_types: TypeRow, + output_types: TypeRow, *args: Wire, ) -> Cfg: cfg = self.hugr.add_cfg(input_types, output_types) @@ -119,16 +119,12 @@ def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C: class Dfg(DfBase[ops.DFG]): - def __init__( - self, input_types: Sequence[Type], output_types: Sequence[Type] - ) -> None: - input_types = list(input_types) - output_types = list(output_types) + def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) super().__init__(root_op) @classmethod - def endo(cls, types: Sequence[Type]) -> Dfg: + def endo(cls, types: TypeRow) -> Dfg: return cls(types, types) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index a35b87f4a..3bee26bd6 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -13,13 +13,12 @@ TypeVar, cast, overload, - Sequence, ) from typing_extensions import Self from hugr._ops import Op -from hugr._tys import Type +from hugr._tys import TypeRow from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -349,7 +348,7 @@ def add_dfg(self, root_op: DP) -> DfBase[DP]: dfg.root = mapping[dfg.root] return dfg - def add_cfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Cfg: + def add_cfg(self, input_types: TypeRow, output_types: TypeRow) -> Cfg: from ._cfg import Cfg cfg = Cfg(input_types, output_types) From 2a7e7e657580b6afb93c296f71434cbaf99a0837 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 11:17:17 +0100 Subject: [PATCH 4/6] minor review fixes --- hugr-py/src/hugr/_cfg.py | 2 +- hugr-py/tests/test_cfg.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 23e1e0808..4af60bb4f 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -8,7 +8,7 @@ class Block(DfBase[ops.DataflowBlock]): - def block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: + def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) def single_successor_outputs(self, *outputs: Wire) -> None: diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index ce21186b3..20372b328 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -7,7 +7,7 @@ def build_basic_cfg(cfg: Cfg) -> None: entry = cfg.simple_entry(1, [tys.Bool]) - entry.block_outputs(*entry.inputs()) + entry.set_block_outputs(*entry.inputs()) cfg.branch(entry.root.out(0), cfg.exit) @@ -20,14 +20,14 @@ def test_basic_cfg() -> None: def test_branch() -> None: cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) entry = cfg.simple_entry(2, [tys.Unit, INT_T]) - entry.block_outputs(*entry.inputs()) + entry.set_block_outputs(*entry.inputs()) middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) - middle_1.block_outputs(*middle_1.inputs()) + middle_1.set_block_outputs(*middle_1.inputs()) middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) u, i = middle_2.inputs() n = middle_2.add(DivMod(i, i)) - middle_2.block_outputs(u, n[0]) + middle_2.set_block_outputs(u, n[0]) cfg.branch(entry.root.out(0), middle_1.root) cfg.branch(entry.root.out(1), middle_2.root) @@ -46,4 +46,4 @@ def test_nested_cfg() -> None: build_basic_cfg(cfg) dfg.set_outputs(cfg.root) - _validate(dfg.hugr, True) + _validate(dfg.hugr) From c8c3bcfac798c59f55b959a9831092848d7eb737 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 12:33:51 +0100 Subject: [PATCH 5/6] feat: allow edges between (dominating) basic blocks --- hugr-py/src/hugr/_cfg.py | 22 +++++++++++++++++++++- hugr-py/src/hugr/_dfg.py | 17 ++++++++++------- hugr-py/src/hugr/_exceptions.py | 10 ++++++++++ hugr-py/tests/test_cfg.py | 22 ++++++++++++++++++++++ 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 4af60bb4f..0151aeab2 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,9 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence +from typing import Iterable, Sequence from ._hugr import Hugr, Node, Wire from ._dfg import DfBase, _from_base from ._tys import FunctionType, TypeRow, Sum +from ._exceptions import NoSiblingAncestor, NotInSameCfg import hugr._ops as ops @@ -15,6 +16,25 @@ def single_successor_outputs(self, *outputs: Wire) -> None: # TODO requires constants raise NotImplementedError + def _wire_up(self, node: Node, ports: Iterable[Wire]): + for i, p in enumerate(ports): + src = p.out_port() + cfg_node = self.hugr[self.root].parent + assert cfg_node is not None + src_parent = self.hugr[src.node].parent + try: + self._wire_up_port(node, i, p) + except NoSiblingAncestor: + # note this just checks if there is a common CFG ancestor + # it does not check for valid dominance between basic blocks + # that is deferred to full HUGR validation. + while cfg_node != src_parent: + if src_parent is None or src_parent == self.hugr.root: + raise NotInSameCfg(src.node.idx, node.idx) + src_parent = self.hugr[src_parent].parent + + self.hugr.add_link(src, node.inp(i)) + @dataclass class Cfg: diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 5f10b5029..17499a7b9 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -97,13 +97,16 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): - src = p.out_port() - node_ancestor = _ancestral_sibling(self.hugr, src.node, node) - if node_ancestor is None: - raise NoSiblingAncestor(src.node.idx, node.idx) - if node_ancestor != node: - self.add_state_order(src.node, node_ancestor) - self.hugr.add_link(src, node.inp(i)) + self._wire_up_port(node, i, p) + + def _wire_up_port(self, node: Node, offset: int, p: Wire): + src = p.out_port() + node_ancestor = _ancestral_sibling(self.hugr, src.node, node) + if node_ancestor is None: + raise NoSiblingAncestor(src.node.idx, node.idx) + if node_ancestor != node: + self.add_state_order(src.node, node_ancestor) + self.hugr.add_link(src, node.inp(offset)) C = TypeVar("C", bound=DfBase) diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index d59d99972..92ba7ceb0 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -11,5 +11,15 @@ def msg(self): return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." +@dataclass +class NotInSameCfg(Exception): + src: int + tgt: int + + @property + def msg(self): + return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." + + class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 20372b328..457334526 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -47,3 +47,25 @@ def test_nested_cfg() -> None: dfg.set_outputs(cfg.root) _validate(dfg.hugr) + + +def test_dom_edge() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [INT_T]) + b, u, i = entry.inputs() + entry.set_block_outputs(b, i) + + # entry dominates both middles so Unit type can be used as inter-graph + # value between basic blocks + middle_1 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_1.set_block_outputs(u, *middle_1.inputs()) + middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_2.set_block_outputs(u, *middle_2.inputs()) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr) From 58a03b3bef5f1219a176b5df57ec51eff6918706 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 12:34:20 +0100 Subject: [PATCH 6/6] rename set_single_successor_outputs --- hugr-py/src/hugr/_cfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 0151aeab2..8ac058a23 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -12,7 +12,7 @@ class Block(DfBase[ops.DataflowBlock]): def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) - def single_successor_outputs(self, *outputs: Wire) -> None: + def set_single_successor_outputs(self, *outputs: Wire) -> None: # TODO requires constants raise NotImplementedError