Skip to content

Commit 8ca7ef3

Browse files
Merge pull request #229 from catalystneuro/support_OnePhotonSeries
Refactor `NwbImagingExtractor` to support extracting data from `OnePhotonSeries`
2 parents 1b2eb7d + 475bd02 commit 8ca7ef3

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

src/roiextractors/extractors/nwbextractors/nwbextractors.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
try:
88
from pynwb import NWBHDF5IO
9-
from pynwb.ophys import TwoPhotonSeries
9+
from pynwb.ophys import TwoPhotonSeries, OnePhotonSeries
1010

1111
HAVE_NWB = True
1212
except ImportError:
@@ -55,8 +55,8 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
5555
----------
5656
file_path: str
5757
The location of the folder containing dataset.nwb file
58-
optical_series_name: str (optional)
59-
optical series to extract data from
58+
optical_series_name: string, optional
59+
The name of the optical series to extract data from.
6060
"""
6161
ImagingExtractor.__init__(self)
6262
self._path = file_path
@@ -73,33 +73,34 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
7373
raise ValueError("No acquisitions found in the .nwb file.")
7474
self._optical_series_name = a_names[0]
7575

76-
self.two_photon_series = self.nwbfile.acquisition[self._optical_series_name]
77-
assert isinstance(
78-
self.two_photon_series, TwoPhotonSeries
79-
), "The optical series must be of type pynwb.TwoPhotonSeries"
76+
self.photon_series = self.nwbfile.acquisition[self._optical_series_name]
77+
valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries]
78+
assert any(
79+
[isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types]
80+
), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries."
8081

8182
# TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor)
82-
assert self.two_photon_series.external_file is None, "Only 'raw' format is currently supported"
83+
assert self.photon_series.external_file is None, "Only 'raw' format is currently supported"
8384

8485
# Load the two video structures that TwoPhotonSeries supports.
8586
self._data_has_channels_axis = True
86-
if len(self.two_photon_series.data.shape) == 3:
87+
if len(self.photon_series.data.shape) == 3:
8788
self._num_channels = 1
88-
self._num_frames, self._columns, self._num_rows = self.two_photon_series.data.shape
89+
self._num_frames, self._columns, self._num_rows = self.photon_series.data.shape
8990
else:
9091
raise_multi_channel_or_depth_not_implemented(extractor_name=self.extractor_name)
9192

9293
# Set channel names (This should disambiguate which optical channel)
93-
self._channel_names = [i.name for i in self.two_photon_series.imaging_plane.optical_channel]
94+
self._channel_names = [i.name for i in self.photon_series.imaging_plane.optical_channel]
9495

9596
# Set sampling frequency
96-
if hasattr(self.two_photon_series, "timestamps") and self.two_photon_series.timestamps:
97-
self._sampling_frequency = 1.0 / np.median(np.diff(self.two_photon_series.timestamps))
98-
self._imaging_start_time = self.two_photon_series.timestamps[0]
99-
self.set_times(np.array(self.two_photon_series.timestamps))
97+
if hasattr(self.photon_series, "timestamps") and self.photon_series.timestamps:
98+
self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series.timestamps))
99+
self._imaging_start_time = self.photon_series.timestamps[0]
100+
self.set_times(np.array(self.photon_series.timestamps))
100101
else:
101-
self._sampling_frequency = self.two_photon_series.rate
102-
self._imaging_start_time = self.two_photon_series.fields.get("starting_time", 0.0)
102+
self._sampling_frequency = self.photon_series.rate
103+
self._imaging_start_time = self.photon_series.fields.get("starting_time", 0.0)
103104

104105
# Fill epochs dictionary
105106
self._epochs = {}
@@ -158,7 +159,7 @@ def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0):
158159
slice_start = 0
159160
slice_stop = self.get_num_frames()
160161

161-
data = self.two_photon_series.data
162+
data = self.photon_series.data
162163
frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1])
163164

164165
if isinstance(frame_idxs, int):
@@ -169,7 +170,7 @@ def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0
169170
start_frame = start_frame if start_frame is not None else 0
170171
end_frame = end_frame if end_frame is not None else self.get_num_frames()
171172

172-
video = self.two_photon_series.data
173+
video = self.photon_series.data
173174
video = video[start_frame:end_frame].transpose([0, 2, 1])
174175
return video
175176

0 commit comments

Comments
 (0)