Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from abc import abstractmethod
from abc import abstractmethod, abstractproperty
from collections.abc import Mapping
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar

Expand All @@ -17,6 +17,7 @@

from zarr.abc.store import ByteGetter, ByteSetter, Store
from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import BufferPrototype
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
from zarr.core.indexing import SelectorTuple
Expand Down Expand Up @@ -185,14 +186,23 @@ async def encode(
class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):
"""Base class for array-to-array codecs."""

codec_input: type[NDBuffer]
codec_output: type[NDBuffer]


class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
"""Base class for array-to-bytes codecs."""

codec_input: type[NDBuffer]
codec_output: type[Buffer]


class BytesBytesCodec(BaseCodec[Buffer, Buffer]):
"""Base class for bytes-to-bytes codecs."""

codec_input: type[Buffer]
codec_output: type[Buffer]


Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec

Expand Down Expand Up @@ -276,6 +286,11 @@ class CodecPipeline:
decoding them and assembling an output array. On the write path, it encodes the chunks
and writes them to a store (via ByteSetter)."""

@abstractproperty
def prototype(self) -> BufferPrototype:
"""The buffer prototype of the codec pipeline"""
...

@abstractmethod
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
"""Fills in codec configuration parameters that can be automatically
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
from numcodecs.compat import ensure_bytes, ensure_ndarray_like

from zarr.abc.codec import ArrayBytesCodec
from zarr.core.buffer import Buffer, NDBuffer
from zarr.registry import get_ndbuffer_class

if TYPE_CHECKING:
from zarr.abc.numcodec import Numcodec
from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer, NDBuffer


@dataclass(frozen=True)
class V2Codec(ArrayBytesCodec):
codec_input = NDBuffer
codec_output = Buffer

filters: tuple[Numcodec, ...] | None
compressor: Numcodec | None

Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from packaging.version import Version

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer import Buffer
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_enum, parse_named_configuration
from zarr.core.dtype.common import HasItemSize
Expand All @@ -19,7 +20,6 @@
from typing import Self

from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer


class BloscShuffle(Enum):
Expand Down Expand Up @@ -88,6 +88,9 @@ def parse_blocksize(data: JSON) -> int:
class BloscCodec(BytesBytesCodec):
"""blosc codec"""

codec_input = Buffer
codec_output = Buffer

is_fixed_size = False

typesize: int | None
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Endian(Enum):
class BytesCodec(ArrayBytesCodec):
"""bytes codec"""

codec_input = NDBuffer
codec_output = Buffer

is_fixed_size = True

endian: Endian | None
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/crc32c_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
import typing_extensions

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer import Buffer
from zarr.core.common import JSON, parse_named_configuration

if TYPE_CHECKING:
from typing import Self

from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer


@dataclass(frozen=True)
class Crc32cCodec(BytesBytesCodec):
"""crc32c codec"""

codec_input = Buffer
codec_output = Buffer

is_fixed_size = True

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/gzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from numcodecs.gzip import GZip

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer import Buffer
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_named_configuration

if TYPE_CHECKING:
from typing import Self

from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer


def parse_gzip_level(data: JSON) -> int:
Expand All @@ -31,6 +31,9 @@ def parse_gzip_level(data: JSON) -> int:
class GzipCodec(BytesBytesCodec):
"""gzip codec"""

codec_input = Buffer
codec_output = Buffer

is_fixed_size = False

level: int = 5
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ class ShardingCodec(
):
"""Sharding codec"""

codec_input = NDBuffer
codec_output = Buffer

chunk_shape: tuple[int, ...]
codecs: tuple[Codec, ...]
index_codecs: tuple[Codec, ...]
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from zarr.abc.codec import ArrayArrayCodec
from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import NDBuffer
from zarr.core.common import JSON, parse_named_configuration

if TYPE_CHECKING:
from typing import Self

from zarr.core.buffer import NDBuffer
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType

Expand All @@ -30,6 +30,9 @@ def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]:
class TransposeCodec(ArrayArrayCodec):
"""Transpose codec"""

codec_input = NDBuffer
codec_output = NDBuffer

is_fixed_size = True

order: tuple[int, ...]
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/codecs/vlen_utf8.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
class VLenUTF8Codec(ArrayBytesCodec):
"""Variable-length UTF8 codec"""

codec_input = NDBuffer
codec_output = Buffer

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
Expand Down Expand Up @@ -71,6 +74,9 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -

@dataclass(frozen=True)
class VLenBytesCodec(ArrayBytesCodec):
codec_input = NDBuffer
codec_output = Buffer

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from packaging.version import Version

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer import Buffer
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_named_configuration

if TYPE_CHECKING:
from typing import Self

from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer


def parse_zstd_level(data: JSON) -> int:
Expand All @@ -38,6 +38,9 @@ def parse_checksum(data: JSON) -> bool:
class ZstdCodec(BytesBytesCodec):
"""zstd codec"""

codec_input = Buffer
codec_output = Buffer

is_fixed_size = True

level: int = 0
Expand Down
32 changes: 16 additions & 16 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
NDArrayLike,
NDArrayLikeOrScalar,
NDBuffer,
default_buffer_prototype,
)
from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
Expand Down Expand Up @@ -1623,7 +1622,7 @@ async def example():
```
"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self.codec_pipeline.prototype
indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
Expand All @@ -1640,7 +1639,7 @@ async def get_orthogonal_selection(
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
prototype = self.codec_pipeline.prototype
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
Expand All @@ -1655,7 +1654,7 @@ async def get_mask_selection(
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
prototype = self.codec_pipeline.prototype
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
Expand All @@ -1670,7 +1669,7 @@ async def get_coordinate_selection(
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
prototype = self.codec_pipeline.prototype
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
Expand Down Expand Up @@ -1787,7 +1786,8 @@ async def setitem(
- Supports basic indexing, where the selection is contiguous and does not involve advanced indexing.
"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self.codec_pipeline.prototype

indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
Expand Down Expand Up @@ -3086,7 +3086,7 @@ def get_basic_selection(
"""

if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
return sync(
self._async_array._get_selection(
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
Expand Down Expand Up @@ -3195,7 +3195,7 @@ def set_basic_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

Expand Down Expand Up @@ -3323,7 +3323,7 @@ def get_orthogonal_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(
self._async_array._get_selection(
Expand Down Expand Up @@ -3442,7 +3442,7 @@ def set_orthogonal_selection(
[__setitem__][zarr.Array.__setitem__]
"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
Expand Down Expand Up @@ -3530,7 +3530,7 @@ def get_mask_selection(
"""

if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return sync(
self._async_array._get_selection(
Expand Down Expand Up @@ -3620,7 +3620,7 @@ def set_mask_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

Expand Down Expand Up @@ -3708,7 +3708,7 @@ def get_coordinate_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = sync(
self._async_array._get_selection(
Expand Down Expand Up @@ -3800,7 +3800,7 @@ def set_coordinate_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
# setup indexer
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)

Expand Down Expand Up @@ -3923,7 +3923,7 @@ def get_block_selection(
[__setitem__][zarr.Array.__setitem__]
"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(
self._async_array._get_selection(
Expand Down Expand Up @@ -4024,7 +4024,7 @@ def set_block_selection(

"""
if prototype is None:
prototype = default_buffer_prototype()
prototype = self._async_array.codec_pipeline.prototype
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

Expand Down
Loading
Loading