From c31e837a7934b13267561e61e62ae4b403feee90 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Mon, 13 Jan 2025 16:42:30 -0600 Subject: [PATCH] add sketch of HEDAnnotations [ci skip] --- mne/__init__.pyi | 2 + mne/annotations.py | 139 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/mne/__init__.pyi b/mne/__init__.pyi index d50b5209346..6560854402e 100644 --- a/mne/__init__.pyi +++ b/mne/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "Evoked", "EvokedArray", "Forward", + "HEDAnnotations", "Info", "Label", "MixedSourceEstimate", @@ -260,6 +261,7 @@ from ._freesurfer import ( ) from .annotations import ( Annotations, + HEDAnnotations, annotations_from_events, count_annotations, events_from_annotations, diff --git a/mne/annotations.py b/mne/annotations.py index 694836d8188..a784c3ae143 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -52,6 +52,7 @@ verbose, warn, ) +from .utils.check import _soft_import # For testing windows_like_datetime, we monkeypatch "datetime" in this module. # Keep the true datetime object around for _validate_type use. @@ -151,6 +152,7 @@ class Annotations: -------- mne.annotations_from_events mne.events_from_annotations + mne.HEDAnnotations Notes ----- @@ -288,7 +290,7 @@ def orig_time(self): def __eq__(self, other): """Compare to another Annotations instance.""" - if not isinstance(other, Annotations): + if not isinstance(other, type(self)): return False return ( np.array_equal(self.onset, other.onset) @@ -567,6 +569,8 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] + if hasattr(self, "hed_tags"): + self.hed_tags = self.hed_tags[order] @verbose def crop( @@ -758,7 +762,138 @@ def rename(self, mapping, verbose=None): return self -# TODO: Add support for HED annotations for use in epoching. +class HEDAnnotations(Annotations): + """Annotations object for annotating segments of raw data with HED tags. + + Parameters + ---------- + onset : array of float, shape (n_annotations,) + The starting time of annotations in seconds after ``orig_time``. + duration : array of float, shape (n_annotations,) | float + Durations of the annotations in seconds. If a float, all the + annotations are given the same duration. + description : array of str, shape (n_annotations,) | str + Array of strings containing description for each annotation. If a + string, all the annotations are given the same description. To reject + epochs, use description starting with keyword 'bad'. See example above. + hed_tags : array of str, shape (n_annotations,) | str + Array of strings containing a HED tag for each annotation. If a single string + is provided, all annotations are given the same HED tag. + hed_version : str + The HED schema version against which to validate the HED tags. + orig_time : float | str | datetime | tuple of int | None + A POSIX Timestamp, datetime or a tuple containing the timestamp as the + first element and microseconds as the second element. Determines the + starting time of annotation acquisition. If None (default), + starting time is determined from beginning of raw data acquisition. + In general, ``raw.info['meas_date']`` (or None) can be used for syncing + the annotations with raw data if their acquisition is started at the + same time. If it is a string, it should conform to the ISO8601 format. + More precisely to this '%%Y-%%m-%%d %%H:%%M:%%S.%%f' particular case of + the ISO8601 format where the delimiter between date and time is ' '. + %(ch_names_annot)s + + See Also + -------- + mne.Annotations + + Notes + ----- + + .. versionadded:: 1.10 + """ + + def __init__( + self, + onset, + duration, + description, + hed_tags, + hed_version="latest", # TODO @VisLab what is a sensible default here? + orig_time=None, + ch_names=None, + ): + hed = _soft_import("hed", "validation of HED tags in annotations") # noqa + # TODO is some sort of initialization of the HED cache directory necessary? + super().__init__( + onset=onset, + duration=duration, + description=description, + orig_time=orig_time, + ch_names=ch_names, + ) + # TODO validate the HED version the user claims to be using. + self.hed_version = hed_version + self._update_hed_tags(hed_tags=hed_tags) + + def _update_hed_tags(self, hed_tags): + if len(hed_tags) != len(self): + raise ValueError( + f"Number of HED tags ({len(hed_tags)}) must match the number of " + f"annotations ({len(self)})." + ) + # TODO insert validation of HED tags here + self.hed_tags = hed_tags + + def __eq__(self, other): + """Compare to another HEDAnnotations instance.""" + return ( + super().__eq__(self, other) + and np.array_equal(self.hed_tags, other.hed_tags) + and self.hed_version == other.hed_version + ) + + def __repr__(self): + """Show a textual summary of the object.""" + counter = Counter(self.hed_tags) + kinds = ", ".join(["{} ({})".format(*k) for k in sorted(counter.items())]) + kinds = (": " if len(kinds) > 0 else "") + kinds + ch_specific = ", channel-specific" if self._any_ch_names() else "" + s = ( + f"HEDAnnotations | {len(self.onset)} segment" + f"{_pl(len(self.onset))}{ch_specific}{kinds}" + ) + return "<" + shorten(s, width=77, placeholder=" ...") + ">" + + def __getitem__(self, key, *, with_ch_names=None): + """Propagate indexing and slicing to the underlying numpy structure.""" + result = super().__getitem__(self, key, with_ch_names=with_ch_names) + if isinstance(result, OrderedDict): + result["hed_tags"] = self.hed_tags[key] + else: + key = list(key) if isinstance(key, tuple) else key + hed_tags = self.hed_tags[key] + return HEDAnnotations( + result.onset, + result.duration, + result.description, + hed_tags, + hed_version=self.hed_version, + orig_time=self.orig_time, + ch_names=result.ch_names, + ) + + def append(self, onset, duration, description, ch_names=None): + """TODO.""" + pass + + def count(self): + """TODO. Unlike Annotations.count, keys should be HED tags not descriptions.""" + pass + + def crop( + self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None + ): + """TODO.""" + pass + + def delete(self, idx): + """TODO.""" + pass + + def to_data_frame(self, time_format="datetime"): + """TODO.""" + pass class EpochAnnotationsMixin: