Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 27 additions & 32 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence
from dataclasses import dataclass, replace

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, Sum, TypeRow, Type
from ._tys import FunctionType, TypeRow, Type


class Block(_DfBase[ops.DataflowBlock]):
Expand All @@ -32,7 +31,7 @@
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)

Check warning on line 34 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L34

Added line #L34 was not covered by tests
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(offset))
Expand All @@ -46,73 +45,69 @@
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types, output_types)
self._init_impl(hugr, hugr.root, input_types)

def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
self.hugr = hugr
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.parent_node)
self.exit = self.hugr.add_node(ops.ExitBlock([]), self.parent_node)

@classmethod
def new_nested(
cls,
input_types: TypeRow,
output_types: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=output_types)),
ops.CFG(FunctionType(input=input_types, output=[])),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types, output_types)
new._init_impl(hugr, root, input_types)
return new

@property
def entry(self) -> Node:
return self._entry_block.parent_node

Check warning on line 80 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L80

Added line #L80 was not covered by tests

def _entry_op(self) -> ops.DataflowBlock:
return self.hugr._get_typed_op(self.entry, ops.DataflowBlock)

Check warning on line 83 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L83

Added line #L83 was not covered by tests

def _exit_op(self) -> ops.ExitBlock:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
return self.hugr._get_typed_op(self.exit, ops.ExitBlock)

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
def add_entry(self) -> Block:
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
def add_block(self, input_types: TypeRow) -> Block:
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
ops.DataflowBlock(input_types, [], []),
self.hugr,
self.parent_node,
)
return new_block

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
src = src.out_port()
self.hugr.add_link(src, dst.inp(0))

if dst == self.exit:
src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock)
out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs]
if self._exit_op().cfg_outputs:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
if self._exit_op().cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)

Check warning on line 108 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L108

Added line #L108 was not covered by tests
else:
self._exit_op().cfg_outputs = out_types
self.parent_op().signature = replace(
self.parent_op().signature, output=out_types
)
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@
) -> Cfg:
from ._cfg import Cfg

ss2165 marked this conversation as resolved.
Show resolved Hide resolved
cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.parent_node)
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]

Check warning on line 107 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L105-L107

Added lines #L105 - L107 were not covered by tests

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -123,7 +123,7 @@
port = wire.out_port()
ty = self.hugr.port_type(port)
if ty is None:
raise ValueError(f"Port {port} is not a dataflow port.")

Check warning on line 126 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L126

Added line #L126 was not covered by tests
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
Expand Down
11 changes: 11 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,16 @@
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."


@dataclass
class MismatchedExit(Exception):
src: int

@property
def msg(self):
return (

Check warning on line 30 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L30

Added line #L30 was not covered by tests
f"Exit branch from node {self.src} does not match existing exit block type."
)


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
26 changes: 25 additions & 1 deletion hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hugr._tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire, _Port

Check warning on line 11 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L11

Added line #L11 was not covered by tests


@runtime_checkable
Expand All @@ -27,16 +27,16 @@
def outer_signature(self) -> tys.FunctionType: ...
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

def port_kind(self, port: _Port) -> tys.Kind:
if port.offset == -1:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
return tys.OrderKind()
return tys.ValueKind(self.port_type(port))

Check warning on line 32 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L30-L32

Added lines #L30 - L32 were not covered by tests

def port_type(self, port: _Port) -> tys.Type:
from hugr._hugr import Direction

sig = self.outer_signature()
if port.direction == Direction.INCOMING:
return sig.input[port.offset]

Check warning on line 39 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L39

Added line #L39 was not covered by tests
return sig.output[port.offset]

def _set_in_types(self, types: tys.TypeRow) -> None:
Expand Down Expand Up @@ -66,7 +66,7 @@
return root

def port_kind(self, port: _Port) -> tys.Kind:
raise NotImplementedError

Check warning on line 69 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L69

Added line #L69 was not covered by tests
mark-koch marked this conversation as resolved.
Show resolved Hide resolved


@dataclass()
Expand Down Expand Up @@ -95,7 +95,7 @@
return sops.Output(parent=parent.idx, types=ser_it(self.types))

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[])

Check warning on line 98 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L98

Added line #L98 was not covered by tests

def _set_in_types(self, types: tys.TypeRow) -> None:
self.types = types
Expand Down Expand Up @@ -176,17 +176,35 @@
assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}"
(row,) = t.variant_rows
self.types = row
print(row)


UnpackTuple = UnpackTupleDef()


@dataclass()
class Tag(DataflowOp):
tag: int
variants: list[tys.TypeRow]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag:
return sops.Tag(
parent=parent.idx,
tag=self.tag,
variants=[ser_it(r) for r in self.variants],
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(
input=self.variants[self.tag], output=[tys.Sum(self.variants)]
)

ss2165 marked this conversation as resolved.
Show resolved Hide resolved

class DfParentOp(Op, Protocol):
def inner_signature(self) -> tys.FunctionType: ...

def _set_out_types(self, types: tys.TypeRow) -> None:
return

Check warning on line 207 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L207

Added line #L207 was not covered by tests


@dataclass()
Expand Down Expand Up @@ -257,8 +275,14 @@
)

def port_kind(self, port: _Port) -> tys.Kind:
return tys.CFKind()

Check warning on line 278 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L278

Added line #L278 was not covered by tests

def _set_out_types(self, types: tys.TypeRow) -> None:
(sum_, *other) = types
assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}"
self.sum_rows = sum_.variant_rows
self.other_outputs = other


@dataclass
class ExitBlock(Op):
Expand All @@ -272,4 +296,4 @@
)

def port_kind(self, port: _Port) -> tys.Kind:
return tys.CFKind()

Check warning on line 299 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L299

Added line #L299 was not covered by tests
20 changes: 12 additions & 8 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,6 @@ def to_serial(self) -> stys.Array:
return stys.Array(ty=self.ty.to_serial_root(), len=self.size)


@dataclass(frozen=True)
class UnitSum(Type):
size: int

def to_serial(self) -> stys.UnitSum:
return stys.UnitSum(size=self.size)


@dataclass()
class Sum(Type):
variant_rows: list[TypeRow]
Expand All @@ -181,6 +173,18 @@ def as_tuple(self) -> Tuple:
return Tuple(*self.variant_rows[0])


@dataclass()
class UnitSum(Sum):
size: int

def __init__(self, size: int):
self.size = size
super().__init__(variant_rows=[[]] * size)

def to_serial(self) -> stys.UnitSum: # type: ignore[override]
return stys.UnitSum(size=self.size)


@dataclass()
class Tuple(Sum):
def __init__(self, *tys: Type):
Expand Down
42 changes: 32 additions & 10 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
from hugr._cfg import Cfg
import hugr._tys as tys
from hugr._dfg import Dfg
import hugr._ops as ops
from .test_hugr_build import _validate, INT_T, DivMod


def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.simple_entry(1, [tys.Bool])
entry = cfg.add_entry()

entry.set_block_outputs(*entry.inputs())
cfg.branch(entry[0], cfg.exit)


def test_basic_cfg() -> None:
cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool])
cfg = Cfg([tys.Unit, tys.Bool])
build_basic_cfg(cfg)
_validate(cfg.hugr)


def test_branch() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [tys.Unit, INT_T])
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
entry.set_block_outputs(*entry.inputs())

middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
middle_1 = cfg.add_block([tys.Unit, INT_T])
middle_1.set_block_outputs(*middle_1.inputs())
middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
middle_2 = cfg.add_block([tys.Unit, INT_T])
u, i = middle_2.inputs()
n = middle_2.add(DivMod(i, i))
middle_2.set_block_outputs(u, n[0])
Expand All @@ -50,16 +51,16 @@ def test_nested_cfg() -> None:


def test_dom_edge() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [INT_T])
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
b, u, i = entry.inputs()
entry.set_block_outputs(b, i)

# entry dominates both middles so Unit type can be used as inter-graph
# value between basic blocks
middle_1 = cfg.simple_block([INT_T], 1, [INT_T])
middle_1 = cfg.add_block([INT_T])
middle_1.set_block_outputs(u, *middle_1.inputs())
middle_2 = cfg.simple_block([INT_T], 1, [INT_T])
middle_2 = cfg.add_block([INT_T])
middle_2.set_block_outputs(u, *middle_2.inputs())

cfg.branch(entry[0], middle_1)
Expand All @@ -69,3 +70,24 @@ def test_dom_edge() -> None:
cfg.branch(middle_2[0], cfg.exit)

_validate(cfg.hugr)


def test_asymm_types() -> None:
# test different types going to entry block's susccessors
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
b, u, i = entry.inputs()

tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i))
entry.set_block_outputs(tagged_int)

middle = cfg.add_block([INT_T])
# discard the int and return the bool from entry
middle.set_block_outputs(u, b)

# middle expects an int and exit expects a bool
cfg.branch(entry[0], middle)
cfg.branch(entry[1], cfg.exit)
cfg.branch(middle[0], cfg.exit)

_validate(cfg.hugr)
Loading