6
6
7
7
try:
8
8
from pynwb import NWBHDF5IO
9
- from pynwb.ophys import TwoPhotonSeries
9
+ from pynwb.ophys import TwoPhotonSeries, OnePhotonSeries
10
10
11
11
HAVE_NWB = True
12
12
except ImportError:
@@ -55,8 +55,8 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
55
55
----------
56
56
file_path: str
57
57
The location of the folder containing dataset.nwb file
58
- optical_series_name: str ( optional)
59
- optical series to extract data from
58
+ optical_series_name: string, optional
59
+ The name of the optical series to extract data from.
60
60
"""
61
61
ImagingExtractor.__init__(self)
62
62
self._path = file_path
@@ -73,33 +73,34 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
73
73
raise ValueError("No acquisitions found in the .nwb file.")
74
74
self._optical_series_name = a_names[0]
75
75
76
- self.two_photon_series = self.nwbfile.acquisition[self._optical_series_name]
77
- assert isinstance(
78
- self.two_photon_series, TwoPhotonSeries
79
- ), "The optical series must be of type pynwb.TwoPhotonSeries"
76
+ self.photon_series = self.nwbfile.acquisition[self._optical_series_name]
77
+ valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries]
78
+ assert any(
79
+ [isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types]
80
+ ), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries."
80
81
81
82
# TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor)
82
- assert self.two_photon_series .external_file is None, "Only 'raw' format is currently supported"
83
+ assert self.photon_series .external_file is None, "Only 'raw' format is currently supported"
83
84
84
85
# Load the two video structures that TwoPhotonSeries supports.
85
86
self._data_has_channels_axis = True
86
- if len(self.two_photon_series .data.shape) == 3:
87
+ if len(self.photon_series .data.shape) == 3:
87
88
self._num_channels = 1
88
- self._num_frames, self._columns, self._num_rows = self.two_photon_series .data.shape
89
+ self._num_frames, self._columns, self._num_rows = self.photon_series .data.shape
89
90
else:
90
91
raise_multi_channel_or_depth_not_implemented(extractor_name=self.extractor_name)
91
92
92
93
# Set channel names (This should disambiguate which optical channel)
93
- self._channel_names = [i.name for i in self.two_photon_series .imaging_plane.optical_channel]
94
+ self._channel_names = [i.name for i in self.photon_series .imaging_plane.optical_channel]
94
95
95
96
# Set sampling frequency
96
- if hasattr(self.two_photon_series , "timestamps") and self.two_photon_series .timestamps:
97
- self._sampling_frequency = 1.0 / np.median(np.diff(self.two_photon_series .timestamps))
98
- self._imaging_start_time = self.two_photon_series .timestamps[0]
99
- self.set_times(np.array(self.two_photon_series .timestamps))
97
+ if hasattr(self.photon_series , "timestamps") and self.photon_series .timestamps:
98
+ self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series .timestamps))
99
+ self._imaging_start_time = self.photon_series .timestamps[0]
100
+ self.set_times(np.array(self.photon_series .timestamps))
100
101
else:
101
- self._sampling_frequency = self.two_photon_series .rate
102
- self._imaging_start_time = self.two_photon_series .fields.get("starting_time", 0.0)
102
+ self._sampling_frequency = self.photon_series .rate
103
+ self._imaging_start_time = self.photon_series .fields.get("starting_time", 0.0)
103
104
104
105
# Fill epochs dictionary
105
106
self._epochs = {}
@@ -158,7 +159,7 @@ def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0):
158
159
slice_start = 0
159
160
slice_stop = self.get_num_frames()
160
161
161
- data = self.two_photon_series .data
162
+ data = self.photon_series .data
162
163
frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1])
163
164
164
165
if isinstance(frame_idxs, int):
@@ -169,7 +170,7 @@ def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0
169
170
start_frame = start_frame if start_frame is not None else 0
170
171
end_frame = end_frame if end_frame is not None else self.get_num_frames()
171
172
172
- video = self.two_photon_series .data
173
+ video = self.photon_series .data
173
174
video = video[start_frame:end_frame].transpose([0, 2, 1])
174
175
return video
175
176
0 commit comments