From 21d48935aac6f7f054c307cf587cf1432f4c8a4c Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 4 Nov 2024 16:27:48 -0500 Subject: [PATCH] allow metadata as any name for interval set with tests --- pynapple/core/interval_set.py | 22 ++-- pynapple/core/metadata_class.py | 32 +++--- tests/test_metadata.py | 182 ++++++++++++++++++++------------ tests/test_test.py | 86 +++++++++++++++ tests/test_ts_group.py | 4 +- 5 files changed, 233 insertions(+), 93 deletions(-) create mode 100644 tests/test_test.py diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 86213bfd..9a893ca2 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -214,10 +214,11 @@ def __init__( self.index = np.arange(data.shape[0], dtype="int") self.columns = np.array(["start", "end"]) self.nap_class = self.__class__.__name__ - if drop_meta: - _MetadataMixin.__init__(self) - else: - _MetadataMixin.__init__(self, metadata) + # initialize metadata to get all attributes before setting metadata + _MetadataMixin.__init__(self) + self._class_attributes = self.__dir__() + if drop_meta is False: + self.set_info(metadata) self._initialized = True def __repr__(self): @@ -286,12 +287,21 @@ def __len__(self): def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) def __setitem__(self, key, value): - if (isinstance(key, str)) and (key not in self.columns): + if key in self.columns: + raise RuntimeError( + "IntervalSet is immutable. Starts and ends have been already sorted." + ) + elif isinstance(key, str): _MetadataMixin.__setitem__(self, key, value) else: raise RuntimeError( diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index dd5126ab..077c84d8 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -115,32 +115,32 @@ def _raise_invalid_metadata_column_name(self, name): raise TypeError( f"Invalid metadata type {type(name)}. Metadata column names must be strings!" ) + # warnings for metadata names that cannot be accessed as attributes or keys if hasattr(self, name) and (name not in self.metadata_columns): # existing non-metadata attribute - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{type(self).__dict__.keys()} attribute names!" - ) - if hasattr(self, "columns") and name in self.columns: - # existing column (since TsdFrame columns are not attributes) - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{self.columns} column names!" + warnings.warn( + f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) - if name[0].isalpha() is False: - # starts with a number - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot start with a number" + elif hasattr(self, "columns") and name in self.columns: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) + # warnings for metadata that cannot be accessed as attributes if name.replace("_", "").isalnum() is False: # contains invalid characters - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores" + warnings.warn( + f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." + ) + elif name[0].isalpha() is False: + # starts with a number + warnings.warn( + f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." ) def _check_metadata_column_names(self, metadata=None, **kwargs): """ - Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters. + Throw warnings when metadata names cannot be accessed as attributes or keys. """ if metadata is not None: diff --git a/tests/test_metadata.py b/tests/test_metadata.py index bbcf3319..13dc478e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -313,6 +313,71 @@ def test_drop_metadata_warnings(iset_meta): iset_meta.time_span() +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # existing attribute and key + ( + "start", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute and key + ( + "end", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute + ( + "values", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + does_not_raise(), + pytest.raises(ValueError), # shape mismatch + pytest.raises(AssertionError), # we do want metadata + ), + # existing metdata + ( + "label", + does_not_raise(), + does_not_raise(), + does_not_raise(), + pytest.raises(AssertionError), # we do want metadata + pytest.raises(AssertionError), # we do want metadata + ), + ], +) +def test_iset_metadata_overlapping_names( + iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(iset_meta, name) + + # warning when set + with set_exp: + iset_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(iset_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + iset_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(iset_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(iset_meta, name) == np.ones(4)) == False + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(iset_meta[name] == np.ones(4)) == False + + ############## ## TsdFrame ## ############## @@ -347,26 +412,21 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): ) -@pytest.mark.parametrize( - "args, kwargs, expected", - [ - ( - # invalid metadata names that are the same as column names - [ - pd.DataFrame( - index=["a", "b", "c"], - columns=["a", "b", "c"], - data=np.random.randint(0, 5, size=(3, 3)), - ) - ], - {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ], -) -def test_tsdframe_add_metadata_error(tsdframe_meta, args, kwargs, expected): - with expected: - tsdframe_meta.set_info(*args, **kwargs) +# @pytest.mark.parametrize( +# "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", +# [ +# ( +# # invalid metadata names that are the same as column names +# "a", +# pytest.warns(UserWarning, match="overlaps with an existing"), +# ), +# ], +# ) +# def test_tsdframe_metadata_overlapping_names(tsdframe_meta, args, kwargs, expected): +# assert + +# with expected: +# tsdframe_meta.set_info(*args, **kwargs) ############# @@ -575,52 +635,26 @@ def test_add_metadata_df(self, obj, info, obj_len): ( # invalid names as strings starting with a number [ - pd.DataFrame( - columns=["1"], - data=np.ones((4, 1)), - ) + {"1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="starts with a number"), ), ( # invalid names with spaces [ - pd.DataFrame( - columns=["l 1"], - data=np.ones((4, 1)), - ) + {"l 1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # invalid names with periods [ - pd.DataFrame( - columns=["l.1"], - data=np.ones((4, 1)), - ) - ], - {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # invalid names with dashes - [ - pd.DataFrame( - columns=["l-1"], - data=np.ones((4, 1)), - ) + {"1.1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # name that overlaps with existing attribute - [], - {"__dir__": np.zeros(4)}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # metadata with wrong length @@ -633,25 +667,33 @@ def test_add_metadata_df(self, obj, info, obj_len): ), ], ) - def test_add_metadata_error(self, obj, args, kwargs, expected): + def test_add_metadata_error(self, obj, obj_len, args, kwargs, expected): + # trim to appropriate length + if len(args): + if isinstance(args[0], pd.DataFrame): + metadata = args[0].iloc[:obj_len] + elif isinstance(args[0], dict): + metadata = {k: v[:obj_len] for k, v in args[0].items()} + else: + metadata = None with expected: - obj.set_info(*args, **kwargs) - - def test_add_metadata_key_error(self, obj, obj_len): - # type specific key errors - info = np.ones(obj_len) - if isinstance(obj, nap.IntervalSet): - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj[0] = info - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj["start"] = info - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj["end"] = info - - elif isinstance(obj, nap.TsGroup): - # currently obj[0] does not raise an error for TsdFrame - with pytest.raises(TypeError, match="Metadata keys must be strings!"): - obj[0] = info + obj.set_info(metadata, **kwargs) + + # def test_add_metadata_key_error(self, obj, obj_len): + # # type specific key errors + # info = np.ones(obj_len) + # if isinstance(obj, nap.IntervalSet): + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj[0] = info + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj["start"] = info + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj["end"] = info + + # elif isinstance(obj, nap.TsGroup): + # # currently obj[0] does not raise an error for TsdFrame + # with pytest.raises(TypeError, match="Metadata keys must be strings!"): + # obj[0] = info def test_overwrite_metadata(self, obj, obj_len): # add metadata diff --git a/tests/test_test.py b/tests/test_test.py new file mode 100644 index 00000000..e23e2959 --- /dev/null +++ b/tests/test_test.py @@ -0,0 +1,86 @@ +from numbers import Number +import inspect + + +import pickle +import numpy as np +import pandas as pd +import pytest +from pathlib import Path +from contextlib import nullcontext as does_not_raise +import warnings + +import pynapple as nap + + +@pytest.fixture +def iset_meta(): + start = np.array([0, 10, 16, 25]) + end = np.array([5, 15, 20, 40]) + metadata = {"label": ["a", "b", "c", "d"], "info": np.arange(4)} + return nap.IntervalSet(start=start, end=end, metadata=metadata) + + +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # existing attribute and key + ( + "start", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute and key + ( + "end", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute + ( + "values", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + does_not_raise(), + pytest.raises(ValueError), # shape mismatch + pytest.raises(AssertionError), # we do want metadata + ), + # existing metdata + ( + "label", + does_not_raise(), + does_not_raise(), + does_not_raise(), + pytest.raises(AssertionError), # we do want metadata + pytest.raises(AssertionError), # we do want metadata + ), + ], +) +def test_iset_metadata_overlapping_names( + iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(iset_meta, name) + + # warning when set + with set_exp: + iset_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(iset_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + iset_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(iset_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(iset_meta, name) == np.ones(4)) == False + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(iset_meta[name] == np.ones(4)) == False diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 02b12fba..6d41b46b 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -629,7 +629,9 @@ def test_setitem_metadata_twice(self, group): assert all(group._metadata["a"] == np.arange(len(group)) + 10) def test_prevent_overwriting_existing_methods(self, ts_group): - with pytest.raises(ValueError, match=r"Invalid metadata name"): + # with pytest.raises(ValueError, match=r"Invalid metadata name"): + # ts_group["set_info"] = np.arange(2) + with pytest.warns(UserWarning, match=r"overlaps with an existing"): ts_group["set_info"] = np.arange(2) def test_getitem_ts_object(self, ts_group):