From f5d4390f85e8e3a21e17ac91b3554b46b3f2957e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 19 Jun 2024 14:28:11 +0100 Subject: [PATCH] add test for functionvalue and complex sum --- hugr-py/src/hugr/serialization/ops.py | 19 +++++++++++------ hugr-py/src/hugr/serialization/serial_hugr.py | 6 +++++- hugr-py/tests/test_hugr_build.py | 21 +++++++++++++++++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 39d2a5676..d76b13e84 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -2,7 +2,7 @@ import inspect import sys from abc import ABC, abstractmethod -from typing import Any, Literal +from typing import Any, Literal, TYPE_CHECKING from pydantic import Field, RootModel, ConfigDict @@ -22,6 +22,9 @@ ) from hugr.utils import deser_it +if TYPE_CHECKING: + from hugr.serialization.serial_hugr import SerialHugr + NodeID = int @@ -101,10 +104,12 @@ class FunctionValue(BaseValue): """A higher-order function value.""" v: Literal["Function"] = Field(default="Function", title="ValueTag") - hugr: Any # TODO + hugr: SerialHugr def deserialize(self) -> _val.Value: - return _val.Function(self.hugr) + from hugr._hugr import Hugr + + return _val.Function(Hugr.from_serial(self.hugr)) class TupleValue(BaseValue): @@ -604,7 +609,9 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): + tys_classes ) -tys_model_rebuild(dict(classes)) +# needed to avoid circular imports +from hugr import _ops # noqa: E402 +from hugr import _val # noqa: E402 +from hugr.serialization.serial_hugr import SerialHugr # noqa: E402 -from hugr import _ops # noqa: E402 # needed to avoid circular imports -from hugr import _val # noqa: E402 # needed to avoid circular imports +tys_model_rebuild(dict(classes)) diff --git a/hugr-py/src/hugr/serialization/serial_hugr.py b/hugr-py/src/hugr/serialization/serial_hugr.py index 49bfbd2f7..1dc60176b 100644 --- a/hugr-py/src/hugr/serialization/serial_hugr.py +++ b/hugr-py/src/hugr/serialization/serial_hugr.py @@ -1,11 +1,12 @@ +from __future__ import annotations from typing import Any, Literal from pydantic import Field, ConfigDict -from .ops import NodeID, OpType, classes as ops_classes from .tys import model_rebuild, ConfiguredBaseModel import hugr +NodeID = int Port = tuple[NodeID, int | None] # (node, offset) Edge = tuple[Port, Port] @@ -48,3 +49,6 @@ def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs): "required": ["version", "nodes", "edges"], }, ) + + +from .ops import OpType, classes as ops_classes # noqa: E402 # needed to avoid circular import diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 79dc24d0e..f37f283bd 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -128,12 +128,15 @@ def test_stable_indices(): assert h._free_nodes == [] -def test_simple_id(): +def simple_id() -> Dfg: h = Dfg(tys.Qubit, tys.Qubit) a, b = h.inputs() h.set_outputs(a, b) + return h - _validate(h.hugr) + +def test_simple_id(): + _validate(simple_id().hugr) def test_multiport(): @@ -266,3 +269,17 @@ def test_ancestral_sibling(): nt = nested.add(Not(a)) assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node + + +@pytest.mark.parametrize( + "val", + [ + val.Function(simple_id().hugr), + val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [IntVal(34)]), + ], +) +def test_vals(val: val.Value): + d = Dfg() + d.set_outputs(d.add_load_const(val)) + + _validate(d.hugr)