Skip to content

Commit

Permalink
progress integrating store mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed May 31, 2024
1 parent 9d4e9d9 commit 81405b0
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 50 deletions.
40 changes: 24 additions & 16 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.abc.codec import Codec
from zarr.array import Array, AsyncArray
from zarr.buffer import NDArrayLike
from zarr.common import JSON, MEMORY_ORDER, ChunkCoords, ZarrFormat
from zarr.common import JSON, MEMORY_ORDER, ChunkCoords, OpenMode, ZarrFormat
from zarr.group import AsyncGroup
from zarr.metadata import ArrayV2Metadata, ArrayV3Metadata, ChunkKeyEncoding
from zarr.store import (
Expand Down Expand Up @@ -129,7 +129,7 @@ async def load(
async def open(
*,
store: StoreLike | None = None,
mode: str | None = None, # type and value changed
mode: OpenMode | None = None, # type and value changed
zarr_version: ZarrFormat | None = None, # deprecated
zarr_format: ZarrFormat | None = None,
path: str | None = None,
Expand Down Expand Up @@ -164,10 +164,8 @@ async def open(
"zarr_version is deprecated, use zarr_format", DeprecationWarning, stacklevel=2
)
zarr_format = zarr_version
if mode is not None:
warnings.warn("mode is ignored", RuntimeWarning, stacklevel=2)

store_path = make_store_path(store)
store_path = make_store_path(store, mode=mode)

if path is not None:
store_path = store_path / path
Expand Down Expand Up @@ -252,10 +250,17 @@ async def save_array(
if zarr_format is None:
zarr_format = 3 # default via config?

store_path = make_store_path(store)
store_path = make_store_path(store, mode="w")
if path is not None:
store_path = store_path / path
new = await AsyncArray.create(store_path, zarr_format=zarr_format, **kwargs)
new = await AsyncArray.create(
store_path,
zarr_format=zarr_format,
shape=arr.shape,
dtype=arr.dtype,
chunks=arr.shape,
**kwargs,
)
await new.setitem(slice(None), arr)


Expand Down Expand Up @@ -295,7 +300,8 @@ async def save_group(
for i, arr in enumerate(args):
aws.append(save_array(store, arr, zarr_format=zarr_format, path=f"{path}/arr_{i}"))
for k, arr in kwargs.items():
aws.append(save_array(store, arr, zarr_format=zarr_format, path=f"{path}/{k}"))
path = f"{path}/{k}" if path is not None else k
aws.append(save_array(store, arr, zarr_format=zarr_format, path=path))
await asyncio.gather(*aws)


Expand Down Expand Up @@ -428,7 +434,7 @@ async def group(
async def open_group(
*, # Note: this is a change from v2
store: StoreLike | None = None,
mode: str | None = None, # not used
mode: OpenMode | None = None, # not used
cache_attrs: bool | None = None, # not used, default changed
synchronizer: Any = None, # not used
path: str | None = None,
Expand Down Expand Up @@ -480,8 +486,6 @@ async def open_group(
if zarr_format is None:
zarr_format = 3 # default from config?

if mode is not None:
warnings.warn("mode is not yet implemented", RuntimeWarning, stacklevel=2)
if cache_attrs is not None:
warnings.warn("cache_attrs is not yet implemented", RuntimeWarning, stacklevel=2)
if synchronizer is not None:
Expand All @@ -493,7 +497,7 @@ async def open_group(
if storage_options is not None:
warnings.warn("storage_options is not yet implemented", RuntimeWarning, stacklevel=2)

store_path = make_store_path(store)
store_path = make_store_path(store, mode=mode)
if path is not None:
store_path = store_path / path

Expand All @@ -508,7 +512,6 @@ async def open_group(
)


# TODO: require kwargs
async def create(
shape: ShapeLike,
*, # Note: this is a change from v2
Expand Down Expand Up @@ -680,7 +683,7 @@ async def create(
if meta_array is not None:
warnings.warn("meta_array is not yet implemented", RuntimeWarning, stacklevel=2)

store_path = make_store_path(store)
store_path = make_store_path(store, mode="w")
if path is not None:
store_path = store_path / path

Expand Down Expand Up @@ -801,9 +804,14 @@ async def open_array(
)

try:
print(store_path)
return await AsyncArray.open(store_path, zarr_format=zarr_format)
except KeyError:
pass
except KeyError as e:
print(e, type(e))
if store_path.store.writeable:
pass
else:
raise e

# if array was not found, create it
return await create(store=store, path=path, zarr_format=zarr_format, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/zarr/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import zarr.api.asynchronous as async_api
from zarr.array import Array
from zarr.buffer import NDArrayLike
from zarr.common import JSON, ZarrFormat
from zarr.common import JSON, OpenMode, ZarrFormat
from zarr.group import Group
from zarr.store import StoreLike
from zarr.sync import sync
Expand Down Expand Up @@ -66,7 +66,7 @@ def load(
def open(
*,
store: StoreLike | None = None,
mode: str | None = None, # type and value changed
mode: OpenMode | None = None, # type and value changed
zarr_version: ZarrFormat | None = None, # deprecated
zarr_format: ZarrFormat | None = None,
path: str | None = None,
Expand Down Expand Up @@ -297,7 +297,7 @@ def group(
def open_group(
*, # Note: this is a change from v2
store: StoreLike | None = None,
mode: str | None = None, # not used in async api
mode: OpenMode | None = None, # not used in async api
cache_attrs: bool | None = None, # default changed, not used in async api
synchronizer: Any = None, # not used in async api
path: str | None = None,
Expand Down
13 changes: 13 additions & 0 deletions src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,24 @@ async def open(
store: StoreLike,
zarr_format: ZarrFormat | None = 3,
) -> AsyncArray:
print(f"store: {store}")
store_path = make_store_path(store)
print(f"store_path: {store_path}")

if zarr_format == 2:
print("^^^^^^", (store_path / ZARR_JSON))
zarray_bytes, zattrs_bytes = await gather(
(store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get()
)
if zarray_bytes is None:
raise KeyError(store_path) # filenotfounderror?
elif zarr_format == 3:
print("*******", (store_path / ZARR_JSON))
zarr_json_bytes = await (store_path / ZARR_JSON).get()
if zarr_json_bytes is None:
raise KeyError(store_path) # filenotfounderror?
elif zarr_format is None:
print("$$$$$$", (store_path / ZARR_JSON))
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
(store_path / ZARR_JSON).get(),
(store_path / ZARRAY_JSON).get(),
Expand Down Expand Up @@ -355,6 +360,10 @@ def dtype(self) -> np.dtype[Any]:
def attrs(self) -> dict[str, JSON]:
return self.metadata.attributes

@property
def read_only(self) -> bool:
return bool(~self.store_path.store.writeable)

async def getitem(
self, selection: Selection, *, factory: Factory.Create = NDBuffer.create
) -> NDArrayLike:
Expand Down Expand Up @@ -582,6 +591,10 @@ def store_path(self) -> StorePath:
def order(self) -> Literal["C", "F"]:
return self._async_array.order

@property
def read_only(self) -> bool:
return self._async_array.read_only

def __getitem__(self, selection: Selection) -> NDArrayLike:
return sync(
self._async_array.getitem(selection),
Expand Down
21 changes: 20 additions & 1 deletion src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,24 @@ def merge_with_morton_order(
break
return obj

@classmethod
def merge_with_c_order(
cls,
chunks_per_shard: ChunkCoords,
tombstones: set[ChunkCoords],
*shard_dicts: ShardMapping,
) -> _ShardBuilder:
obj = cls.create_empty(chunks_per_shard)
for chunk_coords in c_order_iter(chunks_per_shard):
if tombstones is not None and chunk_coords in tombstones:
continue
for shard_dict in shard_dicts:
maybe_value = shard_dict.get(chunk_coords, None)
if maybe_value is not None:
obj[chunk_coords] = maybe_value
break
return obj

@classmethod
def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardBuilder:
obj = cls()
Expand Down Expand Up @@ -284,7 +302,8 @@ async def finalize(
index_location: ShardingCodecIndexLocation,
index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]],
) -> Buffer:
shard_builder = _ShardBuilder.merge_with_morton_order(
print("merging shards with c order")
shard_builder = _ShardBuilder.merge_with_c_order(
self.new_dict.index.chunks_per_shard,
self.tombstones,
self.new_dict,
Expand Down
32 changes: 4 additions & 28 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,39 +67,15 @@ def make_store_path(store_like: StoreLike | None, *, mode: OpenMode | None = Non
if mode is not None:
assert mode == store_like.store.mode
return store_like
elif store_like is None:
if mode is None:
mode = "r"
return StorePath(MemoryStore(mode=mode))
elif isinstance(store_like, Store):
if mode is not None:
assert mode == store_like.mode
return StorePath(store_like)
elif store_like is None:
if mode is None:
mode = "r"
return StorePath(MemoryStore(mode=mode))
elif isinstance(store_like, str):
assert mode is not None
return StorePath(LocalStore(Path(store_like), mode=mode))
raise TypeError


def _normalize_interval_index(
data: Buffer, interval: None | tuple[int | None, int | None]
) -> tuple[int, int]:
"""
Convert an implicit interval into an explicit start and length
"""
if interval is None:
start = 0
length = len(data)
else:
maybe_start, maybe_len = interval
if maybe_start is None:
start = 0
else:
start = maybe_start

if maybe_len is None:
length = len(data) - start
else:
length = maybe_len

return (start, length)
2 changes: 1 addition & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode, concurrent_map
from zarr.store.core import _normalize_interval_index
from zarr.store.utils import _normalize_interval_index


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.store.core import _normalize_interval_index
from zarr.store.utils import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal

S = TypeVar("S", bound=Store)
Expand Down

0 comments on commit 81405b0

Please sign in to comment.