From 13ec7a8bdae3457d6f36c7e3487c5af75fc6a6de Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 12 Aug 2024 15:37:21 +0100 Subject: [PATCH] Set object codec for object arrays --- cubed/storage/backend.py | 36 ++++++++++++++++++++++++++---------- cubed/tests/test_types.py | 12 ++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/cubed/storage/backend.py b/cubed/storage/backend.py index 4ad56032..3844a079 100644 --- a/cubed/storage/backend.py +++ b/cubed/storage/backend.py @@ -4,16 +4,7 @@ from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store -def open_backend_array( - store: T_Store, - mode: str, - *, - shape: Optional[T_Shape] = None, - dtype: Optional[T_DType] = None, - chunks: Optional[T_RegularChunks] = None, - path: Optional[str] = None, - **kwargs, -): +def backend_storage_name(): # get storage name from top-level config # e.g. set globally with CUBED_STORAGE_NAME=tensorstore storage_name = config.get("storage_name", None) @@ -26,10 +17,35 @@ def open_backend_array( else: storage_name = "zarr-python" + return storage_name + + +def open_backend_array( + store: T_Store, + mode: str, + *, + shape: Optional[T_Shape] = None, + dtype: Optional[T_DType] = None, + chunks: Optional[T_RegularChunks] = None, + path: Optional[str] = None, + **kwargs, +): + storage_name = backend_storage_name() + if storage_name == "zarr-python": from cubed.storage.backends.zarr_python import open_zarr_array open_func = open_zarr_array + + # set object codec if needed + import numpy as np + + if np.dtype(dtype).hasobject and "object_codec" not in kwargs: + import numcodecs + + object_codec = numcodecs.Pickle() + kwargs["object_codec"] = object_codec + elif storage_name == "zarr-python-v3": from cubed.storage.backends.zarr_python_v3 import open_zarr_v3_array diff --git a/cubed/tests/test_types.py b/cubed/tests/test_types.py index b21eca6e..6103c319 100644 --- a/cubed/tests/test_types.py +++ b/cubed/tests/test_types.py @@ -1,6 +1,9 @@ +import pytest from numpy.testing import assert_array_equal +import cubed import cubed.array_api as xp +from cubed.storage.backend import backend_storage_name # This is less strict than the spec, but is supported by implementations like NumPy @@ -8,3 +11,12 @@ def test_prod_sum_bool(): a = xp.ones((2,), dtype=xp.bool) assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64)) assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64)) + + +@pytest.mark.skipif( + backend_storage_name() != "zarr-python", + reason="object dtype only works on zarr-python", +) +def test_object_dtype(): + a = xp.asarray(["a", "b"], dtype=object, chunks=2) + cubed.to_zarr(a, store=None)