diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index d59b00af3..f7638e4f7 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -4,8 +4,9 @@ from dataclasses import dataclass -import hugr.tys as tys +from hugr import tys, val from hugr.std import _load_extension +from hugr.utils import comma_sep_str EXTENSION = _load_extension("collections.array") @@ -14,7 +15,7 @@ class Array(tys.ExtType): """Fixed `size` array of `ty` elements.""" - def __init__(self, ty: tys.Type, size: int | tys.BoundedNatArg) -> None: + def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None: if isinstance(size, int): size = tys.BoundedNatArg(size) @@ -52,3 +53,25 @@ def size(self) -> int | None: def type_bound(self) -> tys.TypeBound: return self.ty.type_bound() + + +@dataclass +class ArrayVal(val.ExtensionValue): + """Constant value for a statically sized array of elements.""" + + v: list[val.Value] + ty: tys.Type + + def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: + self.v = v + self.ty = Array(elem_ty, len(v)) + + def to_value(self) -> val.Extension: + name = "ArrayValue" + # The value list must be serialized at this point, otherwise the + # `Extension` value would not be serializable. + vs = [v._to_serial_root() for v in self.v] + return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name]) + + def __str__(self) -> str: + return f"array({comma_sep_str(self.v)})" diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index bca301182..e12e6fb0b 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -3,7 +3,7 @@ import pytest from hugr import val -from hugr.std.collections.array import Array +from hugr.std.collections.array import Array, ArrayVal from hugr.std.collections.list import List, ListVal from hugr.std.float import FLOAT_T from hugr.std.int import INT_T, _int_tv @@ -170,3 +170,7 @@ def test_array(): ls = Array(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None + + ar_val = ArrayVal([val.TRUE, val.FALSE], Bool) + assert ar_val.v == [val.TRUE, val.FALSE] + assert ar_val.ty == Array(Bool, 2)