Skip to content

Commit

Permalink
feat(hugr-py): only require input type annotations when building (#1199)
Browse files Browse the repository at this point in the history
Closes #1198

- ops have to report their signature, but may be initialised incomplete
- dataflow ops have hooks to update themselves based on the types
flowing in to them
- parent ops can update themselves based on the types flowing out of
them
- Tag needs to know all possible variants, can't infer from input types
- calculate cfg outputs by looking at output types of blocks that
connect to exit block

reviewing commit by commit might be easier

---------

Co-authored-by: Mark Koch <[email protected]>
  • Loading branch information
ss2165 and mark-koch authored Jun 18, 2024
1 parent cce468a commit 2bb079f
Show file tree
Hide file tree
Showing 9 changed files with 480 additions and 208 deletions.
134 changes: 73 additions & 61 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Sequence
from dataclasses import dataclass, replace

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, Sum, TypeRow
from ._tys import FunctionType, TypeRow, Type


class Block(_DfBase[ops.DataflowBlock]):
Expand All @@ -19,99 +18,112 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))
def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
cfg_node = self.hugr[self.parent_node].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
super()._wire_up_port(node, offset, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


@dataclass
class Cfg(ParentBuilder):
class Cfg(ParentBuilder[ops.CFG]):
hugr: Hugr
root: Node
parent_node: Node
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types, output_types)
self._init_impl(hugr, hugr.root, input_types)

def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
self.hugr = hugr
self.root = root
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)
self._entry_block = Block.new_nested(ops.DataflowBlock(input_types), hugr, root)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)
self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node)

@classmethod
def new_nested(
cls,
input_types: TypeRow,
output_types: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=output_types)),
ops.CFG(FunctionType(input=input_types, output=[])),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types, output_types)
new._init_impl(hugr, root, input_types)
return new

@property
def entry(self) -> Node:
return self._entry_block.root
return self._entry_block.parent_node

@property
def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block
return self.hugr._get_typed_op(self.entry, ops.DataflowBlock)

@property
def _exit_op(self) -> ops.ExitBlock:
return self.hugr._get_typed_op(self.exit, ops.ExitBlock)

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)
def add_entry(self) -> Block:
return self._entry_block

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
def add_block(self, input_types: TypeRow) -> Block:
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
ops.DataflowBlock(input_types),
self.hugr,
self.root,
self.parent_node,
)
return new_block

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)
def add_successor(self, pred: Wire) -> Block:
b = self.add_block(self._nth_outputs(pred))

self.branch(pred, b)
return b

def _nth_outputs(self, wire: Wire) -> TypeRow:
port = wire.out_port()
block = self.hugr._get_typed_op(port.node, ops.DataflowBlock)
return block.nth_outputs(port.offset)

def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
# TODO check for existing link/type compatibility
if dst.to_node() == self.exit:
return self.branch_exit(src)
src = src.out_port()
self.hugr.add_link(src, dst.inp(0))

def branch_exit(self, src: Wire) -> None:
src = src.out_port()
self.hugr.add_link(src, self.exit.inp(0))

out_types = self._nth_outputs(src)
if self._exit_op._cfg_outputs is not None:
if self._exit_op._cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)
else:
self._exit_op._cfg_outputs = out_types
self.parent_op.signature = replace(
self.parent_op.signature, output=out_types
)
119 changes: 63 additions & 56 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast
from typing_extensions import Self
import hugr._ops as ops
from hugr._tys import FunctionType, TypeRow
from hugr._tys import TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode
from ._hugr import ToNode
from hugr._tys import Type

if TYPE_CHECKING:
from ._cfg import Cfg
Expand All @@ -17,122 +23,123 @@


@dataclass()
class _DfBase(ParentBuilder, Generic[DP]):
class _DfBase(ParentBuilder[DP]):
hugr: Hugr
root: Node
parent_node: Node
input_node: Node
output_node: Node

def __init__(self, root_op: DP) -> None:
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self._init_io_nodes(root_op)
def __init__(self, parent_op: DP) -> None:
self.hugr = Hugr(parent_op)
self.parent_node = self.hugr.root
self._init_io_nodes(parent_op)

def _init_io_nodes(self, parent_op: DP):
inputs = parent_op._inputs()

def _init_io_nodes(self, root_op: DP):
input_types = root_op.input_types()
output_types = root_op.output_types()
self.input_node = self.hugr.add_node(
ops.Input(input_types), self.root, len(input_types)
ops.Input(inputs), self.parent_node, len(inputs)
)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(), self.parent_node)

@classmethod
def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self:
def new_nested(
cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None
) -> Self:
new = cls.__new__(cls)

new.hugr = hugr
new.root = hugr.add_node(root_op, parent or hugr.root)
new._init_io_nodes(root_op)
new.parent_node = hugr.add_node(parent_op, parent or hugr.root)
new._init_io_nodes(parent_op)
return new

def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, ops.Input)
return dop
return self.hugr._get_typed_op(self.input_node, ops.Input)

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)
return self.hugr._get_typed_op(self.output_node, ops.Output)

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
new_n = self.hugr.add_node(op, self.parent_node)
self._wire_up(new_n, args)
return new_n

return replace(new_n, _num_out_ports=op.num_out)

def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)
return self.add_op(com.op, *com.incoming)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]

def add_nested(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
dfg = Dfg.new_nested(root_op, self.hugr, self.root)
self._wire_up(dfg.root, args)
input_types = [self._get_dataflow_type(w) for w in args]

parent_op = ops.DFG(list(input_types))
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def add_cfg(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Cfg:
from ._cfg import Cfg

cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root)
self._wire_up(cfg.root, args)
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op._set_out_types(self._output_op().types)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
ty = self.hugr.port_type(port)
if ty is None:
raise ValueError(f"Port {port} is not a dataflow port.")
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


class Dfg(_DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)
def __init__(self, *input_types: Type) -> None:
parent_op = ops.DFG(list(input_types))
super().__init__(parent_op)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
Expand Down
Loading

0 comments on commit 2bb079f

Please sign in to comment.