Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import (
Iterator,
Iterable,
TYPE_CHECKING,
Generic,
TypeVar,
cast,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast
from typing_extensions import Self
import hugr._ops as ops
from hugr._tys import FunctionType, TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode
from ._hugr import ToNode
from hugr._tys import Type

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -61,15 +70,14 @@ def root_op(self) -> DP:
def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(
self, op: ops.DataflowOp, /, *args: Wire, num_outs: int | None = None
) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
new_n = self.hugr.add_node(op, self.root)
self._wire_up(new_n, args)
return new_n

return replace(new_n, _num_out_ports=op.num_out)

def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)
return self.add_op(com.op, *com.incoming)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
Expand All @@ -78,13 +86,13 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:

def add_nested(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
_, input_types = zip(*self._get_dataflow_types(args)) if args else ([], [])

root_op = ops.DFG(FunctionType(input=list(input_types), output=[]))
dfg = Dfg.new_nested(root_op, self.hugr, self.root)
self._wire_up(dfg.root, args)
return dfg
Expand All @@ -108,14 +116,27 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.root_op()._set_out_types(self._output_op().types)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
tys = []
for i, (p, ty) in enumerate(self._get_dataflow_types(ports)):
tys.append(ty)
self._wire_up_port(node, i, p)
if isinstance(op := self.hugr[node].op, ops.DataflowOp):
op._set_in_types(tys)

def _get_dataflow_types(self, wires: Iterable[Wire]) -> Iterator[tuple[Wire, Type]]:
for w in wires:
port = w.out_port()
ty = self.hugr.port_type(port)
if ty is None:
raise ValueError(f"Port {port} is not a dataflow port.")
yield w, ty

def _wire_up_port(self, node: Node, offset: int, p: Wire):
src = p.out_port()
Expand All @@ -128,14 +149,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire):


class Dfg(_DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
def __init__(self, *input_types: Type) -> None:
root_op = ops.DFG(FunctionType(input=list(input_types), output=[]))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
src_parent = h[src].parent
Expand Down
42 changes: 35 additions & 7 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
Expand Down Expand Up @@ -39,6 +39,9 @@ def port_type(self, port: _Port) -> tys.Type:
return sig.input[port.offset]
return sig.output[port.offset]

def _set_in_types(self, types: tys.TypeRow) -> None:
return

def __call__(self, *args) -> Command:
return Command(self, list(args))

Expand Down Expand Up @@ -86,14 +89,17 @@ def __call__(self) -> Command:

@dataclass()
class Output(DataflowOp):
types: tys.TypeRow
types: tys.TypeRow = field(default_factory=list)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=ser_it(self.types))

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[])

def _set_in_types(self, types: tys.TypeRow) -> None:
self.types = types


@dataclass()
class Custom(DataflowOp):
Expand Down Expand Up @@ -122,8 +128,8 @@ def outer_signature(self) -> tys.FunctionType:


@dataclass()
class MakeTuple(DataflowOp):
types: tys.TypeRow
class MakeTupleDef(DataflowOp):
types: tys.TypeRow = field(default_factory=list)
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
Expand All @@ -138,10 +144,16 @@ def __call__(self, *elements: Wire) -> Command:
def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)])

def _set_in_types(self, types: tys.TypeRow) -> None:
self.types = types


MakeTuple = MakeTupleDef()


@dataclass()
class UnpackTuple(DataflowOp):
types: tys.TypeRow
class UnpackTupleDef(DataflowOp):
types: tys.TypeRow = field(default_factory=list)

@property
def num_out(self) -> int | None:
Expand All @@ -157,12 +169,25 @@ def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)

def outer_signature(self) -> tys.FunctionType:
return MakeTuple(self.types).outer_signature().flip()
return MakeTupleDef(self.types).outer_signature().flip()

def _set_in_types(self, types: tys.TypeRow) -> None:
(t,) = types
assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}"
(row,) = t.variant_rows
self.types = row
print(row)
ss2165 marked this conversation as resolved.
Show resolved Hide resolved


UnpackTuple = UnpackTupleDef()


class DfParentOp(Op, Protocol):
def inner_signature(self) -> tys.FunctionType: ...

def _set_out_types(self, types: tys.TypeRow) -> None:
return


@dataclass()
class DFG(DfParentOp, DataflowOp):
Expand All @@ -184,6 +209,9 @@ def inner_signature(self) -> tys.FunctionType:
def outer_signature(self) -> tys.FunctionType:
return self.signature

def _set_out_types(self, types: tys.TypeRow) -> None:
self.signature = replace(self.signature, output=types)


@dataclass()
class CFG(DataflowOp):
Expand Down
8 changes: 4 additions & 4 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
in_types = []
self.tys = list(in_types)

def deserialize(self) -> _ops.MakeTuple:
return _ops.MakeTuple(deser_it(self.tys))
def deserialize(self) -> _ops.MakeTupleDef:
return _ops.MakeTupleDef(deser_it(self.tys))


class UnpackTuple(DataflowOp):
Expand All @@ -460,8 +460,8 @@ class UnpackTuple(DataflowOp):
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)

def deserialize(self) -> _ops.UnpackTuple:
return _ops.UnpackTuple(deser_it(self.tys))
def deserialize(self) -> _ops.UnpackTupleDef:
return _ops.UnpackTupleDef(deser_it(self.tys))


class Tag(DataflowOp):
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_branch() -> None:


def test_nested_cfg() -> None:
dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool])
dfg = Dfg(tys.Unit, tys.Bool)

cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs())

Expand Down
38 changes: 19 additions & 19 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ def test_stable_indices():


def test_simple_id():
h = Dfg.endo([tys.Qubit] * 2)
h = Dfg(tys.Qubit, tys.Qubit)
a, b = h.inputs()
h.set_outputs(a, b)

_validate(h.hugr)


def test_multiport():
h = Dfg([tys.Bool], [tys.Bool] * 2)
h = Dfg(tys.Bool)
(a,) = h.inputs()
h.set_outputs(a, a)
in_n, ou_n = h.input_node, h.output_node
Expand All @@ -151,7 +151,7 @@ def test_multiport():


def test_add_op():
h = Dfg.endo([tys.Bool])
h = Dfg(tys.Bool)
(a,) = h.inputs()
nt = h.add_op(Not, a)
h.set_outputs(nt)
Expand All @@ -161,33 +161,33 @@ def test_add_op():

def test_tuple():
row = [tys.Bool, tys.Qubit]
h = Dfg.endo(row)
h = Dfg(*row)
a, b = h.inputs()
t = h.add(ops.MakeTuple(row)(a, b))
a, b = h.add(ops.UnpackTuple(row)(t))
t = h.add(ops.MakeTuple(a, b))
a, b = h.add(ops.UnpackTuple(t))
h.set_outputs(a, b)

_validate(h.hugr)

h1 = Dfg.endo(row)
h1 = Dfg(*row)
a, b = h1.inputs()
mt = h1.add_op(ops.MakeTuple(row), a, b)
a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1]
mt = h1.add_op(ops.MakeTuple, a, b)
a, b = h1.add_op(ops.UnpackTuple, mt)[0, 1]
h1.set_outputs(a, b)

assert h.hugr.to_serial() == h1.hugr.to_serial()


def test_multi_out():
h = Dfg([INT_T] * 2, [INT_T] * 2)
h = Dfg(INT_T, INT_T)
a, b = h.inputs()
a, b = h.add(DivMod(a, b))
h.set_outputs(a, b)
_validate(h.hugr)


def test_insert():
h1 = Dfg.endo([tys.Bool])
h1 = Dfg(tys.Bool)
(a1,) = h1.inputs()
nt = h1.add(Not(a1))
h1.set_outputs(nt)
Expand All @@ -200,12 +200,12 @@ def test_insert():


def test_insert_nested():
h1 = Dfg.endo([tys.Bool])
h1 = Dfg(tys.Bool)
(a1,) = h1.inputs()
nt = h1.add(Not(a1))
h1.set_outputs(nt)

h = Dfg.endo([tys.Bool])
h = Dfg(tys.Bool)
(a,) = h.inputs()
nested = h.insert_nested(h1, a)
h.set_outputs(nested)
Expand All @@ -219,9 +219,9 @@ def _nested_nop(dfg: Dfg):
nt = dfg.add(Not(a1))
dfg.set_outputs(nt)

h = Dfg.endo([tys.Bool])
h = Dfg(tys.Bool)
(a,) = h.inputs()
nested = h.add_nested([tys.Bool], [tys.Bool], a)
nested = h.add_nested(a)

_nested_nop(nested)
assert len(h.hugr.children(nested)) == 3
Expand All @@ -231,9 +231,9 @@ def _nested_nop(dfg: Dfg):


def test_build_inter_graph():
h = Dfg.endo([tys.Bool, tys.Bool])
h = Dfg(tys.Bool, tys.Bool)
(a, b) = h.inputs()
nested = h.add_nested([], [tys.Bool])
nested = h.add_nested()

nt = nested.add(Not(a))
nested.set_outputs(nt)
Expand All @@ -250,9 +250,9 @@ def test_build_inter_graph():


def test_ancestral_sibling():
h = Dfg.endo([tys.Bool])
h = Dfg(tys.Bool)
(a,) = h.inputs()
nested = h.add_nested([], [tys.Bool])
nested = h.add_nested()

nt = nested.add(Not(a))

Expand Down