Skip to content

Commit

Permalink
poc: concurrently stream from list_dir -> getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Apr 10, 2024
1 parent 7dff5e5 commit 94b00c7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 24 deletions.
31 changes: 16 additions & 15 deletions src/zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ async def getitem(
) -> AsyncArray | AsyncGroup:

store_path = self.store_path / key
logger.warning("key=%s, store_path=%s", key, store_path)

# Note:
# in zarr-python v2, we first check if `key` references an Array, else if `key` references
Expand Down Expand Up @@ -305,20 +304,22 @@ async def children(self) -> AsyncGenerator[AsyncArray | AsyncGroup, None]:

raise ValueError(msg)

async for key in self.store_path.store.list_dir(self.store_path.path):
# these keys are not valid child names so we make sure to skip them
# TODO: it would be nice to make these special keys accessible programmatically,
# and scoped to specific zarr versions
if key not in ("zarr.json", ".zgroup", ".zattrs"):
try:
# TODO: performance optimization -- batch
print(key)
child = await self.getitem(key)
# keyerror is raised when `subkey``names an object in the store
# in which case `subkey` cannot be the name of a sub-array or sub-group.
yield child
except KeyError:
pass
# leaving these imports here for demo purposes
from aiostream import stream, async_, pipe
from aiostream.aiter_utils import aitercontext

children = (
stream.iterate(self.store_path.store.list_dir(self.store_path.path))
| pipe.filter(lambda x: x not in ("zarr.json", ".zgroup", ".zattrs"))
|
# TODO: need to handle directories without a metadata doc
# previously, we gracefully ignored them by catching the KeyError here.
pipe.map(async_(self.getitem))
)

async with aitercontext(children) as safe_children:
async for child in safe_children:
yield child

async def contains(self, child: str) -> bool:
raise NotImplementedError
Expand Down
5 changes: 1 addition & 4 deletions src/zarr/v3/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ async def list(self) -> AsyncGenerator[str, None]:
if p.is_file():
yield str(p)


async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
Expand All @@ -170,7 +169,6 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
if p.is_file():
yield str(p)


async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
Expand All @@ -186,12 +184,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
base = self.root / prefix
to_strip = str(base) + "/"

try:
key_iter = base.iterdir()
except (FileNotFoundError, NotADirectoryError):
key_iter = []

for key in key_iter:
yield str(key).replace(to_strip, "")

2 changes: 0 additions & 2 deletions src/zarr/v3/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
yield key

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
print('prefix', prefix)
print('keys in list_dir', list(self._store_dict))
for key in self._store_dict:
if key.startswith(prefix + "/") and key != prefix:
yield key.strip(prefix + "/").rsplit("/", maxsplit=1)[0]
4 changes: 1 addition & 3 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_group_children(store: MemoryStore | LocalStore):
# if group.children guarantees a particular order for the children.
# If order is not guaranteed, then the better version of this test is
# to compare two sets, but presently neither the group nor array classes are hashable.
print('getting children')
print("getting children")
observed = group.children
print(observed)
print(list([subgroup, subarray, implicit_subgroup]))
Expand All @@ -66,8 +66,6 @@ def test_group_children(store: MemoryStore | LocalStore):
assert subgroup in observed




@pytest.mark.parametrize("store", (("local", "memory")), indirect=["store"])
def test_group(store: MemoryStore | LocalStore) -> None:
store_path = StorePath(store)
Expand Down

0 comments on commit 94b00c7

Please sign in to comment.