diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ceb79a3e173..607661ed30b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,9 @@ Bug fixes By `Kai Mühlbauer `_. - Fix the SciPy backend for netCDF3 files . (:issue:`8909`, :pull:`10376`) By `Deepak Cherian `_. +- Check and fix character array string dimension names, issue warnings as needed (:issue:`6352`, :pull:`10395`). + By `Kai Mühlbauer `_. + Documentation @@ -36,6 +39,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Forward variable name down to coders for AbstractWritableDataStore.encode_variable and subclasses. (:pull:`10395`). + By `Kai Mühlbauer `_. .. _whats-new.2025.06.1: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e574f19e9d4..e1f8dc5cecd 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -389,11 +389,11 @@ def encode(self, variables, attributes): attributes : dict-like """ - variables = {k: self.encode_variable(v) for k, v in variables.items()} + variables = {k: self.encode_variable(v, name=k) for k, v in variables.items()} attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} return variables, attributes - def encode_variable(self, v): + def encode_variable(self, v, name=None): """encode one variable""" return v @@ -641,7 +641,7 @@ def encode(self, variables, attributes): variables = { k: ensure_dtype_not_object(v, name=k) for k, v in variables.items() } - variables = {k: self.encode_variable(v) for k, v in variables.items()} + variables = {k: self.encode_variable(v, name=k) for k, v in variables.items()} attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} return variables, attributes diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index ba3a6d20e37..f3e434c6e5e 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -286,8 +286,8 @@ def set_dimension(self, name, length, is_unlimited=False): def set_attribute(self, key, value): self.ds.attrs[key] = value - def encode_variable(self, variable): - return _encode_nc4_variable(variable) + def encode_variable(self, variable, name=None): + return _encode_nc4_variable(variable, name=name) def prepare_variable( self, name, variable, check_encoding=False, unlimited_dims=None diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a23d247b6c3..8c3a01eba66 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,7 +9,6 @@ import numpy as np -from xarray import coding from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendArray, @@ -30,6 +29,12 @@ ) from xarray.backends.netcdf3 import encode_nc3_attr_value, encode_nc3_variable from xarray.backends.store import StoreBackendEntrypoint +from xarray.coding.strings import ( + CharacterArrayCoder, + EncodedStringCoder, + create_vlen_dtype, + is_unicode_dtype, +) from xarray.coding.variables import pop_to from xarray.core import indexing from xarray.core.utils import ( @@ -73,7 +78,7 @@ def __init__(self, variable_name, datastore): # check vlen string dtype in further steps # it also prevents automatic string concatenation via # conventions.decode_cf_variable - dtype = coding.strings.create_vlen_dtype(str) + dtype = create_vlen_dtype(str) self.dtype = dtype def __setitem__(self, key, value): @@ -127,12 +132,12 @@ def _getitem(self, key): return array -def _encode_nc4_variable(var): +def _encode_nc4_variable(var, name=None): for coder in [ - coding.strings.EncodedStringCoder(allows_unicode=True), - coding.strings.CharacterArrayCoder(), + EncodedStringCoder(allows_unicode=True), + CharacterArrayCoder(), ]: - var = coder.encode(var) + var = coder.encode(var, name=name) return var @@ -164,7 +169,7 @@ def _nc4_dtype(var): if "dtype" in var.encoding: dtype = var.encoding.pop("dtype") _check_encoding_dtype_is_vlen_string(dtype) - elif coding.strings.is_unicode_dtype(var.dtype): + elif is_unicode_dtype(var.dtype): dtype = str elif var.dtype.kind in ["i", "u", "f", "c", "S"]: dtype = var.dtype @@ -535,12 +540,12 @@ def set_attribute(self, key, value): else: self.ds.setncattr(key, value) - def encode_variable(self, variable): + def encode_variable(self, variable, name=None): variable = _force_native_endianness(variable) if self.format == "NETCDF4": - variable = _encode_nc4_variable(variable) + variable = _encode_nc4_variable(variable, name=name) else: - variable = encode_nc3_variable(variable) + variable = encode_nc3_variable(variable, name=name) return variable def prepare_variable( diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 3ae024c9760..6f66b6c1059 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -118,12 +118,12 @@ def _maybe_prepare_times(var): return data -def encode_nc3_variable(var): +def encode_nc3_variable(var, name=None): for coder in [ coding.strings.EncodedStringCoder(allows_unicode=False), coding.strings.CharacterArrayCoder(), ]: - var = coder.encode(var) + var = coder.encode(var, name=name) data = _maybe_prepare_times(var) data = coerce_nc3_dtype(data) attrs = encode_nc3_attrs(var.attrs) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 16fb4528f55..b98d226cac6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -227,8 +227,8 @@ def set_attribute(self, key, value): value = encode_nc3_attr_value(value) setattr(self.ds, key, value) - def encode_variable(self, variable): - variable = encode_nc3_variable(variable) + def encode_variable(self, variable, name=None): + variable = encode_nc3_variable(variable, name=name) return variable def prepare_variable( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index b86b5d0b374..54ff419b2f2 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -855,8 +855,8 @@ def set_dimensions(self, variables, unlimited_dims=None): def set_attributes(self, attributes): _put_attrs(self.zarr_group, attributes) - def encode_variable(self, variable): - variable = encode_zarr_variable(variable) + def encode_variable(self, variable, name=None): + variable = encode_zarr_variable(variable, name=name) return variable def encode_attribute(self, a): diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index ea2f58274b6..c917faaf383 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from functools import partial import numpy as np @@ -15,7 +16,7 @@ unpack_for_encoding, ) from xarray.core import indexing -from xarray.core.utils import module_available +from xarray.core.utils import emit_user_level_warning, module_available from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array @@ -113,6 +114,36 @@ def ensure_fixed_length_bytes(var: Variable) -> Variable: return var +def validate_char_dim_name(strlen, encoding, name) -> str: + """Check character array dimension naming and size and return it.""" + + if (char_dim_name := encoding.pop("char_dim_name", None)) is not None: + # 1 - extract all characters up to last number sequence + # 2 - extract last number sequence + match = re.search(r"^(.*?)(\d+)(?!.*\d)", char_dim_name) + if match: + new_dim_name = match.group(1) + if int(match.group(2)) != strlen: + emit_user_level_warning( + f"String dimension naming mismatch on variable {name!r}. {char_dim_name!r} provided by encoding, but data has length of '{strlen}'. Using '{new_dim_name}{strlen}' instead of {char_dim_name!r} to prevent possible naming clash.\n" + "To silence this warning either remove 'char_dim_name' from encoding or provide a fitting name." + ) + char_dim_name = f"{new_dim_name}{strlen}" + else: + if ( + original_shape := encoding.get("original_shape", [-1])[-1] + ) != -1 and original_shape != strlen: + emit_user_level_warning( + f"String dimension length mismatch on variable {name!r}. '{original_shape}' provided by encoding, but data has length of '{strlen}'. Using '{char_dim_name}{strlen}' instead of {char_dim_name!r} to prevent possible naming clash.\n" + f"To silence this warning remove 'original_shape' from encoding." + ) + char_dim_name = f"{char_dim_name}{strlen}" + else: + char_dim_name = f"string{strlen}" + + return char_dim_name + + class CharacterArrayCoder(VariableCoder): """Transforms between arrays containing bytes and character arrays.""" @@ -122,10 +153,7 @@ def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) if data.dtype.kind == "S" and encoding.get("dtype") is not str: data = bytes_to_char(data) - if "char_dim_name" in encoding.keys(): - char_dim_name = encoding.pop("char_dim_name") - else: - char_dim_name = f"string{data.shape[-1]}" + char_dim_name = validate_char_dim_name(data.shape[-1], encoding, name) dims = dims + (char_dim_name,) return Variable(dims, data, attrs, encoding) diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 17179a44a8a..e7971a311f5 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -139,6 +139,45 @@ def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name) -> assert roundtripped.dims[-1] == original.dims[-1] +@pytest.mark.parametrize( + [ + "original", + "expected_char_dim_name", + "expected_char_dim_length", + "warning_message", + ], + [ + ( + Variable(("x",), [b"ab", b"cde"], encoding={"char_dim_name": "foo4"}), + "foo3", + 3, + "String dimension naming mismatch", + ), + ( + Variable( + ("x",), + [b"ab", b"cde"], + encoding={"original_shape": (2, 4), "char_dim_name": "foo"}, + ), + "foo3", + 3, + "String dimension length mismatch", + ), + ], +) +def test_CharacterArrayCoder_dim_mismatch_warnings( + original, expected_char_dim_name, expected_char_dim_length, warning_message +) -> None: + coder = strings.CharacterArrayCoder() + with pytest.warns(UserWarning, match=warning_message): + encoded = coder.encode(original) + roundtripped = coder.decode(encoded) + assert encoded.dims[-1] == expected_char_dim_name + assert encoded.sizes[expected_char_dim_name] == expected_char_dim_length + assert roundtripped.encoding["char_dim_name"] == expected_char_dim_name + assert roundtripped.dims[-1] == original.dims[-1] + + def test_StackedBytesArray() -> None: array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") actual = strings.StackedBytesArray(array) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 961df78154e..ce792c83740 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -555,10 +555,10 @@ def test_decode_cf_time_kwargs(self, time_unit) -> None: class CFEncodedInMemoryStore(WritableCFDataStore, InMemoryDataStore): - def encode_variable(self, var): + def encode_variable(self, var, name=None): """encode one variable""" coder = coding.strings.EncodedStringCoder(allows_unicode=True) - var = coder.encode(var) + var = coder.encode(var, name=name) return var