Skip to content

Commit

Permalink
allow metadata as any name for interval set with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Nov 4, 2024
1 parent e4bc841 commit 21d4893
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 93 deletions.
22 changes: 16 additions & 6 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ def __init__(
self.index = np.arange(data.shape[0], dtype="int")
self.columns = np.array(["start", "end"])
self.nap_class = self.__class__.__name__
if drop_meta:
_MetadataMixin.__init__(self)
else:
_MetadataMixin.__init__(self, metadata)
# initialize metadata to get all attributes before setting metadata
_MetadataMixin.__init__(self)
self._class_attributes = self.__dir__()
if drop_meta is False:
self.set_info(metadata)
self._initialized = True

def __repr__(self):
Expand Down Expand Up @@ -286,12 +287,21 @@ def __len__(self):
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
_MetadataMixin.__setattr__(self, name, value)
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
object.__setattr__(self, name, value)

def __setitem__(self, key, value):
if (isinstance(key, str)) and (key not in self.columns):
if key in self.columns:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
elif isinstance(key, str):
_MetadataMixin.__setitem__(self, key, value)
else:
raise RuntimeError(
Expand Down
32 changes: 16 additions & 16 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,32 +115,32 @@ def _raise_invalid_metadata_column_name(self, name):
raise TypeError(
f"Invalid metadata type {type(name)}. Metadata column names must be strings!"
)
# warnings for metadata names that cannot be accessed as attributes or keys
if hasattr(self, name) and (name not in self.metadata_columns):
# existing non-metadata attribute
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{type(self).__dict__.keys()} attribute names!"
)
if hasattr(self, "columns") and name in self.columns:
# existing column (since TsdFrame columns are not attributes)
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{self.columns} column names!"
warnings.warn(
f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
if name[0].isalpha() is False:
# starts with a number
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot start with a number"
elif hasattr(self, "columns") and name in self.columns:
# existing non-metadata attribute
warnings.warn(

Check warning on line 126 in pynapple/core/metadata_class.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/metadata_class.py#L126

Added line #L126 was not covered by tests
f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
# warnings for metadata that cannot be accessed as attributes
if name.replace("_", "").isalnum() is False:
# contains invalid characters
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores"
warnings.warn(
f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)
elif name[0].isalpha() is False:
# starts with a number
warnings.warn(
f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)

def _check_metadata_column_names(self, metadata=None, **kwargs):
"""
Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters.
Throw warnings when metadata names cannot be accessed as attributes or keys.
"""

if metadata is not None:
Expand Down
182 changes: 112 additions & 70 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,71 @@ def test_drop_metadata_warnings(iset_meta):
iset_meta.time_span()


@pytest.mark.parametrize(
"name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp",
[
# existing attribute and key
(
"start",
pytest.warns(UserWarning, match="overlaps with an existing"),
pytest.raises(AttributeError, match="IntervalSet is immutable"),
pytest.raises(RuntimeError, match="IntervalSet is immutable"),
does_not_raise(),
does_not_raise(),
),
# existing attribute and key
(
"end",
pytest.warns(UserWarning, match="overlaps with an existing"),
pytest.raises(AttributeError, match="IntervalSet is immutable"),
pytest.raises(RuntimeError, match="IntervalSet is immutable"),
does_not_raise(),
does_not_raise(),
),
# existing attribute
(
"values",
pytest.warns(UserWarning, match="overlaps with an existing"),
pytest.raises(AttributeError, match="IntervalSet is immutable"),
does_not_raise(),
pytest.raises(ValueError), # shape mismatch
pytest.raises(AssertionError), # we do want metadata
),
# existing metdata
(
"label",
does_not_raise(),
does_not_raise(),
does_not_raise(),
pytest.raises(AssertionError), # we do want metadata
pytest.raises(AssertionError), # we do want metadata
),
],
)
def test_iset_metadata_overlapping_names(
iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp
):
assert hasattr(iset_meta, name)

# warning when set
with set_exp:
iset_meta.set_info({name: np.ones(4)})
# error when set as attribute
with set_attr_exp:
setattr(iset_meta, name, np.ones(4))
# error when set as key
with set_key_exp:
iset_meta[name] = np.ones(4)
# retrieve with get_info
assert np.all(iset_meta.get_info(name) == np.ones(4))
# make sure it doesn't access metadata if its an existing attribute or key
with get_attr_exp:
assert np.all(getattr(iset_meta, name) == np.ones(4)) == False
# make sure it doesn't access metadata if its an existing key
with get_key_exp:
assert np.all(iset_meta[name] == np.ones(4)) == False


##############
## TsdFrame ##
##############
Expand Down Expand Up @@ -347,26 +412,21 @@ def test_tsdframe_metadata_slicing(tsdframe_meta):
)


@pytest.mark.parametrize(
"args, kwargs, expected",
[
(
# invalid metadata names that are the same as column names
[
pd.DataFrame(
index=["a", "b", "c"],
columns=["a", "b", "c"],
data=np.random.randint(0, 5, size=(3, 3)),
)
],
{},
pytest.raises(ValueError, match="Invalid metadata name"),
),
],
)
def test_tsdframe_add_metadata_error(tsdframe_meta, args, kwargs, expected):
with expected:
tsdframe_meta.set_info(*args, **kwargs)
# @pytest.mark.parametrize(
# "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp",
# [
# (
# # invalid metadata names that are the same as column names
# "a",
# pytest.warns(UserWarning, match="overlaps with an existing"),
# ),
# ],
# )
# def test_tsdframe_metadata_overlapping_names(tsdframe_meta, args, kwargs, expected):
# assert

# with expected:
# tsdframe_meta.set_info(*args, **kwargs)


#############
Expand Down Expand Up @@ -575,52 +635,26 @@ def test_add_metadata_df(self, obj, info, obj_len):
(
# invalid names as strings starting with a number
[
pd.DataFrame(
columns=["1"],
data=np.ones((4, 1)),
)
{"1": np.ones(4)},
],
{},
pytest.raises(ValueError, match="Invalid metadata name"),
pytest.warns(UserWarning, match="starts with a number"),
),
(
# invalid names with spaces
[
pd.DataFrame(
columns=["l 1"],
data=np.ones((4, 1)),
)
{"l 1": np.ones(4)},
],
{},
pytest.raises(ValueError, match="Invalid metadata name"),
pytest.warns(UserWarning, match="contains a special character"),
),
(
# invalid names with periods
[
pd.DataFrame(
columns=["l.1"],
data=np.ones((4, 1)),
)
],
{},
pytest.raises(ValueError, match="Invalid metadata name"),
),
(
# invalid names with dashes
[
pd.DataFrame(
columns=["l-1"],
data=np.ones((4, 1)),
)
{"1.1": np.ones(4)},
],
{},
pytest.raises(ValueError, match="Invalid metadata name"),
),
(
# name that overlaps with existing attribute
[],
{"__dir__": np.zeros(4)},
pytest.raises(ValueError, match="Invalid metadata name"),
pytest.warns(UserWarning, match="contains a special character"),
),
(
# metadata with wrong length
Expand All @@ -633,25 +667,33 @@ def test_add_metadata_df(self, obj, info, obj_len):
),
],
)
def test_add_metadata_error(self, obj, args, kwargs, expected):
def test_add_metadata_error(self, obj, obj_len, args, kwargs, expected):
# trim to appropriate length
if len(args):
if isinstance(args[0], pd.DataFrame):
metadata = args[0].iloc[:obj_len]
elif isinstance(args[0], dict):
metadata = {k: v[:obj_len] for k, v in args[0].items()}
else:
metadata = None
with expected:
obj.set_info(*args, **kwargs)

def test_add_metadata_key_error(self, obj, obj_len):
# type specific key errors
info = np.ones(obj_len)
if isinstance(obj, nap.IntervalSet):
with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
obj[0] = info
with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
obj["start"] = info
with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
obj["end"] = info

elif isinstance(obj, nap.TsGroup):
# currently obj[0] does not raise an error for TsdFrame
with pytest.raises(TypeError, match="Metadata keys must be strings!"):
obj[0] = info
obj.set_info(metadata, **kwargs)

# def test_add_metadata_key_error(self, obj, obj_len):
# # type specific key errors
# info = np.ones(obj_len)
# if isinstance(obj, nap.IntervalSet):
# with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
# obj[0] = info
# with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
# obj["start"] = info
# with pytest.raises(RuntimeError, match="IntervalSet is immutable"):
# obj["end"] = info

# elif isinstance(obj, nap.TsGroup):
# # currently obj[0] does not raise an error for TsdFrame
# with pytest.raises(TypeError, match="Metadata keys must be strings!"):
# obj[0] = info

def test_overwrite_metadata(self, obj, obj_len):
# add metadata
Expand Down
Loading

0 comments on commit 21d4893

Please sign in to comment.