Skip to content

Commit

Permalink
fix: replace tests that went missing in zarr-developers#2006
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Sep 17, 2024
1 parent b1ecdd5 commit 6b6cc3a
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 3 deletions.
154 changes: 153 additions & 1 deletion tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import numpy as np
import pytest

import zarr
from zarr import Array, AsyncArray, AsyncGroup, Group
from zarr.abc.store import Store
from zarr.core.buffer import default_buffer_prototype
from zarr.core.common import JSON, ZarrFormat
from zarr.core.group import GroupMetadata
from zarr.core.sync import sync
from zarr.errors import ContainsArrayError, ContainsGroupError
from zarr.store import LocalStore, StorePath
from zarr.store import LocalStore, MemoryStore, StorePath
from zarr.store.common import make_store_path

from .conftest import parse_store
Expand Down Expand Up @@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) ->
actual = pickle.loads(p)

assert actual == expected


async def test_group_members_async(store: LocalStore | MemoryStore) -> None:
group = AsyncGroup(
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", shape=(1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", shape=(1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", shape=(1,))
g2 = await g1.create_group("g2")

# immediate children
children = sorted([x async for x in group.members()], key=lambda x: x[0])
assert children == [
("a0", a0),
("g0", g0),
]

nmembers = await group.nmembers()
assert nmembers == 2

# partial
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
]
assert children == expected
nmembers = await group.nmembers(max_depth=1)
assert nmembers == 4

# all children
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
("g0/g1/a2", a2),
("g0/g1/g2", g2),
]
assert all_children == expected

nmembers = await group.nmembers(max_depth=None)
assert nmembers == 6

with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]


async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)

# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})

# test that we can get the group using require_group
foo_group = await root.require_group("foo")
assert foo_group.attrs == {"foo": 100}

# test that we can get the group using require_group and overwrite=True
foo_group = await root.require_group("foo", overwrite=True)

_ = await foo_group.create_array(
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
)

# test that overwriting a group w/ children fails
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
#
# with pytest.raises(ContainsArrayError):
# await root.require_group("foo", overwrite=True)

# test that requiring a group where an array is fails
with pytest.raises(TypeError):
await foo_group.require_group("bar")


async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})
# create bar group
_ = await root.create_group("bar", attributes={"bar": 200})

foo_group, bar_group = await root.require_groups("foo", "bar")
assert foo_group.attrs == {"foo": 100}
assert bar_group.attrs == {"bar": 200}

# get a mix of existing and new groups
foo_group, spam_group = await root.require_groups("foo", "spam")
assert foo_group.attrs == {"foo": 100}
assert spam_group.attrs == {}

# no names
no_group = await root.require_groups()
assert no_group == ()


async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
with pytest.warns(DeprecationWarning):
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
assert foo.shape == (10,)

with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
await root.create_dataset("foo", shape=(100,), dtype="int8")

_ = await root.create_group("bar")
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):
await root.create_dataset("bar", shape=(100,), dtype="int8")


async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
assert foo1.attrs == {"foo": 101}
foo2 = await root.require_array("foo", shape=(10,), dtype="i8")
assert foo2.attrs == {"foo": 101}

# exact = False
_ = await root.require_array("foo", shape=10, dtype="f8")

# errors w/ exact True
with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_array("foo", shape=(10,), dtype="f8", exact=True)

with pytest.raises(TypeError, match="Incompatible shape"):
await root.require_array("foo", shape=(100, 100), dtype="i8")

with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_array("foo", shape=(10,), dtype="f4")

_ = await root.create_group("bar")
with pytest.raises(TypeError, match="Incompatible object"):
await root.require_array("bar", shape=(10,), dtype="int8")


async def test_open_mutable_mapping():
group = await zarr.api.asynchronous.open_group(store={}, mode="w")
assert isinstance(group.store_path.store, MemoryStore)


def test_open_mutable_mapping_sync():
group = zarr.open_group(store={}, mode="w")
assert isinstance(group.store_path.store, MemoryStore)
40 changes: 38 additions & 2 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import pytest

from zarr.core.buffer import Buffer, cpu
from zarr.store.memory import MemoryStore
from zarr.core.buffer import Buffer, cpu, gpu
from zarr.store.memory import GpuMemoryStore, MemoryStore
from zarr.testing.store import StoreTests
from zarr.testing.utils import gpu_test


class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
Expand Down Expand Up @@ -56,3 +57,38 @@ def test_serizalizable_store(self, store: MemoryStore) -> None:

with pytest.raises(NotImplementedError):
pickle.dumps(store)


@gpu_test
class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]):
store_cls = GpuMemoryStore
buffer_cls = gpu.Buffer

def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
store._store_dict[key] = value

def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(scope="function", params=[None, {}])
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
return {"store_dict": request.param, "mode": "r+"}

@pytest.fixture(scope="function")
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
return self.store_cls(**store_kwargs)

def test_store_repr(self, store: GpuMemoryStore) -> None:
assert str(store) == f"gpumemory://{id(store._store_dict)}"

def test_store_supports_writes(self, store: GpuMemoryStore) -> None:
assert store.supports_writes

def test_store_supports_listing(self, store: GpuMemoryStore) -> None:
assert store.supports_listing

def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:
assert store.supports_partial_writes

def test_list_prefix(self, store: GpuMemoryStore) -> None:
assert True

0 comments on commit 6b6cc3a

Please sign in to comment.