Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return scalar when accessing zero dimensional array #2718

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a9531ca
Return scalar when accessing zero dimensional array
brokkoli71 Jan 15, 2025
1a290c7
returning npt.ArrayLike instead of NDArrayLike because of scalar retu…
brokkoli71 Jan 15, 2025
3348439
returning npt.ArrayLike instead of NDArrayLike because of scalar retu…
brokkoli71 Jan 15, 2025
34bf260
fix mypy in tests
brokkoli71 Jan 16, 2025
75d6cdf
fix mypy in tests
brokkoli71 Jan 16, 2025
66c0798
fix mypy in tests
brokkoli71 Jan 16, 2025
a86d041
improve test_scalar_array
brokkoli71 Jan 23, 2025
f8393b0
Merge branch 'main' into return-scalar-for-zero-dim-indexing
brokkoli71 Jan 29, 2025
c2a0de0
fix typo
brokkoli71 Jan 29, 2025
2475960
add ScalarWrapper
brokkoli71 Jan 29, 2025
cced470
use ScalarWrapper as NDArrayLike
brokkoli71 Jan 29, 2025
d57e1f2
Revert "fix mypy in tests"
brokkoli71 Jan 29, 2025
842bf95
Revert "fix mypy in tests"
brokkoli71 Jan 29, 2025
ee6d62d
Revert "fix mypy in tests"
brokkoli71 Jan 29, 2025
e302ae6
format
brokkoli71 Jan 29, 2025
b55c8b3
Revert "returning npt.ArrayLike instead of NDArrayLike because of sca…
brokkoli71 Jan 29, 2025
f1cb6e1
Revert "returning npt.ArrayLike instead of NDArrayLike because of sca…
brokkoli71 Jan 29, 2025
c76be48
fix mypy for ScalarWrapper
brokkoli71 Jan 29, 2025
359eb66
add missing import NDArrayLike
brokkoli71 Jan 29, 2025
fc0937e
ignore unavoidable mypy error
brokkoli71 Jan 29, 2025
8454d7b
format
brokkoli71 Jan 29, 2025
15c5103
fix __array__
brokkoli71 Jan 29, 2025
a93ce00
extend tests
brokkoli71 Jan 29, 2025
805a8df
format
brokkoli71 Jan 29, 2025
f6b48ba
fix typing in test_scalar_array
brokkoli71 Jan 29, 2025
1b7966f
add dtype to ScalarWrapper
brokkoli71 Jan 29, 2025
81ab808
correct dtype type
brokkoli71 Jan 29, 2025
56c52ae
fix test_basic_indexing
brokkoli71 Jan 29, 2025
a3473fb
fix test_basic_indexing
brokkoli71 Jan 29, 2025
b98578b
Merge remote-tracking branch 'origin/return-scalar-for-zero-dim-index…
brokkoli71 Jan 29, 2025
c150462
fix test_basic_indexing for dtype=datetime64[Y]
brokkoli71 Jan 29, 2025
a7d4421
increase codecov
brokkoli71 Jan 29, 2025
632cad0
fix typing
brokkoli71 Jan 29, 2025
50fd5ff
document changes
brokkoli71 Jan 29, 2025
2aca6c2
move test_scalar_wrapper to test_buffer.py
brokkoli71 Jan 29, 2025
15ec284
Merge branch 'main' into return-scalar-for-zero-dim-indexing
brokkoli71 Jan 30, 2025
a842e88
Merge branch 'main' into return-scalar-for-zero-dim-indexing
brokkoli71 Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/2718.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0-dimensional arrays are now returning an instance of ``ScalarWrapper``, which behaves like a scalar
but follow the ``NDArrayLike`` interface. This change is to make the behavior of 0-dimensional arrays
consistent with ``numpy`` scalars while not breaking the output format of the ``NDArrayLike`` interface.
4 changes: 4 additions & 0 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,8 @@ async def _get_selection(
out_buffer,
drop_axes=indexer.drop_axes,
)
if isinstance(indexer, BasicIndexer) and indexer.shape == ():
return out_buffer.as_scalar()
return out_buffer.as_ndarray_like()

async def getitem(
Expand Down Expand Up @@ -2268,6 +2270,8 @@ def __array__(
raise ValueError(msg)

arr_np = self[...]
if self.ndim == 0:
arr_np = np.array(arr_np)

if dtype is not None:
arr_np = arr_np.astype(dtype)
Expand Down
144 changes: 144 additions & 0 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from zarr.codecs.bytes import Endian
from zarr.core.common import BytesLike, ChunkCoords
from zarr.core.indexing import Selection

# Everything here is imported into ``zarr.core.buffer`` namespace.
__all__: list[str] = []
Expand Down Expand Up @@ -105,6 +106,138 @@
"""


class ScalarWrapper:
def __init__(self, value: Any, dtype: np.dtype[Any] | None = None) -> None:
self._value: Any = value
self._dtype: np.dtype[Any] = dtype or np.dtype(type(self._value))

@property
def dtype(self) -> np.dtype[Any]:
return self._dtype

@property
def ndim(self) -> int:
return 0

@property
def size(self) -> int:
return 1

@property
def shape(self) -> tuple[()]:
return ()

def __len__(self) -> int:
raise TypeError("len() of unsized object.")

def __getitem__(self, key: Selection) -> Self:
if key != slice(None) and key != Ellipsis and key != ():
raise IndexError("Invalid index for scalar.")
return self

def __setitem__(self, key: Selection, value: Any) -> None:
if key != slice(None) and key != Ellipsis and key != ():
raise IndexError("Invalid index for scalar.")
self._value = value

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = True
) -> npt.NDArray[Any]:
return np.array(self._value, dtype=dtype or self._dtype, copy=copy)

def reshape(
self, shape: tuple[int, ...] | Literal[-1], *, order: Literal["A", "C", "F"] = "C"
) -> Self:
if shape != () and shape != -1:
raise ValueError("Cannot reshape scalar to non-scalar shape.")
return self

def view(self, dtype: npt.DTypeLike) -> Self:
return self.astype(dtype)

def astype(
self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = "K", *, copy: bool = True
) -> Self:
if copy:
return self.__class__(self._value, np.dtype(dtype))
self._dtype = np.dtype(dtype)
return self

Check warning on line 164 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L163-L164

Added lines #L163 - L164 were not covered by tests

def fill(self, value: Any) -> None:
self._value = value

def copy(self) -> Self:
return self.__class__(self._value)

def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None = None) -> Self:
return self

def ravel(self, order: Literal["K", "A", "C", "F"] = "C") -> Self:
return self

def all(self) -> bool:
return bool(self._value)

def __eq__(self, other: object) -> Self: # type: ignore[explicit-override, override]
return self.__class__(self._value == other)

def __repr__(self) -> str:
return f"ScalarWrapper({self._value!r})"

Check warning on line 185 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L185

Added line #L185 was not covered by tests

def __getattr__(self, name: str) -> Any:
return getattr(self._value, name)

def __add__(self, other: Any) -> Any:
return self._value + other

def __sub__(self, other: Any) -> Any:
return self._value - other

def __mul__(self, other: Any) -> Any:
return self._value * other

def __truediv__(self, other: Any) -> Any:
return self._value / other

def __floordiv__(self, other: Any) -> Any:
return self._value // other

def __mod__(self, other: Any) -> Any:
return self._value % other

def __pow__(self, other: Any) -> Any:
return self._value**other

def __neg__(self) -> Any:
return -self._value

def __abs__(self) -> Any:
if hasattr(self._value, "__abs__"):
return abs(self._value)
raise TypeError(f"bad operand type for abs(): '{self._value.__class__.__name__}'")

def __int__(self) -> int:
return int(self._value)

def __float__(self) -> float:
return float(self._value)

def __complex__(self) -> complex:
return complex(self._value)

def __bool__(self) -> bool:
return bool(self._value)

def __hash__(self) -> int:
return hash(self._value)

def __str__(self) -> str:
return str(self._value)

def __format__(self, format_spec: str) -> str:
return format(self._value, format_spec)


def check_item_key_is_1d_contiguous(key: Any) -> None:
"""Raises error if `key` isn't a 1d contiguous slice"""
if not isinstance(key, slice):
Expand Down Expand Up @@ -419,6 +552,17 @@
"""
...

def as_scalar(self) -> ScalarWrapper:
"""Returns the buffer as a scalar value

Returns
-------
ScalarWrapper of this buffer
"""
if self._data.size != 1:
raise ValueError("Buffer does not contain a single scalar value")
return ScalarWrapper(self.as_numpy_array().item(), np.dtype(self.dtype))

@property
def dtype(self) -> np.dtype[Any]:
return self._data.dtype
Expand Down
19 changes: 14 additions & 5 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
chunks_initialized,
create_array,
)
from zarr.core.buffer import default_buffer_prototype
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
from zarr.core.buffer.core import ScalarWrapper
from zarr.core.buffer.cpu import NDBuffer
from zarr.core.chunk_grids import _auto_partition
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
Expand Down Expand Up @@ -1324,11 +1325,19 @@ async def test_create_array_data_ignored_params(store: Store) -> None:
await create_array(store, data=data, shape=None, dtype=data.dtype, overwrite=True)


async def test_scalar_array() -> None:
arr = zarr.array(1.5)
assert arr[...] == 1.5
assert arr[()] == 1.5
@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1)])
@pytest.mark.parametrize("zarr_format", [2, 3])
def test_scalar_array(value: Any, zarr_format: ZarrFormat) -> None:
arr = zarr.array(value, zarr_format=zarr_format)
assert arr[...] == value
assert arr.shape == ()
assert arr.ndim == 0

x = arr[()]
assert isinstance(arr[()], ScalarWrapper)
assert isinstance(arr[()], NDArrayLike)
assert x.shape == arr.shape
assert x.ndim == arr.ndim


async def test_orthogonal_set_total_slice() -> None:
Expand Down
56 changes: 55 additions & 1 deletion tests/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import re
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest
Expand Down Expand Up @@ -30,6 +31,9 @@
cp = None


import zarr.api.asynchronous
from zarr.core.buffer.core import ScalarWrapper

if TYPE_CHECKING:
import types

Expand All @@ -40,6 +44,54 @@ def test_nd_array_like(xp: types.ModuleType) -> None:
assert isinstance(ary, NDArrayLike)


@pytest.mark.parametrize("value", [1, 1.4, "a", b"a", np.array(1), False, True])
def test_scalar_wrapper(value: Any) -> None:
x = ScalarWrapper(value)
assert x == value
assert value == x
assert x == x[()]
assert x.view(str) == x
assert x.copy() == x
assert x.transpose() == x
assert x.ravel() == x
assert x.all() == bool(value)
if isinstance(value, (int, float)):
assert -x == -value
assert abs(x) == abs(value)
assert int(x) == int(value)
assert float(x) == float(value)
assert complex(x) == complex(value)
assert x + 1 == value + 1
assert x - 1 == value - 1
assert x * 2 == value * 2
assert x / 2 == value / 2
assert x // 2 == value // 2
assert x % 2 == value % 2
assert x**2 == value**2
assert x == value
assert x != value + 1
assert bool(x) == bool(value)
assert hash(x) == hash(value)
assert str(x) == str(value)
assert format(x, "") == format(value, "")
x.fill(2)
x[()] += 1
assert x == 3
elif isinstance(value, str):
assert str(x) == value
with pytest.raises(TypeError, match=re.escape("bad operand type for abs(): 'str'")):
abs(x)

with pytest.raises(ValueError, match="Cannot reshape scalar to non-scalar shape."):
x.reshape((1, 2))
with pytest.raises(IndexError, match="Invalid index for scalar."):
x[10] = value
with pytest.raises(IndexError, match="Invalid index for scalar."):
x[10]
with pytest.raises(TypeError, match=re.escape("len() of unsized object.")):
len(x)


@pytest.mark.asyncio
async def test_async_array_prototype() -> None:
"""Test the use of a custom buffer prototype"""
Expand Down Expand Up @@ -151,3 +203,5 @@ def test_numpy_buffer_prototype() -> None:
ndbuffer = cpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64"))
assert isinstance(buffer.as_array_like(), np.ndarray)
assert isinstance(ndbuffer.as_ndarray_like(), np.ndarray)
with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"):
ndbuffer.as_scalar()