Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def __init__(self, root_op: DP) -> None:
self._init_io_nodes(root_op)

def _init_io_nodes(self, root_op: DP):
input_types = root_op.input_types()
output_types = root_op.output_types()
inner_sig = root_op.inner_signature()

self.input_node = self.hugr.add_node(
ops.Input(input_types), self.root, len(input_types)
ops.Input(inner_sig.input), self.root, len(inner_sig.input)
)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(inner_sig.output), self.root)

@classmethod
def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self:
Expand All @@ -61,7 +61,9 @@ def root_op(self) -> DP:
def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
def add_op(
self, op: ops.DataflowOp, /, *args: Wire, num_outs: int | None = None
) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n
Expand Down
12 changes: 11 additions & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from typing_extensions import Self

from hugr._ops import Op
from hugr._ops import Op, DataflowOp
from hugr._tys import Type, Kind
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap
Expand Down Expand Up @@ -333,6 +334,15 @@ def num_outgoing(self, node: ToNode) -> int:

# TODO: num_links and _linked_ports

def port_kind(self, port: InPort | OutPort) -> Kind:
return self[port.node].op.port_kind(port)

def port_type(self, port: InPort | OutPort) -> Type | None:
op = self[port.node].op
if isinstance(op, DataflowOp):
return op.port_type(port)
return None

def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
mapping: dict[Node, Node] = {}

Expand Down
86 changes: 66 additions & 20 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,51 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
from hugr.utils import ser_it
import hugr._tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire
from hugr._hugr import Hugr, Node, Wire, _Port


@runtime_checkable
class Op(Protocol):
@property
def num_out(self) -> int | None:
return None

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...

def port_kind(self, port: _Port) -> tys.Kind: ...
mark-koch marked this conversation as resolved.
Show resolved Hide resolved


@runtime_checkable
class DataflowOp(Op, Protocol):
def outer_signature(self) -> tys.FunctionType: ...
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

def port_kind(self, port: _Port) -> tys.Kind:
if port.offset == -1:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
return tys.OrderKind()
return tys.ValueKind(self.port_type(port))

def port_type(self, port: _Port) -> tys.Type:
from hugr._hugr import Direction

sig = self.outer_signature()
if port.direction == Direction.INCOMING:
return sig.input[port.offset]
return sig.output[port.offset]

def __call__(self, *args) -> Command:
return Command(self, list(args))


@dataclass(frozen=True)
class Command:
op: Op
op: DataflowOp
incoming: list[Wire]


Expand All @@ -41,9 +62,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root.parent = parent.idx
return root

def port_kind(self, port: _Port) -> tys.Kind:
raise NotImplementedError
mark-koch marked this conversation as resolved.
Show resolved Hide resolved


@dataclass()
class Input(Op):
class Input(DataflowOp):
types: tys.TypeRow

@property
Expand All @@ -53,20 +77,26 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=ser_it(self.types))

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=self.types)

def __call__(self) -> Command:
return super().__call__()


@dataclass()
class Output(Op):
class Output(DataflowOp):
types: tys.TypeRow

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=ser_it(self.types))

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[])


@dataclass()
class Custom(Op):
class Custom(DataflowOp):
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
Expand All @@ -87,9 +117,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
args=ser_it(self.args),
)

def outer_signature(self) -> tys.FunctionType:
return self.signature


@dataclass()
class MakeTuple(Op):
class MakeTuple(DataflowOp):
types: tys.TypeRow
num_out: int | None = 1

Expand All @@ -102,9 +135,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)])


@dataclass()
class UnpackTuple(Op):
class UnpackTuple(DataflowOp):
types: tys.TypeRow

@property
Expand All @@ -120,14 +156,16 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)

def outer_signature(self) -> tys.FunctionType:
return MakeTuple(self.types).outer_signature().flip()

ss2165 marked this conversation as resolved.
Show resolved Hide resolved

class DfParentOp(Op, Protocol):
def input_types(self) -> tys.TypeRow: ...
def output_types(self) -> tys.TypeRow: ...
def inner_signature(self) -> tys.FunctionType: ...


@dataclass()
class DFG(DfParentOp):
class DFG(DfParentOp, DataflowOp):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
Expand All @@ -140,15 +178,15 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
signature=self.signature.to_serial(),
)

def input_types(self) -> tys.TypeRow:
return self.signature.input
def inner_signature(self) -> tys.FunctionType:
return self.signature

def output_types(self) -> tys.TypeRow:
return self.signature.output
def outer_signature(self) -> tys.FunctionType:
return self.signature


@dataclass()
class CFG(Op):
class CFG(DataflowOp):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
Expand All @@ -161,6 +199,9 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG:
signature=self.signature.to_serial(),
)

def outer_signature(self) -> tys.FunctionType:
return self.signature


@dataclass
class DataflowBlock(DfParentOp):
Expand All @@ -182,11 +223,13 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock:
extension_delta=self.extension_delta,
)

def input_types(self) -> tys.TypeRow:
return self.inputs
def inner_signature(self) -> tys.FunctionType:
return tys.FunctionType(
input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs]
)

def output_types(self) -> tys.TypeRow:
return [tys.Sum(self.sum_rows), *self.other_outputs]
def port_kind(self, port: _Port) -> tys.Kind:
return tys.CFKind()


@dataclass
Expand All @@ -199,3 +242,6 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock:
parent=parent.idx,
cfg_outputs=ser_it(self.cfg_outputs),
)

def port_kind(self, port: _Port) -> tys.Kind:
return tys.CFKind()
29 changes: 29 additions & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def to_serial(self) -> stys.FunctionType:
def empty(cls) -> FunctionType:
return cls(input=[], output=[])

def flip(self) -> FunctionType:
return FunctionType(input=self.output, output=self.input)
mark-koch marked this conversation as resolved.
Show resolved Hide resolved


@dataclass(frozen=True)
class PolyFuncType(Type):
Expand Down Expand Up @@ -270,3 +273,29 @@ def to_serial(self) -> stys.Qubit:
Qubit = QubitDef()
Bool = UnitSum(size=2)
Unit = UnitSum(size=1)


@dataclass(frozen=True)
class ValueKind:
ty: Type


@dataclass(frozen=True)
class ConstKind:
ty: Type


@dataclass(frozen=True)
class FunctionKind:
ty: PolyFuncType


@dataclass(frozen=True)
class CFKind: ...


@dataclass(frozen=True)
class OrderKind: ...


Kind = ValueKind | ConstKind | FunctionKind | CFKind | OrderKind