Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): CFG builder #1192

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from ._hugr import Hugr, Node, Wire
from ._dfg import DfBase, _from_base
from ._tys import Type, FunctionType, TypeRow, Sum
import hugr._ops as ops


class Block(DfBase[ops.DataflowBlock]):
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
def block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
self.set_outputs(branching, *other_outputs)

def single_successor_outputs(self, *outputs: Wire) -> None:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
# TODO requires constants
raise NotImplementedError

Check warning on line 16 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L16

Added line #L16 was not covered by tests


@dataclass
class Cfg:
hugr: Hugr
root: Node
_entry_block: Block
exit: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = _from_base(
Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, []))
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)

@property
def entry(self) -> Node:
return self._entry_block.root

def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
new_block = self.hugr.add_dfg(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs)
)
return _from_base(Block, new_block)

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: Node) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
99 changes: 75 additions & 24 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, Iterable
from typing import Sequence, Iterable, TYPE_CHECKING, Generic, TypeVar, cast
import typing
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
import hugr._ops as ops
from ._exceptions import NoSiblingAncestor
from hugr._tys import FunctionType, Type
from hugr._tys import FunctionType, Type, TypeRow

if TYPE_CHECKING:
from ._cfg import Cfg

Check warning on line 12 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L12

Added line #L12 was not covered by tests


DP = TypeVar("DP", bound=ops.DfParentOp)


@dataclass()
class Dfg:
class DfBase(Generic[DP]):
hugr: Hugr
root: Node
input_node: Node
output_node: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DFG(FunctionType(input=input_types, output=output_types))
def __init__(self, root_op: DP) -> None:
input_types = root_op.input_types()
output_types = root_op.output_types()
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
Input(input_types), self.root, len(input_types)
ops.Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> Input:
def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, Input)
assert isinstance(dop, ops.Input)
return dop

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)

Check warning on line 46 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L46

Added line #L46 was not covered by tests

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
Expand All @@ -55,13 +63,30 @@

def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
dfg = self.hugr.add_dfg(input_types, output_types)
dfg = self.hugr.add_dfg(
ops.DFG(FunctionType(input=input_types, output=output_types))
)
self._wire_up(dfg.root, args)
return dfg
return _from_base(Dfg, dfg)

def add_cfg(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
*args: Wire,
) -> Cfg:
cfg = self.hugr.add_cfg(input_types, output_types)
self._wire_up(cfg.root, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]

Check warning on line 89 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L87-L89

Added lines #L87 - L89 were not covered by tests

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -81,6 +106,32 @@
self.hugr.add_link(src, node.inp(i))


C = TypeVar("C", bound=DfBase)


def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C:
new = cls.__new__(cls)
new.hugr = base.hugr
new.root = base.root
new.input_node = base.input_node
new.output_node = base.output_node
return new


class Dfg(DfBase[ops.DFG]):
def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
super().__init__(root_op)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return cls(types, types)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
src_parent = h[src].parent

Expand Down
26 changes: 21 additions & 5 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
Iterable,
Iterator,
Protocol,
Sequence,
TypeVar,
cast,
overload,
Sequence,
)

from typing_extensions import Self
Expand All @@ -27,7 +27,8 @@
from ._exceptions import ParentBeforeChild

if TYPE_CHECKING:
from ._dfg import Dfg
from ._dfg import DfBase, DP
from ._cfg import Cfg

Check warning on line 31 in hugr-py/src/hugr/_hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_hugr.py#L30-L31

Added lines #L30 - L31 were not covered by tests


class Direction(Enum):
Expand Down Expand Up @@ -337,17 +338,32 @@
)
return mapping

def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg:
from ._dfg import Dfg
def add_dfg(self, root_op: DP) -> DfBase[DP]:
from ._dfg import DfBase

dfg = Dfg(input_types, output_types)
dfg = DfBase(root_op)
mapping = self.insert_hugr(dfg.hugr, self.root)
dfg.hugr = self
dfg.input_node = mapping[dfg.input_node]
dfg.output_node = mapping[dfg.output_node]
dfg.root = mapping[dfg.root]
return dfg

def add_cfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Cfg:
from ._cfg import Cfg

cfg = Cfg(input_types, output_types)
mapping = self.insert_hugr(cfg.hugr, self.root)
cfg.hugr = self
cfg._entry_block.root = mapping[cfg.entry]
cfg._entry_block.input_node = mapping[cfg._entry_block.input_node]
cfg._entry_block.output_node = mapping[cfg._entry_block.output_node]
cfg._entry_block.hugr = self
cfg.exit = mapping[cfg.exit]
cfg.root = mapping[cfg.root]
# TODO this is horrible
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #1194

Comment on lines +354 to +363
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nicer to add an optional hugr argument to Cfg.__init__? If supplied, the CFG is inserted into the given Hugr, otherwise a new one is created.

Alternatively, it could be a classmethod Cfg.new_nested(hugr, input_types, output_types) or something similar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the class method idea, I'll try it out on #1194

return cfg

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

Expand Down
67 changes: 66 additions & 1 deletion hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,13 @@
return super().__call__(tuple_)


class DfParentOp(Op, Protocol):
def input_types(self) -> tys.TypeRow: ...
def output_types(self) -> tys.TypeRow: ...
mark-koch marked this conversation as resolved.
Show resolved Hide resolved


@dataclass()
class DFG(Op):
class DFG(DfParentOp):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
Expand All @@ -134,3 +139,63 @@
parent=parent.idx,
signature=self.signature.to_serial(),
)

def input_types(self) -> tys.TypeRow:
return self.signature.input

def output_types(self) -> tys.TypeRow:
return self.signature.output


@dataclass()
class CFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
def num_out(self) -> int | None:
return len(self.signature.output)

Check warning on line 156 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L156

Added line #L156 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG:
return sops.CFG(
parent=parent.idx,
signature=self.signature.to_serial(),
)


@dataclass
class DataflowBlock(DfParentOp):
inputs: tys.TypeRow
sum_rows: list[tys.TypeRow]
other_outputs: tys.TypeRow = field(default_factory=list)
extension_delta: tys.ExtensionSet = field(default_factory=list)

@property
def num_out(self) -> int | None:
return len(self.sum_rows)

Check warning on line 174 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L174

Added line #L174 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock:
return sops.DataflowBlock(
parent=parent.idx,
inputs=ser_it(self.inputs),
sum_rows=list(map(ser_it, self.sum_rows)),
other_outputs=ser_it(self.other_outputs),
extension_delta=self.extension_delta,
)

def input_types(self) -> tys.TypeRow:
return self.inputs

def output_types(self) -> tys.TypeRow:
return [tys.Sum(self.sum_rows), *self.other_outputs]


@dataclass
class ExitBlock(Op):
cfg_outputs: tys.TypeRow
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock:
return sops.ExitBlock(
parent=parent.idx,
cfg_outputs=ser_it(self.cfg_outputs),
)
1 change: 1 addition & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,4 @@ def to_serial(self) -> stys.Qubit:

Qubit = QubitDef()
Bool = UnitSum(size=2)
Unit = UnitSum(size=1)
Loading