Skip to content

Commit

Permalink
feat(py): add ArrayVal
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 12, 2024
1 parent 6f035d6 commit 8c7d0d1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
27 changes: 25 additions & 2 deletions hugr-py/src/hugr/std/collections/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from dataclasses import dataclass

import hugr.tys as tys
from hugr import tys, val
from hugr.std import _load_extension
from hugr.utils import comma_sep_str

EXTENSION = _load_extension("collections.array")

Expand All @@ -14,7 +15,7 @@
class Array(tys.ExtType):
"""Fixed `size` array of `ty` elements."""

def __init__(self, ty: tys.Type, size: int | tys.BoundedNatArg) -> None:
def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None:
if isinstance(size, int):
size = tys.BoundedNatArg(size)

Expand Down Expand Up @@ -52,3 +53,25 @@ def size(self) -> int | None:

def type_bound(self) -> tys.TypeBound:
return self.ty.type_bound()


@dataclass
class ArrayVal(val.ExtensionValue):
"""Constant value for a statically sized array of elements."""

v: list[val.Value]
ty: tys.Type

def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None:
self.v = v
self.ty = Array(elem_ty, len(v))

def to_value(self) -> val.Extension:
name = "ArrayValue"

Check warning on line 70 in hugr-py/src/hugr/std/collections/array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/array.py#L70

Added line #L70 was not covered by tests
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
vs = [v._to_serial_root() for v in self.v]
return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name])

Check warning on line 74 in hugr-py/src/hugr/std/collections/array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/array.py#L73-L74

Added lines #L73 - L74 were not covered by tests

def __str__(self) -> str:
return f"array({comma_sep_str(self.v)})"

Check warning on line 77 in hugr-py/src/hugr/std/collections/array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/array.py#L77

Added line #L77 was not covered by tests
6 changes: 5 additions & 1 deletion hugr-py/tests/test_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from hugr import val
from hugr.std.collections.array import Array
from hugr.std.collections.array import Array, ArrayVal
from hugr.std.collections.list import List, ListVal
from hugr.std.float import FLOAT_T
from hugr.std.int import INT_T, _int_tv
Expand Down Expand Up @@ -170,3 +170,7 @@ def test_array():
ls = Array(ty_var, len_var)
assert ls.ty == ty_var
assert ls.size is None

ar_val = ArrayVal([val.TRUE, val.FALSE], Bool)
assert ar_val.v == [val.TRUE, val.FALSE]
assert ar_val.ty == Array(Bool, 2)

0 comments on commit 8c7d0d1

Please sign in to comment.