diff --git a/sharding_test.py b/sharding_test.py new file mode 100644 index 0000000000..8e0e7ff18b --- /dev/null +++ b/sharding_test.py @@ -0,0 +1,55 @@ +import json +import os + +import zarr + +store = zarr.DirectoryStore("data/chunking_test.zarr") +z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None) +z[:10, :] = 42 +z[15, 1] = 389 +z[19, 2] = 1 +z[0, 1] = -4.2 + +print(store[".zarray"].decode()) +# { +# "chunks": [ +# 3, +# 2 +# ], +# "compressor": null, +# "dtype": " Tuple[int, ...]: + return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.array._shards)) + + def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]: + localized_chunk = self.__localize_chunk__(chunk) + chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk] + if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): + return None + else: + return slice(chunk_start, chunk_start + chunk_len) + + def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None: + localized_chunk = self.__localize_chunk__(chunk) + if chunk_slice is None: + self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64) + else: + self.offsets_and_lengths[localized_chunk] = ( + chunk_slice.start, + chunk_slice.stop - chunk_slice.start + ) + + def to_bytes(self) -> bytes: + return self.offsets_and_lengths.tobytes(order='C') + + @classmethod + def from_bytes( + cls, buffer: Union[bytes, bytearray], array: "Array" + ) -> "_ShardIndex": + return cls( + array=array, + offsets_and_lengths=np.frombuffer( + bytearray(buffer), dtype=" Optional[Tuple[int, ...]]: """A tuple of integers describing the length of each dimension of a - chunk of the array.""" + chunk of the array, or None.""" return self._chunks + @property + def shards(self): + """A tuple of integers describing the number of chunks in each shard + of the array.""" + return self._shards + @property def dtype(self): """The NumPy data type.""" @@ -487,6 +552,38 @@ def write_empty_chunks(self) -> bool: """ return self._write_empty_chunks + @property + def _num_chunks_per_shard(self) -> int: + return reduce(lambda x, y: x*y, self._shards, 1) + + def __keys_to_shard_groups__( + self, keys: Iterable[str] + ) -> Dict[str, List[Tuple[str, Tuple[int, ...]]]]: + shard_indices_per_shard_key = defaultdict(list) + for chunk_key in keys: + # TODO: allow to be in a group (aka only use last parts for dimensions) + chunk_subkeys = tuple(map(int, chunk_key.split(self._dimension_separator))) + shard_key_tuple = ( + subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards) + ) + shard_key = self._dimension_separator.join(map(str, shard_key_tuple)) + shard_indices_per_shard_key[shard_key].append((chunk_key, chunk_subkeys)) + return shard_indices_per_shard_key + + def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex: + # At the end of each shard 2*64bit per chunk for offset and length define the index: + return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self) + + def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]: + # TODO: allow to be in a group (aka only use last parts for dimensions) + shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator))) + for chunk_offset in itertools.product(*(range(i) for i in self._shards)): + yield tuple( + shard_key_i * shards_i + offset_i + for shard_key_i, offset_i, shards_i + in zip(shard_key_tuple, chunk_offset, self._shards) + ) + def __eq__(self, other): return ( isinstance(other, Array) and @@ -857,7 +954,7 @@ def _get_basic_selection_zd(self, selection, out=None, fields=None): try: # obtain encoded data for chunk ckey = self._chunk_key((0,)) - cdata = self.chunk_store[ckey] + cdata = self._read_single_possibly_sharded(ckey) except KeyError: # chunk not initialized @@ -1170,8 +1267,10 @@ def _get_selection(self, indexer, out=None, fields=None): check_array_shape('out', out, out_shape) # iterate over chunks - if not hasattr(self.chunk_store, "getitems") or \ - any(map(lambda x: x == 0, self.shape)): + if self._shards is None and ( + not hasattr(self.chunk_store, "getitems") + or any(map(lambda x: x == 0, self.shape)) + ): # sequentially get one key at a time from storage for chunk_coords, chunk_selection, out_selection in indexer: @@ -1640,7 +1739,7 @@ def _set_basic_selection_zd(self, selection, value, fields=None): # setup chunk try: # obtain compressed data for chunk - cdata = self.chunk_store[ckey] + cdata = self._read_single_possibly_sharded(ckey) except KeyError: # chunk not initialized @@ -1708,8 +1807,11 @@ def _set_selection(self, indexer, value, fields=None): check_array_shape('value', value, sel_shape) # iterate over chunks in range - if not hasattr(self.store, "setitems") or self._synchronizer is not None \ - or any(map(lambda x: x == 0, self.shape)): + if self._shards is None and ( + not hasattr(self.chunk_store, "setitems") + or self._synchronizer is not None + or any(map(lambda x: x == 0, self.shape)) + ): # iterative approach for chunk_coords, chunk_selection, out_selection in indexer: @@ -1883,6 +1985,21 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection, self._process_chunk(out, cdata, chunk_selection, drop_axes, out_is_ndarray, fields, out_selection) + def _read_single_possibly_sharded(self, ckey): + if self._shards is None: + return self.chunk_store[ckey] + else: + shard_key, chunks_in_shard = next(iter(self.__keys_to_shard_groups__([ckey]).items())) + # TODO use partial read if available + full_shard_value = self.chunk_store[shard_key] + index = self.__get_index__(full_shard_value) + for _chunk_key, chunk_subkeys in chunks_in_shard: + chunk_slice = index.get_chunk_slice(chunk_subkeys) + if chunk_slice is None: + raise KeyError + else: + return full_shard_value[chunk_slice] + def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection, drop_axes=None, fields=None): """As _chunk_getitem, but for lists of chunks @@ -1904,6 +2021,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection, and hasattr(self._compressor, "decode_partial") and not fields and self.dtype != object + # TODO: this should rather check for read_block or similar and hasattr(self.chunk_store, "getitems") ): partial_read_decode = True @@ -1914,7 +2032,16 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection, } else: partial_read_decode = False - cdatas = self.chunk_store.getitems(ckeys, on_error="omit") + cdatas = {} + for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(ckeys).items(): + # TODO use partial read if available + full_shard_value = self.chunk_store[shard_key] + index = self.__get_index__(full_shard_value) + for chunk_key, chunk_subkeys in chunks_in_shard: + chunk_slice = index.get_chunk_slice(chunk_subkeys) + if chunk_slice is not None: + cdatas[chunk_key] = full_shard_value[chunk_slice] + for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection): if ckey in cdatas: self._process_chunk( @@ -1948,11 +2075,40 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None): to_store = {k: self._encode_chunk(cdatas[k]) for k in nonempty_keys} else: to_store = {k: self._encode_chunk(v) for k, v in cdatas.items()} - self.chunk_store.setitems(to_store) + + for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(to_store.keys()).items(): + all_chunks = set(self.__get_chunks_in_shard(shard_key)) + chunks_to_set = set(chunk_subkeys for _chunk_key, chunk_subkeys in chunks_in_shard) + chunks_to_read = all_chunks - chunks_to_set + new_content = { + chunk_subkeys: to_store[chunk_key] for chunk_key, chunk_subkeys in chunks_in_shard + } + try: + # TODO use partial read if available + full_shard_value = self.chunk_store[shard_key] + except KeyError: + index = _ShardIndex.create_empty(self) + else: + index = self.__get_index__(full_shard_value) + for chunk_to_read in chunks_to_read: + chunk_slice = index.get_chunk_slice(chunk_to_read) + if chunk_slice is not None: + new_content[chunk_to_read] = full_shard_value[chunk_slice] + + # TODO use partial write if available and possible (e.g. at the end) + shard_content = b"" + # TODO: order the chunks in the shard: + for chunk_subkeys, chunk_content in new_content.items(): + chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content)) + index.set_chunk_slice(chunk_subkeys, chunk_slice) + shard_content += chunk_content + # Appending the index at the end of the shard: + shard_content += index.to_bytes() + self.chunk_store[shard_key] = shard_content def _chunk_delitems(self, ckeys): - if hasattr(self.store, "delitems"): - self.store.delitems(ckeys) + if hasattr(self.chunk_store, "delitems"): + self.chunk_store.delitems(ckeys) else: # pragma: no cover # exempting this branch from coverage as there are no extant stores # that will trigger this condition, but it's possible that they @@ -2028,7 +2184,7 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None): try: # obtain compressed data for chunk - cdata = self.chunk_store[ckey] + cdata = self._read_single_possibly_sharded(ckey) except KeyError: @@ -2239,6 +2395,7 @@ def digest(self, hashname="sha1"): h = hashlib.new(hashname) + # TODO: operate on shards here if available: for i in itertools.product(*[range(s) for s in self.cdata_shape]): h.update(self.chunk_store.get(self._chunk_key(i), b"")) @@ -2365,6 +2522,7 @@ def _resize_nosync(self, *args): except KeyError: # chunk not initialized pass + # TODO: collect all chunks do delete and use _chunk_delitems def append(self, data, axis=0): """Append `data` to `axis`. diff --git a/zarr/creation.py b/zarr/creation.py index 64c5666adb..fc87363ba5 100644 --- a/zarr/creation.py +++ b/zarr/creation.py @@ -1,3 +1,4 @@ +from typing import Tuple, Union from warnings import warn import numpy as np @@ -19,7 +20,9 @@ def create(shape, chunks=True, dtype=None, compressor='default', fill_value=0, order='C', store=None, synchronizer=None, overwrite=False, path=None, chunk_store=None, filters=None, cache_metadata=True, cache_attrs=True, read_only=False, - object_codec=None, dimension_separator=None, write_empty_chunks=True, **kwargs): + object_codec=None, dimension_separator=None, write_empty_chunks=True, + shards: Union[int, Tuple[int, ...], None] = None, + shard_format: str = "indexed", **kwargs): """Create an array. Parameters @@ -145,7 +148,7 @@ def create(shape, chunks=True, dtype=None, compressor='default', init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor, fill_value=fill_value, order=order, overwrite=overwrite, path=path, chunk_store=chunk_store, filters=filters, object_codec=object_codec, - dimension_separator=dimension_separator) + dimension_separator=dimension_separator, shards=shards, shard_format=shard_format) # instantiate array z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer, diff --git a/zarr/meta.py b/zarr/meta.py index c292b09a14..d63be624d3 100644 --- a/zarr/meta.py +++ b/zarr/meta.py @@ -51,6 +51,8 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A object_codec = None dimension_separator = meta.get("dimension_separator", None) + shards = meta.get("shards", None) + shard_format = meta.get("shard_format", None) fill_value = cls.decode_fill_value(meta['fill_value'], dtype, object_codec) meta = dict( zarr_format=meta["zarr_format"], @@ -64,6 +66,10 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A ) if dimension_separator: meta['dimension_separator'] = dimension_separator + if shards: + meta['shards'] = tuple(shards) + assert shard_format is not None + meta['shard_format'] = shard_format except Exception as e: raise MetadataError("error decoding metadata") from e else: @@ -77,6 +83,8 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes: dtype, sdshape = dtype.subdtype dimension_separator = meta.get("dimension_separator") + shards = meta.get("shards") + shard_format = meta.get("shard_format") if dtype.hasobject: import numcodecs object_codec = numcodecs.get_codec(meta['filters'][0]) @@ -95,9 +103,10 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes: ) if dimension_separator: meta['dimension_separator'] = dimension_separator - - if dimension_separator: - meta["dimension_separator"] = dimension_separator + if shards: + meta['shards'] = shards + assert shard_format is not None + meta['shard_format'] = shard_format return json_dumps(meta) diff --git a/zarr/storage.py b/zarr/storage.py index 7f572d35ff..c7f8c7fdeb 100644 --- a/zarr/storage.py +++ b/zarr/storage.py @@ -54,7 +54,7 @@ from zarr.util import (buffer_size, json_loads, nolock, normalize_chunks, normalize_dimension_separator, normalize_dtype, normalize_fill_value, normalize_order, - normalize_shape, normalize_storage_path, retry_call) + normalize_shape, normalize_shards, normalize_storage_path, retry_call) from zarr._storage.absstore import ABSStore # noqa: F401 from zarr._storage.store import (_listdir_from_keys, @@ -236,6 +236,8 @@ def init_array( filters=None, object_codec=None, dimension_separator=None, + shards: Union[int, Tuple[int, ...], None] = None, + shard_format: Optional[str] = None, ): """Initialize an array store with the given configuration. Note that this is a low-level function and there should be no need to call this directly from user code. @@ -353,7 +355,8 @@ def init_array( order=order, overwrite=overwrite, path=path, chunk_store=chunk_store, filters=filters, object_codec=object_codec, - dimension_separator=dimension_separator) + dimension_separator=dimension_separator, + shards=shards, shard_format=shard_format) def _init_array_metadata( @@ -370,6 +373,8 @@ def _init_array_metadata( filters=None, object_codec=None, dimension_separator=None, + shards: Union[int, Tuple[int, ...], None] = None, + shard_format: Optional[str] = None, ): # guard conditions @@ -388,6 +393,8 @@ def _init_array_metadata( shape = normalize_shape(shape) + dtype.shape dtype = dtype.base chunks = normalize_chunks(chunks, shape, dtype.itemsize) + shards = normalize_shards(shards, shape) + shard_format = shard_format or "morton_order" order = normalize_order(order) fill_value = normalize_fill_value(fill_value, dtype) @@ -445,6 +452,9 @@ def _init_array_metadata( compressor=compressor_config, fill_value=fill_value, order=order, filters=filters_config, dimension_separator=dimension_separator) + if shards is not None: + meta["shards"] = shards + meta["shard_format"] = shard_format key = _path_to_prefix(path) + array_meta_key if hasattr(store, '_metadata_class'): store[key] = store._metadata_class.encode_array_metadata(meta) # type: ignore diff --git a/zarr/util.py b/zarr/util.py index 04d350a68d..3b017b98ca 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -14,7 +14,7 @@ from numcodecs.registry import codec_registry from numcodecs.blosc import cbuffer_sizes, cbuffer_metainfo -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast def flatten(arg: Iterable) -> Iterable: @@ -149,6 +149,40 @@ def normalize_chunks( return tuple(chunks) +def normalize_shards( + shards: Union[int, Tuple[Optional[int], ...], None], shape: Tuple[int, ...], +) -> Optional[Tuple[int, ...]]: + """Convenience function to normalize the `shards` argument for an array + with the given `shape`.""" + + # N.B., expect shape already normalized + + if shards is None: + return None + + # handle 1D convenience form + if isinstance(shards, numbers.Integral): + shards = tuple(int(shards) for _ in shape) + shards = cast(Tuple[Optional[int]], shards) + + # handle bad dimensionality + if len(shards) > len(shape): + raise ValueError('too many dimensions in shards') + + # handle underspecified shards + if len(shards) < len(shape): + # assume single shards across remaining dimensions + shards += (1, ) * (len(shape) - len(shards)) + + # handle None or -1 in shards + if -1 in shards or None in shards: + shards = tuple(s if c == -1 or c is None else int(c) + for s, c in zip(shape, shards)) + + shards = cast(Tuple[int], shards) + return tuple(shards) + + def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype, Any]: # convenience API for object arrays @@ -560,6 +594,7 @@ def __init__(self, store_key, chunk_store): # is it fsstore or an actual fsspec map object assert hasattr(self.chunk_store, "map") self.map = self.chunk_store.map + # TODO maybe use partial_read here also self.fs = self.chunk_store.fs self.store_key = store_key self.buff = None