Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Sequence
from dataclasses import dataclass, replace

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, Sum, TypeRow
from ._tys import FunctionType, TypeRow, Type


class Block(_DfBase[ops.DataflowBlock]):
Expand All @@ -19,99 +18,112 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))
def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
cfg_node = self.hugr[self.parent_node].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
super()._wire_up_port(node, offset, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


@dataclass
class Cfg(ParentBuilder):
class Cfg(ParentBuilder[ops.CFG]):
hugr: Hugr
root: Node
parent_node: Node
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types, output_types)
self._init_impl(hugr, hugr.root, input_types)

def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
self.hugr = hugr
self.root = root
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)
self._entry_block = Block.new_nested(ops.DataflowBlock(input_types), hugr, root)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)
self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node)

@classmethod
def new_nested(
cls,
input_types: TypeRow,
output_types: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=output_types)),
ops.CFG(FunctionType(input=input_types, output=[])),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types, output_types)
new._init_impl(hugr, root, input_types)
return new

@property
def entry(self) -> Node:
return self._entry_block.root
return self._entry_block.parent_node

@property
def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block
return self.hugr._get_typed_op(self.entry, ops.DataflowBlock)

@property
def _exit_op(self) -> ops.ExitBlock:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
return self.hugr._get_typed_op(self.exit, ops.ExitBlock)

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)
def add_entry(self) -> Block:
return self._entry_block

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
def add_block(self, input_types: TypeRow) -> Block:
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
ops.DataflowBlock(input_types),
self.hugr,
self.root,
self.parent_node,
)
return new_block

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)
def add_successor(self, pred: Wire) -> Block:
b = self.add_block(self._nth_outputs(pred))

self.branch(pred, b)
return b

def _nth_outputs(self, wire: Wire) -> TypeRow:
port = wire.out_port()
block = self.hugr._get_typed_op(port.node, ops.DataflowBlock)
return block.nth_outputs(port.offset)

def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
# TODO check for existing link/type compatibility
if dst.to_node() == self.exit:
return self.branch_exit(src)
src = src.out_port()
self.hugr.add_link(src, dst.inp(0))

def branch_exit(self, src: Wire) -> None:
src = src.out_port()
self.hugr.add_link(src, self.exit.inp(0))

out_types = self._nth_outputs(src)
if self._exit_op._cfg_outputs is not None:
if self._exit_op._cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)
else:
self._exit_op._cfg_outputs = out_types
self.parent_op.signature = replace(
self.parent_op.signature, output=out_types
)
119 changes: 63 additions & 56 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast
from typing_extensions import Self
import hugr._ops as ops
from hugr._tys import FunctionType, TypeRow
from hugr._tys import TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode
from ._hugr import ToNode
from hugr._tys import Type

if TYPE_CHECKING:
from ._cfg import Cfg
Expand All @@ -17,122 +23,123 @@


@dataclass()
class _DfBase(ParentBuilder, Generic[DP]):
class _DfBase(ParentBuilder[DP]):
hugr: Hugr
root: Node
parent_node: Node
input_node: Node
output_node: Node

def __init__(self, root_op: DP) -> None:
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self._init_io_nodes(root_op)
def __init__(self, parent_op: DP) -> None:
self.hugr = Hugr(parent_op)
self.parent_node = self.hugr.root
self._init_io_nodes(parent_op)

def _init_io_nodes(self, parent_op: DP):
inputs = parent_op._inputs()

def _init_io_nodes(self, root_op: DP):
input_types = root_op.input_types()
output_types = root_op.output_types()
self.input_node = self.hugr.add_node(
ops.Input(input_types), self.root, len(input_types)
ops.Input(inputs), self.parent_node, len(inputs)
)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(), self.parent_node)

@classmethod
def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self:
def new_nested(
cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None
) -> Self:
new = cls.__new__(cls)

new.hugr = hugr
new.root = hugr.add_node(root_op, parent or hugr.root)
new._init_io_nodes(root_op)
new.parent_node = hugr.add_node(parent_op, parent or hugr.root)
new._init_io_nodes(parent_op)
return new

def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, ops.Input)
return dop
return self.hugr._get_typed_op(self.input_node, ops.Input)

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)
return self.hugr._get_typed_op(self.output_node, ops.Output)

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
new_n = self.hugr.add_node(op, self.parent_node)
self._wire_up(new_n, args)
return new_n

return replace(new_n, _num_out_ports=op.num_out)

def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)
return self.add_op(com.op, *com.incoming)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]

def add_nested(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
dfg = Dfg.new_nested(root_op, self.hugr, self.root)
self._wire_up(dfg.root, args)
input_types = [self._get_dataflow_type(w) for w in args]

parent_op = ops.DFG(list(input_types))
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def add_cfg(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Cfg:
from ._cfg import Cfg

ss2165 marked this conversation as resolved.
Show resolved Hide resolved
cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root)
self._wire_up(cfg.root, args)
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op._set_out_types(self._output_op().types)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
ty = self.hugr.port_type(port)
if ty is None:
raise ValueError(f"Port {port} is not a dataflow port.")
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


class Dfg(_DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)
def __init__(self, *input_types: Type) -> None:
parent_op = ops.DFG(list(input_types))
super().__init__(parent_op)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
Expand Down
Loading
Loading