Skip to content

Commit

Permalink
add test for functionvalue and complex sum
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 19, 2024
1 parent 8ba5de5 commit f5d4390
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
19 changes: 13 additions & 6 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import sys
from abc import ABC, abstractmethod
from typing import Any, Literal
from typing import Any, Literal, TYPE_CHECKING

from pydantic import Field, RootModel, ConfigDict

Expand All @@ -22,6 +22,9 @@
)
from hugr.utils import deser_it

if TYPE_CHECKING:
from hugr.serialization.serial_hugr import SerialHugr

NodeID = int


Expand Down Expand Up @@ -101,10 +104,12 @@ class FunctionValue(BaseValue):
"""A higher-order function value."""

v: Literal["Function"] = Field(default="Function", title="ValueTag")
hugr: Any # TODO
hugr: SerialHugr

def deserialize(self) -> _val.Value:
return _val.Function(self.hugr)
from hugr._hugr import Hugr

return _val.Function(Hugr.from_serial(self.hugr))


class TupleValue(BaseValue):
Expand Down Expand Up @@ -604,7 +609,9 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True):
+ tys_classes
)

tys_model_rebuild(dict(classes))
# needed to avoid circular imports
from hugr import _ops # noqa: E402
from hugr import _val # noqa: E402
from hugr.serialization.serial_hugr import SerialHugr # noqa: E402

from hugr import _ops # noqa: E402 # needed to avoid circular imports
from hugr import _val # noqa: E402 # needed to avoid circular imports
tys_model_rebuild(dict(classes))
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/serialization/serial_hugr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Any, Literal

from pydantic import Field, ConfigDict

from .ops import NodeID, OpType, classes as ops_classes
from .tys import model_rebuild, ConfiguredBaseModel
import hugr

NodeID = int
Port = tuple[NodeID, int | None] # (node, offset)
Edge = tuple[Port, Port]

Expand Down Expand Up @@ -48,3 +49,6 @@ def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
"required": ["version", "nodes", "edges"],
},
)


from .ops import OpType, classes as ops_classes # noqa: E402 # needed to avoid circular import
21 changes: 19 additions & 2 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,15 @@ def test_stable_indices():
assert h._free_nodes == []


def test_simple_id():
def simple_id() -> Dfg:
h = Dfg(tys.Qubit, tys.Qubit)
a, b = h.inputs()
h.set_outputs(a, b)
return h

_validate(h.hugr)

def test_simple_id():
_validate(simple_id().hugr)


def test_multiport():
Expand Down Expand Up @@ -266,3 +269,17 @@ def test_ancestral_sibling():
nt = nested.add(Not(a))

assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node


@pytest.mark.parametrize(
"val",
[
val.Function(simple_id().hugr),
val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [IntVal(34)]),
],
)
def test_vals(val: val.Value):
d = Dfg()
d.set_outputs(d.add_load_const(val))

_validate(d.hugr)

0 comments on commit f5d4390

Please sign in to comment.