diff --git a/CHANGELOG.md b/CHANGELOG.md index 062bd49e8..3e09375e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file. - Fixes bug that caused file paths on windows machines to be incorrect in Visual behavior user-facing classes - Updates to support MESO.2 - Loosens/updates required versions for several dependencies +- Updates in order to generate valid NWB files for Neuropixels Visual Coding data collected between 2019 and 2021 ## [2.13.1] = 2021-10-04 - Fixes bug that was preventing the BehaviorSession from properly instantiating passive sessions. diff --git a/allensdk/brain_observatory/comparison_utils.py b/allensdk/brain_observatory/comparison_utils.py index b6ffbf99f..50fbe09f7 100644 --- a/allensdk/brain_observatory/comparison_utils.py +++ b/allensdk/brain_observatory/comparison_utils.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd import xarray as xr -from pandas.util.testing import assert_frame_equal +from pandas.testing import assert_frame_equal def compare_fields(x1: Any, x2: Any, err_msg="", diff --git a/allensdk/brain_observatory/ecephys/__init__.py b/allensdk/brain_observatory/ecephys/__init__.py index 62f154428..8e131952a 100644 --- a/allensdk/brain_observatory/ecephys/__init__.py +++ b/allensdk/brain_observatory/ecephys/__init__.py @@ -1,7 +1,6 @@ import numpy as np - UNIT_FILTER_DEFAULTS = { "amplitude_cutoff_maximum": { "value": 0.1, @@ -23,7 +22,7 @@ def get_unit_filter_value(key, pop=True, replace_none=True, **source): value = source.pop(key, UNIT_FILTER_DEFAULTS[key]["value"]) else: value = source.get(key, UNIT_FILTER_DEFAULTS[key]["value"]) - + if value is None and replace_none: value = UNIT_FILTER_DEFAULTS[key]["missing"] diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py b/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py index 31773553c..5c6067ed6 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py @@ -6,7 +6,8 @@ from ._schemas import InputParameters, OutputParameters from .barcode_sync_dataset import BarcodeSyncDataset from .channel_states import extract_barcodes_from_states, \ - extract_splits_from_states + extract_splits_from_states, \ + extract_splits_from_barcode_times from .probe_synchronizer import ProbeSynchronizer @@ -29,6 +30,12 @@ def align_timestamps(args): channel_states, timestamps, probe["sampling_rate"] ) + barcode_split_times = extract_splits_from_barcode_times( + probe_barcode_times + ) + + probe_split_times = np.union1d(probe_split_times, barcode_split_times) + print("Split times:") print(probe_split_times) @@ -92,6 +99,7 @@ def align_timestamps(args): "global_probe_lfp_sampling_rate"] = lfp_sampling_rate this_probe_output_info["output_paths"] = mapped_files this_probe_output_info["name"] = probe["name"] + this_probe_output_info["split_times"] = probe_split_times probe_output_info.append(this_probe_output_info) diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py b/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py index 5cb0e7563..df3a5fa4f 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py @@ -1,6 +1,6 @@ -from argschema import ArgSchema, ArgSchemaParser +from argschema import ArgSchema from argschema.schemas import DefaultSchema -from argschema.fields import Nested, InputDir, String, Float, Dict, Int, List +from argschema.fields import Nested, String, Float, Dict, Int, List class ProbeMappable(DefaultSchema): @@ -10,11 +10,13 @@ class ProbeMappable(DefaultSchema): ) input_path = String( required=True, - help="Input path for this file. Should point to a file containing a 1D timestamps array with values in probe samples.", + help="""Input path for this file. Should point to a file containing a 1D + timestamps array with values in probe samples.""", ) output_path = String( required=True, - help="Output path for the mapped version of this file. Will write a 1D timestamps array with values in seconds on the master clock.", + help="""Output path for the mapped version of this file. Will write a 1D + timestamps array with values in seconds on the master clock.""" ) @@ -22,28 +24,36 @@ class ProbeInputParameters(DefaultSchema): name = String(required=True, help="Identifier for this probe") sampling_rate = Float( required=True, - help="The sampling rate of the probe, in Hz, assessed on the probe clock.", + help="""The sampling rate of the probe, in Hz, assessed on + the probe clock.""", ) lfp_sampling_rate = Float( - required=True, help="The sampling rate of the LFP collected on this probe." + required=True, help="""The sampling rate of the LFP collected on this + probe.""" ) start_index = Int( - default=0, help="Sample index of probe recording start time. Defaults to 0." + default=0, help="""Sample index of probe recording start time. + Defaults to 0.""" ) barcode_channel_states_path = String( required=True, - help="Path to the channel states file. This file contains a 1-dimensional array whose axis is events and whose " - "values indicate the state of the channel line (rising or falling) at that event.", + help="""Path to the channel states file. This file contains a + 1-dimensional array whose axis is events and whose + values indicate the state of the channel line (rising or + falling) at that event.""", ) barcode_timestamps_path = String( required=True, - help="Path to the timestamps file. This file contains a 1-dimensional array whose axis is events and whose " - "values indicate the sample on which each event was detected.", + help="""Path to the timestamps file. This file contains a 1-dimensional + array whose axis is events and whose values indicate the sample + on which each event was detected.""", ) mappable_timestamp_files = Nested( ProbeMappable, many=True, - help="Timestamps files for this probe. Describe the times (in probe samples) when e.g. lfp samples were taken or spike events occured", + help="""Timestamps files for this probe. Describe the times (in probe + samples) when e.g. lfp samples were taken or spike events + occured""", ) @@ -54,7 +64,9 @@ class InputParameters(ArgSchema): help="Probes whose data will be aligned to the master clock.", ) sync_h5_path = String( - required=True, help="path to h5 file containing syncronization information" + required=True, + help="""path to h5 file containing syncronization + information""" ) @@ -66,15 +78,24 @@ class ProbeOutputParameters(DefaultSchema): ) total_time_shift = Float( required=True, - help="Translation (in seconds) from master->probe times computed for this probe.", + help="""Translation (in seconds) from master->probe times computed + for this probe.""", ) global_probe_sampling_rate = Float( required=True, - help="The sampling rate of this probe in Hz, assessed on the master clock.", + help="""The sampling rate of this probe in Hz, assessed on the master + clock.""", ) global_probe_lfp_sampling_rate = Float( required=True, - help="The sampling rate of LFP collected on this probe in Hz, assessed on the master clock.", + help="""The sampling rate of LFP collected on this probe in Hz, + assessed on the master clock.""", + ) + split_times = List( + Float(), + required=True, + help="""Start/stop times of likely dropped data, due to gaps in + recording or irregular barcode intervals""" ) diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py b/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py index 850308ee3..31503e54f 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py @@ -21,7 +21,7 @@ def extract_barcodes_from_times( Minimun duration of time between barcodes. bar_duration : numeric, optional A value slightly shorter than the expected duration of each bar - barcode_duration_ceiling : numeric, optional + barcode_duration_ceiling : numeric, optional The maximum duration of a single barcode nbits : int, optional The bit-depth of each barcode @@ -36,7 +36,8 @@ def extract_barcodes_from_times( Notes ----- ignores first code in prod (ok, but not intended) - ignores first on pulse (intended - this is needed to identify that a barcode is starting) + ignores first on pulse (intended - this is needed to identify that a + barcode is starting) """ @@ -50,51 +51,57 @@ def extract_barcodes_from_times( oncode = on_times[ np.where( - np.logical_and(on_times > t, on_times < t + barcode_duration_ceiling) + np.logical_and(on_times > t, + on_times < t + barcode_duration_ceiling) )[0] ] offcode = off_times[ np.where( - np.logical_and(off_times > t, off_times < t + barcode_duration_ceiling) + np.logical_and(off_times > t, + off_times < t + barcode_duration_ceiling) )[0] ] - currTime = offcode[0] + if len(offcode) > 0: - bits = np.zeros((nbits,)) + currTime = offcode[0] - for bit in range(0, nbits): + bits = np.zeros((nbits,)) - nextOn = np.where(oncode > currTime)[0] - nextOff = np.where(offcode > currTime)[0] + for bit in range(0, nbits): - if nextOn.size > 0: - nextOn = oncode[nextOn[0]] - else: - nextOn = t + inter_barcode_interval + nextOn = np.where(oncode > currTime)[0] + nextOff = np.where(offcode > currTime)[0] - if nextOff.size > 0: - nextOff = offcode[nextOff[0]] - else: - nextOff = t + inter_barcode_interval + if nextOn.size > 0: + nextOn = oncode[nextOn[0]] + else: + nextOn = t + inter_barcode_interval - if nextOn < nextOff: - bits[bit] = 1 + if nextOff.size > 0: + nextOff = offcode[nextOff[0]] + else: + nextOff = t + inter_barcode_interval - currTime += bar_duration + if nextOn < nextOff: + bits[bit] = 1 - barcode = 0 + currTime += bar_duration - # least sig left - for bit in range(0, nbits): - barcode += bits[bit] * pow(2, bit) + barcode = 0 - barcodes.append(barcode) + # least sig left + for bit in range(0, nbits): + barcode += bits[bit] * pow(2, bit) + + barcodes.append(barcode) return barcode_start_times, barcodes -def find_matching_index(master_barcodes, probe_barcodes, alignment_type="start"): +def find_matching_index(master_barcodes, + probe_barcodes, + alignment_type="start"): """Given a set of barcodes for the master clock and the probe clock, find the indices of a matching set, either starting from the beginning or the end of the list. @@ -147,21 +154,24 @@ def find_matching_index(master_barcodes, probe_barcodes, alignment_type="start") def match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes): - """Given sequences of barcode values and (local) times on a probe line and a master - line, find the time points on each clock corresponding to the first and last shared - barcode. + """Given sequences of barcode values and (local) times on a probe line + and a master line, find the time points on each clock corresponding to + the first and last shared barcode. - If there's only one probe barcode, only the first matching timepoint is returned. + If there's only one probe barcode, only the first matching timepoint + is returned. Parameters ---------- master_times : np.ndarray - start times of barcodes (according to the master clock) on the master line. + start times of barcodes (according to the master clock) on the + master line. One per barcode. master_barcodes : np.ndarray barcode values on the master line. One per barcode probe_times : np.ndarray - start times (according to the probe clock) of barcodes on the probe line. + start times (according to the probe clock) of barcodes on the + probe line. One per barcode probe_barcodes : np.ndarray barcode values on the probe_line. One per barcode @@ -169,9 +179,11 @@ def match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes): Returns ------- probe_interval : np.ndarray - Start and end times of the matched interval according to the probe_clock. + Start and end times of the matched interval according to the + probe_clock. master_interval : np.ndarray - Start and end times of the matched interval according to the master clock + Start and end times of the matched interval according to the + master clock """ @@ -190,8 +202,11 @@ def match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes): print("Master start index: " + str(master_start_index)) if len(probe_barcodes) > 2: - master_end_index, probe_end_index = find_matching_index(master_barcodes, probe_barcodes, alignment_type='end') - + master_end_index, probe_end_index = \ + find_matching_index(master_barcodes, + probe_barcodes, + alignment_type='end') + if probe_end_index is not None: print("Probe end index: " + str(probe_end_index)) t_m_end = master_times[master_end_index] @@ -218,14 +233,14 @@ def linear_transform_from_intervals(master, probe): Returns ------- scale : float - Scale factor. If > 1.0, the probe clock is running fast compared to the + Scale factor. If > 1.0, the probe clock is running fast compared to the master clock. If < 1.0, the probe clock is running slow. translation : float If > 0, the probe clock started before the master clock. If > 0, after. Notes ----- - solves + solves (master + translation) * scale = probe for scale and translation """ @@ -251,17 +266,20 @@ def get_probe_time_offset( acq_start_index, local_probe_rate, ): - """Time offset between master clock and recording probes. For converting probe time to master clock. - + """Time offset between master clock and recording probes. For converting + probe time to master clock. + Parameters ---------- master_times : np.ndarray - start times of barcodes (according to the master clock) on the master line. + start times of barcodes (according to the master clock) on the master + line. One per barcode. master_barcodes : np.ndarray barcode values on the master line. One per barcode probe_times : np.ndarray - start times (according to the probe clock) of barcodes on the probe line. + start times (according to the probe clock) of barcodes on the probe + line. One per barcode probe_barcodes : np.ndarray barcode values on the probe_line. One per barcode @@ -269,18 +287,20 @@ def get_probe_time_offset( sample index of probe acquisition start time local_probe_rate : float the probe's apparent sampling rate - + Returns ------- total_time_shift : float - Time at which the probe started acquisition, assessed on - the master clock. If < 0, the probe started earlier than the master line. + Time at which the probe started acquisition, assessed on + the master clock. If < 0, the probe started earlier than the master + line. probe_rate : float The probe's sampling rate, assessed on the master clock master_endpoints : iterable - Defines the start and end times of the sync interval on the master clock - + Defines the start and end times of the sync interval on the master + clock + """ probe_endpoints, master_endpoints = match_barcodes( diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py b/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py index 2f151c880..83804f1d8 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py @@ -19,6 +19,8 @@ def barcode_line(self): return self.line_labels.index("barcode") elif "barcodes" in self.line_labels: return self.line_labels.index("barcodes") + elif "barcode_ephys" in self.line_labels: + return self.line_labels.index("barcode_ephys") else: raise ValueError("no barcode line found") @@ -27,10 +29,10 @@ def extract_barcodes(self, **barcode_kwargs): Parameters ---------- - **barcode_kwargs : + **barcode_kwargs : Will be passed to .barcode.extract_barcodes_from_times - Returns + Returns ------- times : np.ndarray The start times of each detected barcode. @@ -57,12 +59,13 @@ def get_barcode_table(self, **barcode_kwargs): Notes ----- - This method is deprecated! + This method is deprecated! """ warnings.warn( np.VisibleDeprecationWarning( - "This function is deprecated as unecessary (and slated for removal). Instead, simply use extract_barcodes." + """This function is deprecated as unecessary (and slated for + removal). Instead, simply use extract_barcodes.""" ) ) diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py b/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py index ce33d09e5..91727b0da 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py @@ -16,7 +16,7 @@ def extract_barcodes_from_states( Sample index of each event. sampling_rate : numeric Samples / second - **barcode_kwargs : + **barcode_kwargs : Additional parameters describing the barcodes. @@ -34,7 +34,7 @@ def extract_barcodes_from_states( def extract_splits_from_states( channel_states, timestamps, sampling_rate, **barcode_kwargs ): - """Obtain barcodes from timestamped rising/falling edges. + """Obtain data split times from timestamped rising/falling edges. Parameters ---------- @@ -44,7 +44,7 @@ def extract_splits_from_states( Sample index of each event. sampling_rate : numeric Samples / second - **barcode_kwargs : + **barcode_kwargs : Additional parameters describing the barcodes. @@ -58,3 +58,36 @@ def extract_splits_from_states( T_split = np.array([0]) return T_split + + +def extract_splits_from_barcode_times( + barcode_times, + tolerance=0.0001 +): + """Determine locations of likely dropped data from barcode times + Parameters + ---------- + barcode_times : numpy.ndarray + probe barcode times + tolerance : float + Timing tolerance (relative to median interval) + """ + + barcode_intervals = np.diff(barcode_times) + + median_interval = np.median(barcode_intervals) + + irregular_intervals = np.where(np.abs(barcode_intervals - median_interval) + > tolerance * median_interval)[0] + + T_split = [0] + + for i in irregular_intervals: + + T_split.append(barcode_times[i-1]) + + if i+1 < len(barcode_times): + T_split.append(barcode_times[i+1]) + + return np.array(T_split) +# diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py b/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py index a624e9486..22a239e5a 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py @@ -6,7 +6,7 @@ class ProbeSynchronizer(object): @property def sampling_rate_scale(self): - """ The ratio of the probe's sampling rate assessed on the global clock to the + """ The ratio of the probe's sampling rate assessed on the global clock to the probe's locally assessed sampling rate. """ @@ -57,8 +57,9 @@ def __call__(self, samples, sync_condition="master"): Returns ------- - numpy.ndarray : - Sample timestamps in seconds on the master (default) or probe clock. + numpy.ndarray : + Sample timestamps in seconds on the master (default) or + probe clock. """ @@ -70,7 +71,8 @@ def __call__(self, samples, sync_condition="master"): if self.global_probe_sampling_rate > 0: if sync_condition == "probe": - samples[in_range] = samples[in_range] / self.local_probe_sampling_rate + samples[in_range] = samples[in_range] / \ + self.local_probe_sampling_rate elif sync_condition == "master": samples[in_range] = ( @@ -105,12 +107,14 @@ def compute( Parameters ---------- master_barcode_times : np.ndarray - start times of barcodes (according to the master clock) on the master line. + start times of barcodes (according to the master clock) on the + master line. One per barcode. master_barcodes : np.ndarray barcode values on the master line. One per barcode probe_barcode_times : np.ndarray - start times (according to the probe clock) of barcodes on the probe line. + start times (according to the probe clock) of barcodes on the + probe line. One per barcode probe_barcodes : np.ndarray barcode values on the probe_line. One per barcode @@ -122,18 +126,21 @@ def compute( sample index of probe acquisition start time local_probe_sampling_rate : float the probe's apparent sampling rate - + Returns ------- - ProbeSynchronizer : - When called, applies the transform computed here to samples on the probe clock. + ProbeSynchronizer : + When called, applies the transform computed here to samples on the + probe clock. """ times_array = np.array(probe_barcode_times) barcodes_array = np.array(probe_barcodes) - ok_barcodes = np.where((times_array > min_time) * (times_array < max_time))[0] + ok_barcodes = np.where((times_array > min_time) * + (times_array < max_time))[0] + ok_barcodes = ok_barcodes[ok_barcodes < len(barcodes_array)] times_to_align = list(times_array[ok_barcodes]) barcodes_to_align = list(barcodes_array[ok_barcodes]) @@ -141,14 +148,15 @@ def compute( print("Num barcodes: " + str(len(barcodes_to_align))) - total_time_shift, global_probe_sampling_rate, _ = barcode.get_probe_time_offset( - master_barcode_times, - master_barcodes, - times_to_align, - barcodes_to_align, - probe_start_index, - local_probe_sampling_rate, - ) + total_time_shift, global_probe_sampling_rate, _ = \ + barcode.get_probe_time_offset( + master_barcode_times, + master_barcodes, + times_to_align, + barcodes_to_align, + probe_start_index, + local_probe_sampling_rate, + ) else: print("Not enough barcodes...setting sampling rate to 0") diff --git a/allensdk/brain_observatory/ecephys/ecephys_session.py b/allensdk/brain_observatory/ecephys/ecephys_session.py index 8f9a8429a..96fcc458c 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session.py @@ -9,18 +9,23 @@ import scipy.stats from allensdk.core.lazy_property import LazyPropertyMixin -from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysSessionApi, EcephysNwbSessionApi, EcephysNwb1Api +from allensdk.brain_observatory.ecephys.ecephys_session_api import ( + EcephysSessionApi, + EcephysNwbSessionApi, + EcephysNwb1Api) from allensdk.brain_observatory.ecephys.stimulus_table import naming_utilities -from allensdk.brain_observatory.ecephys.stimulus_table._schemas import default_stimulus_renames, default_column_renames - +from allensdk.brain_observatory.ecephys.stimulus_table._schemas import ( + default_stimulus_renames, + default_column_renames) +# stimulus_presentation column names not describing a parameter of a stimulus NON_STIMULUS_PARAMETERS = tuple([ 'start_time', 'stop_time', 'duration', 'stimulus_block', "stimulus_condition_id" -]) # stimulus_presentation column names not describing a parameter of a stimulus +]) class EcephysSession(LazyPropertyMixin): @@ -29,35 +34,41 @@ class EcephysSession(LazyPropertyMixin): Attributes ---------- units : pd.Dataframe - A table whose rows are sorted units (putative neurons) and whose columns are characteristics - of those units. + A table whose rows are sorted units (putative neurons) and whose + columns are characteristics of those units. Index is: unit_id : int Unique integer identifier for this unit. Columns are: firing_rate : float - This unit's firing rate (spikes / s) calculated over the window of that unit's activity - (the time from its first detected spike to its last). + This unit's firing rate (spikes / s) calculated over the + window of that unit's activity (the time from its first + detected spike to its last). isi_violations : float - Estamate of this unit's contamination rate (larger means that more of the spikes assigned - to this unit probably originated from other neurons). Calculated as a ratio of the firing - rate of the unit over periods where spikes would be isi-violating vs the total firing - rate of the unit. + Estamate of this unit's contamination rate (larger means + that more of the spikes assigned to this unit probably + originated from other neurons). Calculated as a ratio of the + firing rate of the unit over periods where spikes would be + isi-violating vs the total firing rate of the unit. peak_channel_id : int - Unique integer identifier for this unit's peak channel. A unit's peak channel is the channel on - which its peak-to-trough amplitude difference is maximized. This is assessed using the kilosort 2 - templates rather than the mean waveforms for a unit. + Unique integer identifier for this unit's peak channel. + A unit's peak channel is the channel on which its + peak-to-trough amplitude difference is maximized. This is + assessed using the kilosort 2 templates rather than the mean + waveforms for a unit. snr : float Signal to noise ratio for this unit. probe_horizontal_position : numeric - The horizontal (short-axis) position of this unit's peak channel in microns. + The horizontal (short-axis) position of this unit's peak + channel in microns. probe_vertical_position : numeric - The vertical (long-axis, lower values are closer to the probe base) position of - this unit's peak channel in microns. + The vertical (long-axis, lower values are closer to the probe + base) position of this unit's peak channel in microns. probe_id : int Unique integer identifier for this unit's probe. probe_description : str - Human-readable description carrying miscellaneous information about this unit's probe. + Human-readable description carrying miscellaneous information + about this unit's probe. location : str Gross-scale location of this unit's probe. spike_times : dict @@ -69,11 +80,13 @@ class EcephysSession(LazyPropertyMixin): values : np.ndarray Running speed of the experimental subject (in cm / s). mean_waveforms : dict - Maps integer unit ids to xarray.DataArrays containing mean spike waveforms for that unit. + Maps integer unit ids to xarray.DataArrays containing mean spike + waveforms for that unit. stimulus_presentations : pd.DataFrame - Table whose rows are stimulus presentations and whose columns are presentation characteristics. - A stimulus presentation is the smallest unit of distinct stimulus presentation and lasts for - (usually) 1 60hz frame. Since not all parameters are relevant to all stimuli, this table + Table whose rows are stimulus presentations and whose columns are + presentation characteristics. A stimulus presentation is the smallest + unit of distinct stimulus presentation and lasts for (usually) 1 60 Hz + frame. Since not all parameters are relevant to all stimuli, this table contains many 'null' values. Index is stimulus_presentation_id : int @@ -86,14 +99,17 @@ class EcephysSession(LazyPropertyMixin): duration : float stop_time - start_time (s). Included for convenience. stimulus_name : str - Identifies the stimulus family (e.g. "drifting_gratings" or "natural_movie_3") used - for this presentation. The stimulus family, along with relevant parameter values, provides the - information required to reconstruct the stimulus presented during this presentation. The empty - string indicates a blank period. + Identifies the stimulus family (e.g. "drifting_gratings" or + "natural_movie_3") used for this presentation. The stimulus + family, along with relevant parameter values, provides the + information required to reconstruct the stimulus presented + during this presentation. The empty string indicates a blank + period. stimulus_block : numeric - A stimulus block is made by sequentially presenting presentations from the same stimulus family. - This value is the index of the block which contains this presentation. During a blank period, - this is 'null'. + A stimulus block is made by sequentially presenting + presentations from the same stimulus family. This value is the + index of the block which contains this presentation. + During a blank period, this is 'null'. TF : float Temporal frequency, or 'null' when not appropriate. SF : float @@ -107,15 +123,18 @@ class EcephysSession(LazyPropertyMixin): Image : numeric Phase : float stimulus_condition_id : integer - identifies the session-unique stimulus condition (permutation of parameters) to which this presentation - belongs + identifies the session-unique stimulus condition (permutation + of parameters) to which this presentation belongs stimulus_conditions : pd.DataFrame - Each row is a unique permutation (within this session) of stimulus parameters presented during this experiment. - Columns are as stimulus presentations, sans start_time, end_time, stimulus_block, and duration. + Each row is a unique permutation (within this session) of stimulus + parameters presented during this experiment. Columns are as stimulus + presentations, sans start_time, end_time, stimulus_block, and duration. inter_presentation_intervals : pd.DataFrame - The elapsed time between each immediately sequential pair of stimulus presentations. This is a dataframe with a - two-level multiindex (levels are 'from_presentation_id' and 'to_presentation_id'). It has a single column, - 'interval', which reports the elapsed time between the two presentations in seconds on the experiment's master + The elapsed time between each immediately sequential pair of stimulus + presentations. This is a dataframe with a two-level multiindex (levels + are 'from_presentation_id' and 'to_presentation_id'). It has a single + column, 'interval', which reports the elapsed time between the + two presentations in seconds on the experiment's master clock. ''' @@ -205,10 +224,14 @@ def session_type(self): @property def units(self): - return self._units.drop(columns=['width_rf', 'height_rf', - 'on_screen_rf', 'time_to_peak_fl', - 'time_to_peak_rf', 'time_to_peak_sg', - 'sustained_idx_fl', 'time_to_peak_dg'], + return self._units.drop(columns=['width_rf', + 'height_rf', + 'on_screen_rf', + 'time_to_peak_fl', + 'time_to_peak_rf', + 'time_to_peak_sg', + 'sustained_idx_fl', + 'time_to_peak_dg'], errors='ignore') @property @@ -240,7 +263,8 @@ def metadata(self): @property def stimulus_presentations(self): - return self.__class__._remove_detailed_stimulus_parameters(self._stimulus_presentations) + return self.__class__._remove_detailed_stimulus_parameters( + self._stimulus_presentations) @property def spike_times(self): @@ -276,23 +300,38 @@ def __init__( self.api: EcephysSessionApi = api - self.ecephys_session_id = self.LazyProperty(self.api.get_ecephys_session_id) - self.session_start_time = self.LazyProperty(self.api.get_session_start_time) - self.running_speed = self.LazyProperty(self.api.get_running_speed) - self.mean_waveforms = self.LazyProperty(self.api.get_mean_waveforms, wrappers=[self._build_mean_waveforms]) - self._spike_times = self.LazyProperty(self.api.get_spike_times, wrappers=[self._build_spike_times]) - self.optogenetic_stimulation_epochs = self.LazyProperty(self.api.get_optogenetic_stimulation) - self.spike_amplitudes = self.LazyProperty(self.api.get_spike_amplitudes) + self.ecephys_session_id = \ + self.LazyProperty(self.api.get_ecephys_session_id) + self.session_start_time = \ + self.LazyProperty(self.api.get_session_start_time) + self.running_speed = \ + self.LazyProperty(self.api.get_running_speed) + self.mean_waveforms = \ + self.LazyProperty(self.api.get_mean_waveforms, + wrappers=[self._build_mean_waveforms]) + self._spike_times = \ + self.LazyProperty(self.api.get_spike_times, + wrappers=[self._build_spike_times]) + self.optogenetic_stimulation_epochs = \ + self.LazyProperty(self.api.get_optogenetic_stimulation) + self.spike_amplitudes = \ + self.LazyProperty(self.api.get_spike_amplitudes) self.probes = self.LazyProperty(self.api.get_probes) self.channels = self.LazyProperty(self.api.get_channels) - self._stimulus_presentations = self.LazyProperty(self.api.get_stimulus_presentations, - wrappers=[self._build_stimulus_presentations, self._mask_invalid_stimulus_presentations]) - self.inter_presentation_intervals = self.LazyProperty(self._build_inter_presentation_intervals) + self._stimulus_presentations = \ + self.LazyProperty( + self.api.get_stimulus_presentations, + wrappers=[self._build_stimulus_presentations, + self._mask_invalid_stimulus_presentations]) + self.inter_presentation_intervals = \ + self.LazyProperty(self._build_inter_presentation_intervals) self.invalid_times = self.LazyProperty(self.api.get_invalid_times) - self._units = self.LazyProperty(self.api.get_units, wrappers=[self._build_units_table]) + self._units = \ + self.LazyProperty(self.api.get_units, + wrappers=[self._build_units_table]) self._rig_metadata = self.LazyProperty(self.api.get_rig_metadata) self._metadata = self.LazyProperty(self.api.get_metadata) @@ -300,12 +339,14 @@ def __init__( self.api.test() def get_current_source_density(self, probe_id): - """ Obtain current source density (CSD) of trial-averaged response to a flash stimuli for this probe. - See allensdk.brain_observatory.ecephys.current_source_density for details of CSD calculation. - - CSD is computed with a 1D method (second spatial derivative) without prior spatial smoothing - User should apply spatial smoothing of their choice (e.g., Gaussian filter) to the computed CSD + """ Obtain current source density (CSD) of trial-averaged response + to a flash stimuli for this probe. See + allensdk.brain_observatory.ecephys.current_source_density + for details of CSD calculation. + CSD is computed with a 1D method (second spatial derivative) without + prior spatial smoothing. User should apply spatial smoothing of their + choice (e.g., Gaussian filter) to the computed CSD Parameters ---------- @@ -315,8 +356,9 @@ def get_current_source_density(self, probe_id): Returns ------- xr.DataArray : - dimensions are channel (id) and time (seconds, relative to stimulus onset). Values are current source - density assessed on that channel at that time (V/m^2) + dimensions are channel (id) and time (seconds, relative to stimulus + onset). Values are current source density assessed on that + channel at that time (V/m^2) """ @@ -330,26 +372,31 @@ def get_lfp(self, probe_id, mask_invalid_intervals=True): probe_id : int identify the probe whose LFP data ought to be loaded mask_invalid_intervals : bool - if True (default) will mask data in the invalid intervals with np.nan + if True (default) will mask data in the invalid intervals with + np.nan Returns ------- xr.DataArray : - dimensions are channel (id) and time (seconds). Values are sampled LFP data. + dimensions are channel (id) and time (seconds). Values are sampled + LFP data. Notes ----- - Unlike many other data access methods on this class. This one does not cache the loaded data in memory due to - the large size of the LFP data. + Unlike many other data access methods on this class. This one does not + cache the loaded data in memory due to the large size of the LFP data. ''' if mask_invalid_intervals: probe_name = self.probes.loc[probe_id]["description"] fail_tags = ["all_probes", probe_name] - invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) + invalid_time_intervals = \ + self._filter_invalid_times_by_tags(fail_tags) lfp = self.api.get_lfp(probe_id) time_points = lfp.time - valid_time_points = self._get_valid_time_points(time_points, invalid_time_intervals) + valid_time_points = \ + self._get_valid_time_points(time_points, + invalid_time_intervals) return lfp.where(cond=valid_time_points) else: return self.api.get_lfp(probe_id) @@ -365,8 +412,12 @@ def _get_valid_time_points(self, time_points, invalid_time_intevals): valid_time_points = all_time_points for ix, invalid_time_interval in invalid_time_intevals.iterrows(): - invalid_time_points = (time_points >= invalid_time_interval['start_time']) & (time_points <= invalid_time_interval['stop_time']) - valid_time_points = np.logical_and(valid_time_points, np.logical_not(invalid_time_points)) + invalid_time_points = \ + ((time_points >= invalid_time_interval['start_time']) + & (time_points <= invalid_time_interval['stop_time'])) + valid_time_points = \ + np.logical_and(valid_time_points, + np.logical_not(invalid_time_points)) return valid_time_points @@ -385,13 +436,15 @@ def _filter_invalid_times_by_tags(self, tags): """ invalid_times = self.invalid_times.copy() if not invalid_times.empty: - mask = invalid_times['tags'].apply(lambda x: any([t in x for t in tags])) + mask = invalid_times['tags'].apply(lambda x: + any([t in x for t in tags])) invalid_times = invalid_times[mask] return invalid_times def get_inter_presentation_intervals_for_stimulus(self, stimulus_names): - ''' Get a subset of this session's inter-presentation intervals, filtered by stimulus name. + ''' Get a subset of this session's inter-presentation intervals, + filtered by stimulus name. Parameters ---------- @@ -401,21 +454,39 @@ def get_inter_presentation_intervals_for_stimulus(self, stimulus_names): Returns ------- pd.DataFrame : - inter-presentation intervals, filtered to the requested stimulus names. + inter-presentation intervals, filtered to the requested stimulus + names. ''' - stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}') - filtered_presentations = self.stimulus_presentations[self.stimulus_presentations['stimulus_name'].isin(stimulus_names)] + stimulus_names = \ + coerce_scalar( + stimulus_names, + 'expected stimulus_names to be a collection (list-like), ' + f'but found {type(stimulus_names)}: {stimulus_names}') + filtered_presentations = \ + self.stimulus_presentations[ + self.stimulus_presentations[ + 'stimulus_name' + ].isin(stimulus_names)] filtered_ids = set(filtered_presentations.index.values) return self.inter_presentation_intervals[ - (self.inter_presentation_intervals.index.isin(filtered_ids, level='from_presentation_id')) - & (self.inter_presentation_intervals.index.isin(filtered_ids, level='to_presentation_id')) + (self.inter_presentation_intervals.index.isin( + filtered_ids, + level='from_presentation_id')) + & (self.inter_presentation_intervals.index.isin( + filtered_ids, + level='to_presentation_id')) ] - def get_stimulus_table(self, stimulus_names=None, include_detailed_parameters=False, include_unused_parameters=False): - '''Get a subset of stimulus presentations by name, with irrelevant parameters filtered off + def get_stimulus_table( + self, + stimulus_names=None, + include_detailed_parameters=False, + include_unused_parameters=False): + '''Get a subset of stimulus presentations by name, with irrelevant + parameters filtered off Parameters ---------- @@ -425,31 +496,45 @@ def get_stimulus_table(self, stimulus_names=None, include_detailed_parameters=Fa Returns ------- pd.DataFrame : - Rows are filtered presentations, columns are the relevant subset of stimulus parameters + Rows are filtered presentations, columns are the relevant subset + of stimulus parameters ''' if stimulus_names is None: stimulus_names = self.stimulus_names - stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}') - presentations = self._stimulus_presentations[self._stimulus_presentations['stimulus_name'].isin(stimulus_names)] + stimulus_names = \ + coerce_scalar( + stimulus_names, + 'expected stimulus_names to be a collection (list-like), ' + f'but found {type(stimulus_names)}: {stimulus_names}') + presentations = \ + self._stimulus_presentations[ + self._stimulus_presentations[ + 'stimulus_name' + ].isin(stimulus_names)] if not include_detailed_parameters: - presentations = self.__class__._remove_detailed_stimulus_parameters(presentations) + presentations = \ + self.__class__._remove_detailed_stimulus_parameters( + presentations) if not include_unused_parameters: - presentations = removed_unused_stimulus_presentation_columns(presentations) + presentations = removed_unused_stimulus_presentation_columns( + presentations) return presentations def get_stimulus_epochs(self, duration_thresholds=None): - """ Reports continuous periods of time during which a single kind of stimulus was presented -flipVert + """ Reports continuous periods of time during which a single kind of + stimulus was presented + Parameters --------- duration_thresholds : dict, optional - keys are stimulus names, values are floating point durations in seconds. All epochs with + keys are stimulus names, values are floating point durations in + seconds. All epochs with - a given stimulus name - a duration shorter than the associated threshold will be removed from the results @@ -479,10 +564,15 @@ def get_stimulus_epochs(self, duration_thresholds=None): | (epochs["duration"] >= threshold) ] - return epochs.loc[:, ["start_time", "stop_time", "duration", "stimulus_name", "stimulus_block"]] + return epochs.loc[:, ["start_time", + "stop_time", + "duration", + "stimulus_name", + "stimulus_block"]] def get_invalid_times(self): - """ Report invalid time intervals with tags describing the scope of invalid data + """ Report invalid time intervals with tags describing the scope + of invalid data The tags format: [scope,scope_id,label] @@ -490,20 +580,26 @@ def get_invalid_times(self): 'EcephysSession': data is invalid across session 'EcephysProbe': data is invalid for a single probe label: - 'all_probes': gain fluctuations on the Neuropixels probe result in missed spikes and LFP saturation events - 'stimulus' : very long frames (>3x the normal frame length) make any stimulus-locked analysis invalid - 'probe#': probe # stopped sending data during this interval (spikes and LFP samples will be missing) + 'all_probes': gain fluctuations on the Neuropixels probe result in + missed spikes and LFP saturation events + 'stimulus' : very long frames (>3x the normal frame length) make + any stimulus-locked analysis invalid + 'probe#': probe # stopped sending data during this interval + (spikes and LFP samples will be missing) 'optotagging': missing optotagging data Returns ------- pd.DataFrame : - Rows are invalid intervals, columns are 'start_time' (s), 'stop_time' (s), 'tags' + Rows are invalid intervals, columns are 'start_time' (s), + 'stop_time' (s), 'tags' """ return self.invalid_times - def get_screen_gaze_data(self, include_filtered_data=False) -> Optional[pd.DataFrame]: + def get_screen_gaze_data( + self, + include_filtered_data=False) -> Optional[pd.DataFrame]: """Return a dataframe with estimated gaze position on screen. Parameters @@ -523,7 +619,8 @@ def get_screen_gaze_data(self, include_filtered_data=False) -> Optional[pd.DataF *_screen_coordinates_spherical_x_deg *_screen_coorindates_spherical_y_deg """ - return self.api.get_screen_gaze_data(include_filtered_data=include_filtered_data) + return self.api.get_screen_gaze_data( + include_filtered_data=include_filtered_data) def get_pupil_data(self) -> Optional[pd.DataFrame]: """Return a dataframe with eye tracking ellipse fit data @@ -545,7 +642,8 @@ def _mask_invalid_stimulus_presentations(self, stimulus_presentations): """Mask invalid stimulus presentations Find stimulus presentations overlapping with invalid times - Mask stimulus names with "invalid_presentation", keep "start_time" and "stop_time", mask remaining data with np.nan + Mask stimulus names with "invalid_presentation", keep "start_time" and + "stop_time", mask remaining data with np.nan Parameters ---------- @@ -569,9 +667,12 @@ def _mask_invalid_stimulus_presentations(self, stimulus_presentations): invalid_interval = it['start_time'], it['stop_time'] if _overlap(stim_epoch, invalid_interval): stimulus_presentations.iloc[ix_sp, :] = np.nan - stimulus_presentations.at[ix_sp, "stimulus_name"] = "invalid_presentation" - stimulus_presentations.at[ix_sp, "start_time"] = stim_epoch[0] - stimulus_presentations.at[ix_sp, "stop_time"] = stim_epoch[1] + stimulus_presentations.at[ix_sp, "stimulus_name"] = \ + "invalid_presentation" + stimulus_presentations.at[ix_sp, "start_time"] = \ + stim_epoch[0] + stimulus_presentations.at[ix_sp, "stop_time"] = \ + stim_epoch[1] return stimulus_presentations @@ -585,26 +686,29 @@ def presentationwise_spike_counts( large_bin_size_threshold=0.001, time_domain_callback=None ): - ''' Build an array of spike counts surrounding stimulus onset per unit and stimulus frame. + ''' Build an array of spike counts surrounding stimulus onset per + unit and stimulus frame. Parameters --------- bin_edges : numpy.ndarray - Spikes will be counted into the bins defined by these edges. Values are in seconds, relative - to stimulus onset. + Spikes will be counted into the bins defined by these edges. + Values are in seconds, relative to stimulus onset. stimulus_presentation_ids : array-like Filter to these stimulus presentations unit_ids : array-like Filter to these units binarize : bool, optional - If true, all counts greater than 0 will be treated as 1. This results in lower storage overhead, - but is only reasonable if bin sizes are fine (<= 1 millisecond). + If true, all counts greater than 0 will be treated as 1. This + results in lower storage overhead, but is only reasonable if bin + sizes are fine (<= 1 millisecond). large_bin_size_threshold : float, optional - If binarize is True and the largest bin width is greater than this value, a warning will be emitted. + If binarize is True and the largest bin width is greater than + this value, a warning will be emitted. time_domain_callback : callable, optional The time domain is a numpy array whose values are trial-aligned bin - edges (each row is aligned to a different trial). This optional function will be - applied to the time domain before counting spikes. + edges (each row is aligned to a different trial). This optional + function will be applied to the time domain before counting spikes. Returns ------- @@ -614,24 +718,35 @@ def presentationwise_spike_counts( ''' - stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) + stimulus_presentations = self._filter_owned_df( + 'stimulus_presentations', + ids=stimulus_presentation_ids) units = self._filter_owned_df('units', ids=unit_ids) largest_bin_size = np.amax(np.diff(bin_edges)) if binarize and largest_bin_size > large_bin_size_threshold: warnings.warn( - f'You\'ve elected to binarize spike counts, but your maximum bin width is {largest_bin_size:2.5f} seconds. ' - 'Binarizing spike counts with such a large bin width can cause significant loss of accuracy! ' - f'Please consider only binarizing spike counts when your bins are <= {large_bin_size_threshold} seconds wide.' + 'You\'ve elected to binarize spike counts, but your maximum ' + f'bin width is {largest_bin_size:2.5f} seconds. ' + 'Binarizing spike counts with such a large bin width can ' + 'cause significant loss of accuracy! ' + 'Please consider only binarizing spike counts ' + f'when your bins are <= {large_bin_size_threshold} ' + 'seconds wide.' ) bin_edges = np.array(bin_edges) - domain = build_time_window_domain(bin_edges, stimulus_presentations['start_time'].values, callback=time_domain_callback) + domain = build_time_window_domain( + bin_edges, + stimulus_presentations['start_time'].values, + callback=time_domain_callback) out_of_order = np.where(np.diff(domain, axis=1) < 0) if len(out_of_order[0]) > 0: - out_of_order_time_bins = [(row, col) for row, col in zip(out_of_order)] - raise ValueError(f"The time domain specified contains out-of-order bin edges at indices: {out_of_order_time_bins}") + out_of_order_time_bins = \ + [(row, col) for row, col in zip(out_of_order)] + raise ValueError("The time domain specified contains out-of-order " + f"bin edges at indices: {out_of_order_time_bins}") ends = domain[:, -1] starts = domain[:, 0] @@ -639,30 +754,46 @@ def presentationwise_spike_counts( overlapping = np.where(time_diffs < 0)[0] if len(overlapping) > 0: - # Ignoring intervals that overlaps multiple time bins because trying to figure that out would take O(n) + # Ignoring intervals that overlaps multiple time bins because + # trying to figure that out would take O(n) overlapping = [(s, s + 1) for s in overlapping] - warnings.warn(f"You've specified some overlapping time intervals between neighboring rows: {overlapping}, " - f"with a maximum overlap of {np.abs(np.min(time_diffs))} seconds.") + warnings.warn("You've specified some overlapping time intervals " + f"between neighboring rows: {overlapping}, " + "with a maximum overlap of" + f" {np.abs(np.min(time_diffs))} seconds.") tiled_data = build_spike_histogram( - domain, self.spike_times, units.index.values, dtype=dtype, binarize=binarize + domain, + self.spike_times, + units.index.values, + dtype=dtype, + binarize=binarize ) + stim_presentation_id = stimulus_presentations.index.values + tiled_data = xr.DataArray( name='spike_counts', data=tiled_data, coords={ - 'stimulus_presentation_id': stimulus_presentations.index.values, - 'time_relative_to_stimulus_onset': bin_edges[:-1] + np.diff(bin_edges) / 2, + 'stimulus_presentation_id': stim_presentation_id, + 'time_relative_to_stimulus_onset': (bin_edges[:-1] + + np.diff(bin_edges) / 2), 'unit_id': units.index.values }, - dims=['stimulus_presentation_id', 'time_relative_to_stimulus_onset', 'unit_id'] + dims=['stimulus_presentation_id', + 'time_relative_to_stimulus_onset', + 'unit_id'] ) return tiled_data - def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids=None): - ''' Produce a table associating spike times with units and stimulus presentations + def presentationwise_spike_times( + self, + stimulus_presentation_ids=None, + unit_ids=None): + ''' Produce a table associating spike times with units and + stimulus presentations Parameters ---------- @@ -684,12 +815,16 @@ def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids= The unit that emitted this spike. ''' - stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) + stimulus_presentations = \ + self._filter_owned_df('stimulus_presentations', + ids=stimulus_presentation_ids) units = self._filter_owned_df('units', ids=unit_ids) presentation_times = np.zeros([stimulus_presentations.shape[0] * 2]) - presentation_times[::2] = np.array(stimulus_presentations['start_time']) - presentation_times[1::2] = np.array(stimulus_presentations['stop_time']) + presentation_times[::2] = \ + np.array(stimulus_presentations['start_time']) + presentation_times[1::2] = \ + np.array(stimulus_presentations['stop_time']) all_presentation_ids = np.array(stimulus_presentations.index.values) presentation_ids = [] @@ -701,30 +836,39 @@ def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids= indices = np.searchsorted(presentation_times, data) - 1 index_valid = indices % 2 == 0 - presentations = all_presentation_ids[np.floor(indices / 2).astype(int)] + presentations = \ + all_presentation_ids[np.floor(indices / 2).astype(int)] sorder = np.argsort(presentations) presentations = presentations[sorder] index_valid = index_valid[sorder] data = data[sorder] - changes = np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0] + changes = \ + np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0] for ii, jj in zip(changes[:-1], changes[1:]): values = data[ii:jj][index_valid[ii:jj]] if values.size == 0: continue unit_ids.append(np.zeros([values.size]) + unit_id) - presentation_ids.append(np.zeros([values.size]) + presentations[ii]) + presentation_ids.append(np.zeros([values.size]) + + presentations[ii]) spike_times.append(values) if not spike_times: - # If there are no units firing during the given stimulus return an empty dataframe - return pd.DataFrame(columns=['spike_times', 'stimulus_presentation', - 'unit_id', 'time_since_stimulus_presentation_onset']) + # If there are no units firing during the given stimulus return an + # empty dataframe + return pd.DataFrame(columns=[ + 'spike_times', + 'stimulus_presentation', + 'unit_id', + 'time_since_stimulus_presentation_onset']) + + pres_ids = np.concatenate(presentation_ids).astype(int) spike_df = pd.DataFrame({ - 'stimulus_presentation_id': np.concatenate(presentation_ids).astype(int), + 'stimulus_presentation_id': pres_ids, 'unit_id': np.concatenate(unit_ids).astype(int) }, index=pd.Index(np.concatenate(spike_times), name='spike_time')) @@ -740,13 +884,18 @@ def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids= spikes_with_onset.drop(columns=["start_time"], inplace=True) return spikes_with_onset - def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_ids=None, use_rates=False): + def conditionwise_spike_statistics( + self, + stimulus_presentation_ids=None, + unit_ids=None, + use_rates=False): """ Produce summary statistics for each distinct stimulus condition Parameters ---------- stimulus_presentation_ids : array-like - identifies stimulus presentations from which spikes will be considered + identifies stimulus presentations from which spikes will be + considered unit_ids : array-like identifies units whose spikes will be considered use_rates : bool, optional @@ -755,37 +904,59 @@ def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_id Returns ------- pd.DataFrame : - Rows are indexed by unit id and stimulus condition id. Values are summary statistics describing spikes - emitted by a specific unit across presentations within a specific condition. + Rows are indexed by unit id and stimulus condition id. Values are + summary statistics describing spikes emitted by a specific unit + across presentations within a specific condition. """ - # TODO: Need to return an empty df if no matching unit-ids or presentation-ids are found - # TODO: To use filter_owned_df() make sure to convert the results from a Series to a Dataframe - stimulus_presentation_ids = (stimulus_presentation_ids if stimulus_presentation_ids is not None - else self.stimulus_presentations.index.values) # In case - presentations = self.stimulus_presentations.loc[stimulus_presentation_ids, ["stimulus_condition_id", "duration"]] + # TODO: Need to return an empty df if no matching unit-ids or + # presentation-ids are found + # TODO: To use filter_owned_df() make sure to convert the results + # from a Series to a Dataframe + stimulus_presentation_ids = ( + stimulus_presentation_ids if stimulus_presentation_ids is not None + else self.stimulus_presentations.index.values) # In case + presentations = self.stimulus_presentations.loc[ + stimulus_presentation_ids, ["stimulus_condition_id", "duration"] + ] spikes = self.presentationwise_spike_times( - stimulus_presentation_ids=stimulus_presentation_ids, unit_ids=unit_ids + stimulus_presentation_ids=stimulus_presentation_ids, + unit_ids=unit_ids ) if spikes.empty: # In the case there are no spikes - spike_counts = pd.DataFrame({'spike_count': 0}, - index=pd.MultiIndex.from_product([stimulus_presentation_ids, unit_ids], - names=['stimulus_presentation_id', 'unit_id'])) + spike_counts = pd.DataFrame( + {'spike_count': 0}, + index=pd.MultiIndex.from_product([ + stimulus_presentation_ids, + unit_ids], + names=['stimulus_presentation_id', 'unit_id'])) else: spike_counts = spikes.copy() spike_counts["spike_count"] = np.zeros(spike_counts.shape[0]) - spike_counts = spike_counts.groupby(["stimulus_presentation_id", "unit_id"]).count() - unit_ids = unit_ids if unit_ids is not None else spikes['unit_id'].unique() # If not explicity stated get unit ids from spikes table. - spike_counts = spike_counts.reindex(pd.MultiIndex.from_product([stimulus_presentation_ids, - unit_ids], - names=['stimulus_presentation_id', - 'unit_id']), fill_value=0) - - sp = pd.merge(spike_counts, presentations, left_on="stimulus_presentation_id", right_index=True, how="left") + spike_counts = \ + spike_counts.groupby(["stimulus_presentation_id", + "unit_id"]).count() + + # If not explicity stated get unit ids from spikes table. + unit_ids = unit_ids if unit_ids is not None \ + else spikes['unit_id'].unique() + spike_counts = \ + spike_counts.reindex( + pd.MultiIndex.from_product( + [stimulus_presentation_ids, + unit_ids], + names=['stimulus_presentation_id', + 'unit_id']), fill_value=0) + + sp = pd.merge(spike_counts, + presentations, + left_on="stimulus_presentation_id", + right_index=True, + how="left") sp.reset_index(inplace=True) if use_rates: @@ -800,9 +971,14 @@ def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_id for ind, gr in sp.groupby(["stimulus_condition_id", "unit_id"]): summary.append(extractor(ind, gr)) - return pd.DataFrame(summary).set_index(keys=["unit_id", "stimulus_condition_id"]) + return pd.DataFrame(summary).set_index(keys=[ + "unit_id", + "stimulus_condition_id"]) - def get_parameter_values_for_stimulus(self, stimulus_name, drop_nulls=True): + def get_parameter_values_for_stimulus( + self, + stimulus_name, + drop_nulls=True): """ For each stimulus parameter, report the unique values taken on by that parameter while a named stimulus was presented. @@ -818,17 +994,24 @@ def get_parameter_values_for_stimulus(self, stimulus_name, drop_nulls=True): """ - presentation_ids = self.get_stimulus_table([stimulus_name]).index.values - return self.get_stimulus_parameter_values(presentation_ids, drop_nulls=drop_nulls) + presentation_ids = \ + self.get_stimulus_table([stimulus_name]).index.values + return self.get_stimulus_parameter_values( + presentation_ids, + drop_nulls=drop_nulls) - def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nulls=True): + def get_stimulus_parameter_values( + self, + stimulus_presentation_ids=None, + drop_nulls=True): ''' For each stimulus parameter, report the unique values taken on by that parameter throughout the course of the session. Parameters ---------- stimulus_presentation_ids : array-like, optional - If provided, only parameter values from these stimulus presentations will be considered. + If provided, only parameter values from these stimulus + presentations will be considered. Returns ------- @@ -837,9 +1020,15 @@ def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nul ''' - stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) - stimulus_presentations = stimulus_presentations.drop(columns=list(NON_STIMULUS_PARAMETERS) + ['stimulus_name']) - stimulus_presentations = removed_unused_stimulus_presentation_columns(stimulus_presentations) + stimulus_presentations = \ + self._filter_owned_df('stimulus_presentations', + ids=stimulus_presentation_ids) + stimulus_presentations = \ + stimulus_presentations.drop( + columns=list(NON_STIMULUS_PARAMETERS) + ['stimulus_name']) + stimulus_presentations = \ + removed_unused_stimulus_presentation_columns( + stimulus_presentations) parameters = {} for colname in stimulus_presentations.columns: @@ -858,7 +1047,8 @@ def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nul def channel_structure_intervals(self, channel_ids): - """ find on a list of channels the intervals of channels inserted into particular structures + """ find on a list of channels the intervals of channels inserted + into particular structures Parameters ---------- @@ -874,7 +1064,8 @@ def channel_structure_intervals(self, channel_ids): labels : np.ndarray for each detected interval, the label associated with that interval intervals : np.ndarray - one element longer than labels. Start and end indices for intervals. + one element longer than labels. Start and end indices for + intervals. """ structure_id_key = "ecephys_structure_id" @@ -884,7 +1075,8 @@ def channel_structure_intervals(self, channel_ids): unique_probes = table["probe_id"].unique() if len(unique_probes) > 1: - warnings.warn("Calculating structure boundaries across channels from multiple probes.") + warnings.warn("Calculating structure boundaries across channels " + "from multiple probes.") intervals = nan_intervals(table[structure_id_key].values) labels = table[structure_label_key].iloc[intervals[:-1]].values @@ -903,34 +1095,54 @@ def _build_spike_times(self, spike_times): return output_spike_times - def _build_stimulus_presentations(self, stimulus_presentations, nonapplicable="null"): + def _build_stimulus_presentations( + self, + stimulus_presentations, + nonapplicable="null"): stimulus_presentations.index.name = 'stimulus_presentation_id' - stimulus_presentations = stimulus_presentations.drop(columns=['stimulus_index']) - - # TODO: putting these here for now; after SWDB 2019, will rerun stimulus table module for all sessions - # and can remove these - stimulus_presentations = naming_utilities.collapse_columns(stimulus_presentations) - stimulus_presentations = naming_utilities.standardize_movie_numbers(stimulus_presentations) - stimulus_presentations = naming_utilities.add_number_to_shuffled_movie(stimulus_presentations) - stimulus_presentations = naming_utilities.map_stimulus_names( - stimulus_presentations, default_stimulus_renames - ) - stimulus_presentations = naming_utilities.map_column_names(stimulus_presentations, default_column_renames) - - # pandas groupby ops ignore nans, so we need a new "nonapplicable" value that pandas does not recognize as null ... + stimulus_presentations = \ + stimulus_presentations.drop(columns=['stimulus_index']) + + # TODO: putting these here for now; after SWDB 2019, will rerun + # stimulus table module for all sessions and can remove these + stimulus_presentations = \ + naming_utilities.collapse_columns(stimulus_presentations) + stimulus_presentations = \ + naming_utilities.standardize_movie_numbers(stimulus_presentations) + stimulus_presentations = \ + naming_utilities.add_number_to_shuffled_movie( + stimulus_presentations) + stimulus_presentations = \ + naming_utilities.map_stimulus_names( + stimulus_presentations, default_stimulus_renames) + stimulus_presentations = \ + naming_utilities.map_column_names( + stimulus_presentations, + default_column_renames, + ignore_case=False) + + # pandas groupby ops ignore nans, so we need a new "nonapplicable" + # value that pandas does not recognize as null ... stimulus_presentations.replace("", nonapplicable, inplace=True) stimulus_presentations.fillna(nonapplicable, inplace=True) - stimulus_presentations['duration'] = stimulus_presentations['stop_time'] - stimulus_presentations['start_time'] + stimulus_presentations['duration'] = \ + stimulus_presentations['stop_time'] - \ + stimulus_presentations['start_time'] # TODO: database these stimulus_conditions = {} presentation_conditions = [] cid_counter = -1 - # TODO: Can we have parameters on what columns to omit? If stimulus_block or duration is left in it can affect - # how conditionwise_spike_statistics counts spikes - params_only = stimulus_presentations.drop(columns=["start_time", "stop_time", "duration", "stimulus_block"]) + # TODO: Can we have parameters on what columns to omit? + # If stimulus_block or duration is left in it can affect + # how conditionwise_spike_statistics counts spikes + params_only = \ + stimulus_presentations.drop(columns=["start_time", + "stop_time", + "duration", + "stimulus_block"]) for row in params_only.itertuples(index=False): if row in stimulus_conditions: @@ -949,8 +1161,11 @@ def _build_stimulus_presentations(self, stimulus_presentations, nonapplicable="n cond_ids.append(ci) cond_vals.append(cv) - self._stimulus_conditions = pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, name="stimulus_condition_id")) - stimulus_presentations["stimulus_condition_id"] = presentation_conditions + self._stimulus_conditions = \ + pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, + name="stimulus_condition_id")) + stimulus_presentations["stimulus_condition_id"] = \ + presentation_conditions return stimulus_presentations @@ -959,8 +1174,16 @@ def _build_units_table(self, units_table): probes = self.probes.copy() self._unmerged_units = units_table.copy() - table = pd.merge(units_table, channels, left_on='peak_channel_id', right_index=True, suffixes=['_unit', '_channel']) - table = pd.merge(table, probes, left_on='probe_id', right_index=True, suffixes=['_unit', '_probe']) + table = pd.merge(units_table, + channels, + left_on='peak_channel_id', + right_index=True, + suffixes=['_unit', '_channel']) + table = pd.merge(table, + probes, + left_on='probe_id', + right_index=True, + suffixes=['_unit', '_probe']) table.index.name = 'unit_id' table = table.rename(columns={ @@ -982,14 +1205,22 @@ def _build_units_table(self, units_table): 'pref_images_multi_ns': 'pref_image_multi_ns', }) - return table.sort_values(by=['probe_description', 'probe_vertical_position', 'probe_horizontal_position']) + return table.sort_values(by=['probe_description', + 'probe_vertical_position', + 'probe_horizontal_position']) def _build_nwb1_waveforms(self, mean_waveforms): - # _build_mean_waveforms() assumes every unit has the same number of waveforms and that a unit-waveform exists - # for all channels. This is not true for NWB 1 files where each unit has ONE waveform on ONE channel + # _build_mean_waveforms() assumes every unit has the same number of + # waveforms and that a unit-waveform exists for all channels. This + # is not true for NWB 1 files where each unit has ONE waveform on + # ONE channel units_df = self._units output_waveforms = {} - sampling_rate_lu = {uid: self.probes.loc[r['probe_id']]['sampling_rate'] for uid, r in units_df.iterrows()} + sampling_rate_lu = { + uid: self.probes.loc[ + r['probe_id'] + ]['sampling_rate'] for uid, r in units_df.iterrows() + } for uid in list(mean_waveforms.keys()): data = mean_waveforms.pop(uid) @@ -1012,35 +1243,54 @@ def _build_mean_waveforms(self, mean_waveforms): for cid, row in self.channels.iterrows(): channel_id_lut[(row["local_index"], row["probe_id"])] = cid - probe_id_lut = {uid: row['probe_id'] for uid, row in self._units.iterrows()} + probe_id_lut = { + uid: row['probe_id'] for uid, row in self._units.iterrows() + } output_waveforms = {} for uid in list(mean_waveforms.keys()): data = mean_waveforms.pop(uid) - if uid not in probe_id_lut: # It's been filtered out during unit table generation! + # It's been filtered out during unit table generation! + if uid not in probe_id_lut: continue probe_id = probe_id_lut[uid] + + time_vals = np.arange(data.shape[1]) / \ + self.probes.loc[probe_id]['sampling_rate'] + output_waveforms[uid] = xr.DataArray( data=data, dims=['channel_id', 'time'], coords={ - 'channel_id': [channel_id_lut[(ii, probe_id)] for ii in range(data.shape[0])], - 'time': np.arange(data.shape[1]) / self.probes.loc[probe_id]['sampling_rate'] + 'channel_id': [channel_id_lut[(ii, probe_id)] + for ii in range(data.shape[0])], + 'time': time_vals } ) - output_waveforms[uid] = output_waveforms[uid][output_waveforms[uid]["channel_id"] != -1] + + output_waveforms[uid] = \ + output_waveforms[uid][ + output_waveforms[uid]["channel_id"] != -1 + ] return output_waveforms def _build_inter_presentation_intervals(self): + + from_presentation_id = self.stimulus_presentations.index.values[:-1] + to_presentation_id = self.stimulus_presentations.index.values[1:] + interval1 = self.stimulus_presentations['start_time'].values[1:] + interval2 = self.stimulus_presentations['stop_time'].values[:-1] + intervals = pd.DataFrame({ - 'from_presentation_id': self.stimulus_presentations.index.values[:-1], - 'to_presentation_id': self.stimulus_presentations.index.values[1:], - 'interval': self.stimulus_presentations['start_time'].values[1:] - self.stimulus_presentations['stop_time'].values[:-1] + 'from_presentation_id': from_presentation_id, + 'to_presentation_id': to_presentation_id, + 'interval': interval1 - interval2 }) - return intervals.set_index(['from_presentation_id', 'to_presentation_id'], inplace=False) + return intervals.set_index(['from_presentation_id', + 'to_presentation_id'], inplace=False) def _filter_owned_df(self, key, ids=None, copy=True): df = getattr(self, key) @@ -1051,7 +1301,9 @@ def _filter_owned_df(self, key, ids=None, copy=True): if ids is None: return df - ids = coerce_scalar(ids, f'a scalar ({ids}) was provided as ids, filtering to a single row of {key}.') + ids = coerce_scalar( + ids, f'a scalar ({ids}) was ' + f'provided as ids, filtering to a single row of {key}.') df = df.loc[ids] @@ -1068,8 +1320,9 @@ def _remove_detailed_stimulus_parameters(cls, presentations): @classmethod def from_nwb_path(cls, path, nwb_version=2, api_kwargs=None, **kwargs): api_kwargs = {} if api_kwargs is None else api_kwargs - # TODO: Is there a way for pynwb to check the file before actually loading it with io read? If so we could - # automatically check what NWB version is being inputed + # TODO: Is there a way for pynwb to check the file before actually + # loading it with io read? If so we could automatically check + # what NWB version is being inputed nwb_version = int(nwb_version) # only use major version if nwb_version >= 2: @@ -1079,9 +1332,11 @@ def from_nwb_path(cls, path, nwb_version=2, api_kwargs=None, **kwargs): NWBAdaptorCls = EcephysNwb1Api else: - raise Exception(f'specified NWB version {nwb_version} not supported. Supported versions are: 2.X, 1.X') + raise Exception(f'specified NWB version {nwb_version} not ' + 'supported. Supported versions are: 2.X, 1.X') - return cls(api=NWBAdaptorCls.from_path(path=path, **api_kwargs), **kwargs) + return cls(api=NWBAdaptorCls.from_path(path=path, + **api_kwargs), **kwargs) def _warn_invalid_spike_intervals(self): @@ -1090,11 +1345,17 @@ def _warn_invalid_spike_intervals(self): invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) if not invalid_time_intervals.empty: - warnings.warn("Session includes invalid time intervals that could be accessed with the attribute 'invalid_times'," - "Spikes within these intervals are invalid and may need to be excluded from the analysis.") + warnings.warn("Session includes invalid time intervals that could " + "be accessed with the attribute 'invalid_times'," + "Spikes within these intervals are invalid and may " + "need to be excluded from the analysis.") -def build_spike_histogram(time_domain, spike_times, unit_ids, dtype=None, binarize=False): +def build_spike_histogram(time_domain, + spike_times, + unit_ids, + dtype=None, + binarize=False): time_domain = np.array(time_domain) unit_ids = np.array(unit_ids) @@ -1139,7 +1400,8 @@ def removed_unused_stimulus_presentation_columns(stimulus_presentations): def nan_intervals(array, nan_like=["null"]): - """ find interval bounds (bounding consecutive identical values) in an array, which may contain nans + """ find interval bounds (bounding consecutive identical values) in an + array, which may contain nans Parameters ----------- @@ -1148,7 +1410,8 @@ def nan_intervals(array, nan_like=["null"]): Returns ------- np.ndarray : - start and end indices of detected intervals (one longer than the number of intervals) + start and end indices of detected intervals (one longer than the + number of intervals) """ @@ -1184,7 +1447,8 @@ def array_intervals(array): Returns ------- np.ndarray : - start and end indices of detected intervals (one longer than the number of intervals) + start and end indices of detected intervals (one longer than the + number of intervals) """ diff --git a/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py b/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py index fcafbbd4b..9f182b285 100644 --- a/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py +++ b/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py @@ -1,4 +1,3 @@ -from itertools import product import functools from collections import defaultdict import logging @@ -8,38 +7,37 @@ from allensdk.brain_observatory.sync_dataset import Dataset from allensdk.brain_observatory.ecephys import stimulus_sync -from allensdk.brain_observatory import sync_utilities class EcephysSyncDataset(Dataset): - + @property def sample_frequency(self): return self.meta_data['ni_daq']['counter_output_freq'] - @sample_frequency.setter def sample_frequency(self, value): if not hasattr(self, 'meta_data'): self.meta_data = defaultdict(dict) self.meta_data['ni_daq']['counter_output_freq'] = value - def __init__(self): - '''In-memory representation of a sync h5 file as produced by the sync package. + '''In-memory representation of a sync h5 file as produced by the sync package. Notes ----- - base is from here: http://aibspi/mpe_apps/sync/blob/master/sync/dataset.py - Construction works slightly differently for this class as its base. In particular, - this class' __init__ method merely constructs the object. To make a new SyncDataset in client code, use the + base is from http://aibspi/mpe_apps/sync/blob/master/sync/dataset.py + Construction works slightly differently for this class as its base. + In particular, this class' __init__ method merely constructs the + object. To make a new SyncDataset in client code, use the factory classmethod. This is done for ease of testability. ''' pass - - def extract_led_times(self, keys=Dataset.OPTOGENETIC_STIMULATION_KEYS, fallback_line=18): + def extract_led_times(self, + keys=Dataset.OPTOGENETIC_STIMULATION_KEYS, + fallback_line=18): try: led_times = self.get_edges( @@ -48,55 +46,130 @@ def extract_led_times(self, keys=Dataset.OPTOGENETIC_STIMULATION_KEYS, fallback_ units="seconds" ) except KeyError: - warnings.warn(f"unable to find LED times using line labels {keys}, returning line {fallback_line}") + warnings.warn("unable to find LED times using line labels" + + f"{keys}, returning line {fallback_line}") led_times = self.get_rising_edges(fallback_line, units="seconds") return led_times + def remove_zero_frames(self, frame_times): + + D = np.diff(frame_times) + + a = np.where(D < 0.01)[0] + b = np.where((D > 0.018) * (D < 0.1))[0] + + def find_match(b, value): + try: + return b[np.max(np.where((b < value))[0])] - value + except ValueError: + return None + + c = [find_match(b, A) for A in a] + + ft = np.copy(D) + + for idx, d in enumerate(a): + if c[idx] is not None: + if c[idx] > -100: + ft[d+c[idx]] = np.median(D) + ft[d] = np.median(D) + + t = np.concatenate(([np.min(frame_times)], + np.cumsum(ft) + np.min(frame_times))) + + return t + + def extract_frame_times_from_photodiode( + self, + photodiode_cycle=60, + frame_keys=Dataset.FRAME_KEYS, + photodiode_keys=Dataset.PHOTODIODE_KEYS, + trim_discontiguous_frame_times=True): - def extract_frame_times_from_photodiode(self, photodiode_cycle=60, frame_keys=Dataset.FRAME_KEYS, photodiode_keys=Dataset.PHOTODIODE_KEYS): photodiode_times = self.get_edges('all', photodiode_keys) vsync_times = self.get_edges('falling', frame_keys) - vsync_times = sync_utilities.trim_discontiguous_times(vsync_times) - - logging.info(f"Total vsyncs: {len(vsync_times)}") - photodiode_times = stimulus_sync.trim_border_pulses(photodiode_times, vsync_times) - photodiode_times = stimulus_sync.correct_on_off_effects(photodiode_times) - photodiode_times = stimulus_sync.fix_unexpected_edges(photodiode_times, cycle=photodiode_cycle) + if trim_discontiguous_frame_times: + vsync_times = stimulus_sync.trim_discontiguous_vsyncs(vsync_times) + + vsync_times_chunked, pd_times_chunked = \ + stimulus_sync.separate_vsyncs_and_photodiode_times( + vsync_times, + photodiode_times, + photodiode_cycle) + + logging.info(f"Total chunks: {len(vsync_times_chunked)}") + + frame_start_times = np.zeros((0,)) + + for i in range(len(vsync_times_chunked)): + + photodiode_times = stimulus_sync.trim_border_pulses( + pd_times_chunked[i], + vsync_times_chunked[i]) + photodiode_times = stimulus_sync.correct_on_off_effects( + photodiode_times) + photodiode_times = stimulus_sync.fix_unexpected_edges( + photodiode_times, + cycle=photodiode_cycle) + + frame_duration = stimulus_sync.estimate_frame_duration( + photodiode_times, + cycle=photodiode_cycle) + irregular_interval_policy = functools.partial( + stimulus_sync.allocate_by_vsync, + np.diff(vsync_times_chunked[i])) + frame_indices, frame_starts, frame_end_times = \ + stimulus_sync.compute_frame_times( + photodiode_times, + frame_duration, + len(vsync_times_chunked[i]), + cycle=photodiode_cycle, + irregular_interval_policy=irregular_interval_policy + ) + + frame_start_times = np.concatenate((frame_start_times, + frame_starts)) + + frame_start_times = self.remove_zero_frames(frame_start_times) - frame_duration = stimulus_sync.estimate_frame_duration(photodiode_times, cycle=photodiode_cycle) - irregular_interval_policy = functools.partial(stimulus_sync.allocate_by_vsync, np.diff(vsync_times)) - frame_indices, frame_start_times, frame_end_times = stimulus_sync.compute_frame_times( - photodiode_times, frame_duration, len(vsync_times), - cycle=photodiode_cycle, irregular_interval_policy=irregular_interval_policy - ) + logging.info(f"Total vsyncs: {len(vsync_times)}") return frame_start_times - - def extract_frame_times_from_vsyncs(self, photodiode_cycle=60, + def extract_frame_times_from_vsyncs( + self, + photodiode_cycle=60, frame_keys=Dataset.FRAME_KEYS, photodiode_keys=Dataset.PHOTODIODE_KEYS ): raise NotImplementedError() - - def extract_frame_times(self, strategy, photodiode_cycle=60, - frame_keys=Dataset.FRAME_KEYS, photodiode_keys=Dataset.PHOTODIODE_KEYS - ): + def extract_frame_times( + self, + strategy, + photodiode_cycle=60, + frame_keys=Dataset.FRAME_KEYS, + photodiode_keys=Dataset.PHOTODIODE_KEYS, + trim_discontiguous_frame_times=True + ): if strategy == 'use_photodiode': return self.extract_frame_times_from_photodiode( - photodiode_cycle=photodiode_cycle, frame_keys=frame_keys, photodiode_keys=photodiode_keys + photodiode_cycle=photodiode_cycle, + frame_keys=frame_keys, + photodiode_keys=photodiode_keys, + trim_discontiguous_frame_times=trim_discontiguous_frame_times ) elif strategy == 'use_vsyncs': return self.extract_frame_times_from_vsyncs( - photodiode_cycle=photodiode_cycle, frame_keys=frame_keys, photodiode_keys=photodiode_keys + photodiode_cycle=photodiode_cycle, + frame_keys=frame_keys, + photodiode_keys=photodiode_keys ) else: raise ValueError('unrecognized strategy: {}'.format(strategy)) - @classmethod def factory(cls, path): ''' Build a new SyncDataset. @@ -104,11 +177,12 @@ def factory(cls, path): Parameters ---------- path : str - Filesystem path to the h5 file containing sync information to be loaded. + Filesystem path to the h5 file containing sync information + to be loaded. ''' obj = cls() obj.load(path) return obj - +# diff --git a/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py b/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py index 6c7424fef..5cdaba772 100644 --- a/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py +++ b/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py @@ -33,56 +33,113 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -from argschema import ArgSchema, ArgSchemaParser +from argschema import ArgSchema from argschema.schemas import DefaultSchema -from argschema.fields import Nested, InputDir, String, Boolean, Float, Dict, Int, NumpyArray +from argschema.fields import Nested, String, Boolean, Float, Int, NumpyArray class ProbeInputParameters(DefaultSchema): name = String(required=True, help='Identifier for this probe') - lfp_input_file_path = String(required=True, description="path to original LFP .dat file") - lfp_timestamps_input_path = String(required=True, description="path to LFP timestamps") - lfp_data_path = String(required=True, help="Path to LFP data continuous file") - lfp_timestamps_path = String(required=True, help="Path to LFP timestamps aligned to master clock") - lfp_channel_info_path = String(required=True, help="Path to LFP channel info") - total_channels = Int(default=384, help='Total channel count for this probe.') - surface_channel = Int(required=True, help="Probe surface channel") - reference_channels = NumpyArray(required=False, help="Probe reference channels") - lfp_sampling_rate = Float(required=True, help="Sampling rate of LFP data") - noisy_channels = NumpyArray(required=False, help="Noisy channels to remove") + lfp_input_file_path = String( + required=True, + description="path to original LFP .dat file") + lfp_timestamps_input_path = String( + required=True, + description="path to LFP timestamps") + lfp_data_path = String( + required=True, + help="Path to LFP data continuous file") + lfp_timestamps_path = String( + required=True, + help="Path to LFP timestamps aligned to master clock") + lfp_channel_info_path = String( + required=True, + help="Path to LFP channel info") + total_channels = Int( + default=384, + help='Total channel count for this probe.') + surface_channel = Int( + required=True, + help="Probe surface channel") + reference_channels = NumpyArray( + required=False, + help="Probe reference channels") + lfp_sampling_rate = Float( + required=True, + help="Sampling rate of LFP data") + noisy_channels = NumpyArray( + required=False, + help="Noisy channels to remove") class LfpSubsamplingParameters(DefaultSchema): - temporal_subsampling_factor = Int(default=2, description="Ratio of input samples to output samples in time") - channel_stride = Int(default=4, description="Distance between channels to keep") - surface_padding = Int(default=40, description="Number of channels above surface to include") - start_channel_offset = Int(default=2, description="Offset of first channel (from bottom of the probe)") - reorder_channels = Boolean(default=True, description="Implement channel reordering") - cutoff_frequency = Float(default=0.1, description="Cutoff frequency for DC offset filter (Butterworth)") - filter_order = Int(default=1, description="Order of DC offset filter (Butterworth)") - remove_reference_channels = Boolean(default=False, - description="indicates whether references should be removed from output") - remove_channels_out_of_brain = Boolean(default=False, - description="indicates whether to remove channels outside the brain") - remove_noisy_channels = Boolean(default=False, - description="indicates whether noisy channels should be removed from output") + temporal_subsampling_factor = Int( + default=2, + description="Ratio of input samples to output samples in time") + channel_stride = Int( + default=4, + description="Distance between channels to keep") + surface_padding = Int( + default=40, + description="Number of channels above surface to include") + start_channel_offset = Int( + default=2, + description="Offset of first channel (from bottom of the probe)") + reorder_channels = Boolean( + default=False, + description="Implement channel reordering") + cutoff_frequency = Float( + default=0.1, + description="Cutoff frequency for DC offset filter (Butterworth)") + filter_order = Int( + default=1, + description="Order of DC offset filter (Butterworth)") + remove_reference_channels = Boolean( + default=False, + description="indicates whether references should be removed") + remove_channels_out_of_brain = Boolean( + default=False, + description="indicates whether to remove channels outside the brain") + remove_noisy_channels = Boolean( + default=False, + description="indicates whether noisy channels should be removed") class InputParameters(ArgSchema): - probes = Nested(ProbeInputParameters, many=True, help='Probes for LFP subsampling') - lfp_subsampling = Nested(LfpSubsamplingParameters, help='Parameters for this module') + probes = Nested( + ProbeInputParameters, + many=True, + help='Probes for LFP subsampling') + lfp_subsampling = Nested( + LfpSubsamplingParameters, + help='Parameters for this module') class OutputSchema(DefaultSchema): - input_parameters = Nested(InputParameters, description="Input parameters the module was run with", required=True) + input_parameters = Nested( + InputParameters, + description="Input parameters the module was run with", + required=True) class ProbeOutputParameters(DefaultSchema): - name = String(required=True, help='Identifier for this probe.') - lfp_data_path = String(required=True, help='Output subsampled data file.') - lfp_timestamps_path = String(required=True, help='Timestamps for subsampled data.') - lfp_channel_info_path = String(required=True, help='LFP channels from that was subsampled.') + name = String( + equired=True, + help='Identifier for this probe.') + lfp_data_path = String( + required=True, + help='Output subsampled data file.') + lfp_timestamps_path = String( + required=True, + help='Timestamps for subsampled data.') + lfp_channel_info_path = String( + required=True, + help='LFP channels from that was subsampled.') class OutputParameters(OutputSchema): - probe_outputs = Nested(ProbeOutputParameters, many=True, required=True, help='probewise outputs') + probe_outputs = Nested( + ProbeOutputParameters, + many=True, + required=True, + help='probewise outputs') diff --git a/allensdk/brain_observatory/ecephys/stimulus_sync.py b/allensdk/brain_observatory/ecephys/stimulus_sync.py index ddcd15d00..a59ae8da6 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_sync.py +++ b/allensdk/brain_observatory/ecephys/stimulus_sync.py @@ -1,5 +1,3 @@ -import warnings - import numpy as np import scipy.spatial.distance as distance @@ -7,60 +5,114 @@ def trimmed_stats(data, pctiles=(10, 90)): low = np.percentile(data, pctiles[0]) high = np.percentile(data, pctiles[1]) - + trimmed = data[np.logical_and( data <= high, data >= low )] - + return np.mean(trimmed), np.std(trimmed) - + + +def trim_discontiguous_vsyncs(vs_times, photodiode_cycle=60): + vs_times = np.array(vs_times) + + breaks = np.where(np.diff(vs_times) > (1/photodiode_cycle)*100)[0] + + if len(breaks) > 0: + chunk_sizes = np.diff(np.concatenate((np.array([0, ]), + breaks, + np.array([len(vs_times), ])))) + largest_chunk = np.argmax(chunk_sizes) + + if largest_chunk == 0: + return vs_times[:np.min(breaks+1)] + elif largest_chunk == len(breaks): + return vs_times[np.max(breaks+1):] + else: + return vs_times[breaks[largest_chunk-1]:breaks[largest_chunk]] + else: + return vs_times + + +def separate_vsyncs_and_photodiode_times(vs_times, + pd_times, + photodiode_cycle=60): + + vs_times = np.array(vs_times) + pd_times = np.array(pd_times) + + breaks = np.where(np.diff(vs_times) > (1/photodiode_cycle)*100)[0] + + shift = 2.0 + break_times = [-shift] + break_times.extend(vs_times[breaks].tolist()) + break_times.extend([np.inf]) + + vs_times_out = [] + pd_times_out = [] + + for indx, b in enumerate(break_times[:-1]): + + pd_in_range = np.where((pd_times > break_times[indx] + shift) * + (pd_times <= break_times[indx+1] + shift))[0] + vs_in_range = np.where((vs_times > break_times[indx]) * + (vs_times <= break_times[indx+1]))[0] + + vs_times_out.append(vs_times[vs_in_range]) + pd_times_out.append(pd_times[pd_in_range]) + + return vs_times_out, pd_times_out + def trim_border_pulses(pd_times, vs_times, frame_interval=1/60, num_frames=5): pd_times = np.array(pd_times) return pd_times[np.logical_and( - pd_times >= vs_times[0], + pd_times >= vs_times[0], pd_times <= vs_times[-1] + num_frames * frame_interval )] - - + + def correct_on_off_effects(pd_times): ''' - + Notes ----- - This cannot (without additional info) determine whether an assymmetric offset is odd-long or even-long. + This cannot (without additional info) determine whether an assymmetric + offset is odd-long or even-long. ''' - - pd_diff = np.diff(pd_times) - odd_diff_mean, odd_diff_std = trimmed_stats(pd_diff[1::2]) + + pd_diff = np.diff(pd_times) + odd_diff_mean, odd_diff_std = trimmed_stats(pd_diff[1::2]) even_diff_mean, even_diff_std = trimmed_stats(pd_diff[0::2]) - + half_diff = np.diff(pd_times[0::2]) full_period_mean, full_period_std = trimmed_stats(half_diff) - half_period_mean = full_period_mean / 2 + half_period_mean = full_period_mean / 2 odd_offset = odd_diff_mean - half_period_mean even_offset = even_diff_mean - half_period_mean - + pd_times[::2] -= odd_offset / 2 pd_times[1::2] -= even_offset / 2 - + return pd_times def flag_unexpected_edges(pd_times, ndevs=10): pd_diff = np.diff(pd_times) diff_mean, diff_std = trimmed_stats(pd_diff) - + expected_duration_mask = np.ones(pd_diff.size) expected_duration_mask[np.logical_or( pd_diff < diff_mean - ndevs * diff_std, pd_diff > diff_mean + ndevs * diff_std )] = 0 - expected_duration_mask[1:] = np.logical_and(expected_duration_mask[:-1], expected_duration_mask[1:]) - expected_duration_mask = np.concatenate([expected_duration_mask, [expected_duration_mask[-1]]]) - + expected_duration_mask[1:] = np.logical_and(expected_duration_mask[:-1], + expected_duration_mask[1:]) + expected_duration_mask = np.concatenate([expected_duration_mask, + [expected_duration_mask[-1]]]) + return expected_duration_mask @@ -69,36 +121,39 @@ def fix_unexpected_edges(pd_times, ndevs=10, cycle=60, max_frame_offset=4): expected_duration_mask = flag_unexpected_edges(pd_times, ndevs=ndevs) diff_mean, diff_std = trimmed_stats(np.diff(pd_times)) frame_interval = diff_mean / cycle - + bad_edges = np.where(expected_duration_mask == 0)[0] bad_blocks = np.sort(np.unique(np.concatenate([ [0], np.where(np.diff(bad_edges) > 1)[0] + 1, [len(bad_edges)] ]))) - + output_edges = [] for low, high in zip(bad_blocks[:-1], bad_blocks[1:]): current_bad_edge_indices = bad_edges[low: high-1] current_bad_edges = pd_times[current_bad_edge_indices] low_bound = pd_times[current_bad_edge_indices[0]] high_bound = pd_times[current_bad_edge_indices[-1] + 1] - + edges_missing = int(np.around((high_bound - low_bound) / diff_mean)) expected = np.linspace(low_bound, high_bound, edges_missing + 1) - - distances = distance.cdist(current_bad_edges[:, None], expected[:, None]) + + distances = distance.cdist(current_bad_edges[:, None], + expected[:, None]) distances = np.around(distances / frame_interval).astype(int) - + min_offsets = np.amin(distances, axis=0) min_offset_indices = np.argmin(distances, axis=0) output_edges = np.concatenate([ output_edges, expected[min_offsets > max_frame_offset], - current_bad_edges[min_offset_indices[min_offsets <= max_frame_offset]] + current_bad_edges[min_offset_indices[min_offsets <= + max_frame_offset]] ]) - - return np.sort(np.concatenate([output_edges, pd_times[expected_duration_mask > 0]])) + + return np.sort(np.concatenate([output_edges, + pd_times[expected_duration_mask > 0]])) def estimate_frame_duration(pd_times, cycle=60): @@ -110,7 +165,13 @@ def assign_to_last(index, starts, ends, frame_duration, irregularity, cycle): return starts, ends -def allocate_by_vsync(vs_diff, index, starts, ends, frame_duration, irregularity, cycle): +def allocate_by_vsync(vs_diff, + index, + starts, + ends, + frame_duration, + irregularity, + cycle): current_vs_diff = vs_diff[index * cycle: (index + 1) * cycle] sign = np.sign(irregularity) @@ -125,35 +186,51 @@ def allocate_by_vsync(vs_diff, index, starts, ends, frame_duration, irregularity return starts, ends -def compute_frame_times(photodiode_times, frame_duration, num_frames, cycle, irregular_interval_policy=assign_to_last): +def compute_frame_times(photodiode_times, + frame_duration, + num_frames, + cycle, + irregular_interval_policy=assign_to_last): indices = np.arange(num_frames) starts = np.zeros(num_frames, dtype=float) ends = np.zeros(num_frames, dtype=float) num_intervals = len(photodiode_times) - 1 - for start_index, (start_time, end_time) in enumerate(zip(photodiode_times[:-1], photodiode_times[1:])): + for start_index, (start_time, end_time) in \ + enumerate(zip(photodiode_times[:-1], photodiode_times[1:])): interval_duration = end_time - start_time - irregularity = int(np.around((interval_duration) / frame_duration)) - cycle + irregularity = \ + int(np.around((interval_duration) / frame_duration)) - cycle local_frame_duration = interval_duration / (cycle + irregularity) - durations = np.zeros(cycle + ( start_index == num_intervals - 1 )) + local_frame_duration - + durations = \ + np.zeros(cycle + + (start_index == num_intervals - 1)) + local_frame_duration + current_ends = np.cumsum(durations) + start_time current_starts = current_ends - durations while irregularity != 0: current_starts, current_ends = irregular_interval_policy( - start_index, current_starts, current_ends, local_frame_duration, irregularity, cycle + start_index, + current_starts, + current_ends, + local_frame_duration, + irregularity, cycle ) irregularity += -1 * np.sign(irregularity) early_frame = start_index * cycle - late_frame = (start_index + 1) * cycle + ( start_index == num_intervals - 1 ) + late_frame = \ + (start_index + 1) * cycle + (start_index == num_intervals - 1) remaining = starts[early_frame: late_frame].size starts[early_frame: late_frame] = current_starts[:remaining] ends[early_frame: late_frame] = current_ends[:remaining] - return indices, starts, ends \ No newline at end of file + return indices, starts, ends + +# +# diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py b/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py index 167cbfe89..7a354a5a5 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py @@ -36,7 +36,11 @@ def build_stimulus_table( sync_dataset = EcephysSyncDataset.factory(sync_h5_path) frame_times = sync_dataset.extract_frame_times( - strategy=frame_time_strategy) + strategy=frame_time_strategy, + trim_discontiguous_frame_times=kwargs.get( + 'trim_discontiguous_frame_times', + True) + ) def seconds_to_frames(seconds): return \ @@ -71,6 +75,8 @@ def seconds_to_frames(seconds): stim_table_full, maximum_expected_spontanous_activity_duration ) + print(stim_table_full.keys()) + stim_table_full = naming_utilities.collapse_columns(stim_table_full) stim_table_full = naming_utilities.drop_empty_columns(stim_table_full) stim_table_full = naming_utilities.standardize_movie_numbers( @@ -80,8 +86,15 @@ def seconds_to_frames(seconds): stim_table_full = naming_utilities.map_stimulus_names( stim_table_full, stimulus_name_map ) + + print(stim_table_full.keys()) + print(column_name_map) + stim_table_full = naming_utilities.map_column_names(stim_table_full, - column_name_map) + column_name_map, + ignore_case=False) + + print(stim_table_full.keys()) stim_table_full.to_csv(output_stimulus_table_path, index=False) np.save(output_frame_times_path, frame_times, allow_pickle=False) diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py b/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py index 753e0aa50..e7ff7228d 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py @@ -1,39 +1,40 @@ import sys -from argschema import ArgSchema, ArgSchemaParser +from argschema import ArgSchema from argschema.schemas import DefaultSchema -from argschema.fields import Nested, InputDir, String, Float, Dict, Int, List, Bool - -from . import naming_utilities as nu +from argschema.fields import Nested, String, Float, Dict, List, Bool default_stimulus_renames = { "": "spontaneous", - "natural_movie_1" : "natural_movie_one", - "natural_movie_3" : "natural_movie_three", + "natural_movie_1": "natural_movie_one", + "natural_movie_3": "natural_movie_three", "Natural Images": "natural_scenes", "flash_250ms": "flashes", "gabor_20_deg_250ms": "gabors", - "drifting_gratings" : "drifting_gratings", - "static_gratings" : "static_gratings", + "drifting_gratings": "drifting_gratings", + "static_gratings": "static_gratings", "contrast_response": "drifting_gratings_contrast", - "natural_movie_1_more_repeats" : "natural_movie_one", - "natural_movie_shuffled" : "natural_movie_one_shuffled", - "motion_stimulus" : "dot_motion", - "drifting_gratings_more_repeats" : "drifting_gratings_75_repeats", - + + "Natural_Images_Shuffled": "natural_scenes_shuffled", + "Natural_Images_Sequential": "natural_scenes_sequential", + "natural_movie_1_more_repeats": "natural_movie_one", + "natural_movie_shuffled": "natural_movie_one_shuffled", + "motion_stimulus": "dot_motion", + "drifting_gratings_more_repeats": "drifting_gratings_75_repeats", + "signal_noise_test_0_200_repeats": "test_movie_one", "signal_noise_test_0": "test_movie_one", - "signal_noise_test_0": "test_movie_two", - "signal_noise_session_1" : "dense_movie_one", - "signal_noise_session_2" : "dense_movie_two", - "signal_noise_session_3" : "dense_movie_three", - "signal_noise_session_4" : "dense_movie_four", - "signal_noise_session_5" : "dense_movie_five", - "signal_noise_session_6" : "dense_movie_six", + "signal_noise_test_1": "test_movie_two", + "signal_noise_session_1": "dense_movie_one", + "signal_noise_session_2": "dense_movie_two", + "signal_noise_session_3": "dense_movie_three", + "signal_noise_session_4": "dense_movie_four", + "signal_noise_session_5": "dense_movie_five", + "signal_noise_session_6": "dense_movie_six", } @@ -52,26 +53,36 @@ class InputParameters(ArgSchema): stimulus_pkl_path = String( - required=True, help="path to pkl file containing raw stimulus information" + required=True, + help="""path to pkl file containing raw stimulus information""" ) sync_h5_path = String( - required=True, help="path to h5 file containing syncronization information" + required=True, + help="""path to h5 file containing syncronization information""" ) output_stimulus_table_path = String( - required=True, help="the output stimulus table csv will be written here" + required=True, + help="""the output stimulus table csv will be written here""" ) - output_frame_times_path = String(required=True, help="output all frame times here") + output_frame_times_path = String( + required=True, + help="""output all frame times here""") minimum_spontaneous_activity_duration = Float( default=sys.float_info.epsilon, - help="detected spontaneous activity sweeps will be rejected if they last fewer that this many seconds", + help="""detected spontaneous activity sweeps will be rejected if + they last fewer that this many seconds""", ) maximum_expected_spontanous_activity_duration = Float( default=1225.02541, - help="validation will fail if a spontanous activity epoch longer than this one is computed.", + help="""validation will fail if a spontanous activity epoch longer + than this one is computed.""", ) frame_time_strategy = String( default="use_photodiode", - help="technique used to align frame times. Options are 'use_photodiode', which interpolates frame times between photodiode edge times (preferred when vsync times are unreliable) and 'use_vsyncs', which is preferred when reliable vsync times are available.", + help="""technique used to align frame times. Options are 'use_photodiode', + which interpolates frame times between photodiode edge times + (preferred when vsync times are unreliable) and 'use_vsyncs', + which is preferred when reliable vsync times are available.""", ) stimulus_name_map = Dict( keys=String(), @@ -80,9 +91,9 @@ class InputParameters(ArgSchema): default=default_stimulus_renames ) column_name_map = Dict( - keys=String(), - values=String(), - help="optionally rename stimulus parameters", + keys=String(), + values=String(), + help="optionally rename stimulus parameters", default=default_column_renames ) extract_const_params_from_repr = Bool(default=True) @@ -94,7 +105,14 @@ class InputParameters(ArgSchema): fail_on_negative_duration = Bool( default=False, - help="Determine if the module should fail if a stimulus epoch has a negative duration." + help="""Determine if the module should fail if a + stimulus epoch has a negative duration.""" + ) + + trim_discontiguous_frame_times = Bool( + default=True, + help="""set to False if stimulus was shown in chunks, + and discontiguous vsyncs are expected""" ) @@ -106,4 +124,4 @@ class OutputSchema(DefaultSchema): ) output_path = String(help="Path to output csv file") output_frame_times_path = String(help="output all frame times here") - +# diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py b/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py index a22d9d02f..e8bfdf2cd 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py @@ -1,19 +1,8 @@ -# -*- coding: utf-8 -*- -""" -Created on Fri Dec 16 15:11:23 2016 - -@author: Xiaoxuan Jia -""" - -import ast -import re import logging import numpy as np import pandas as pd -import warnings - from . import stimulus_parameter_extraction as spe @@ -30,25 +19,29 @@ def create_stim_table( Parameters ---------- stimuli : list of dict - Each element is a stimulus dictionary, as provided by the stim.pkl file. - stimulus_tabler : function - A function which takes a single stimulus dictionary as its argument and returns a stimulus table dataframe. + Each element is a stimulus dictionary, + as provided by the stim.pkl file. + stimulus_tabler : function + A function which takes a single stimulus dictionary + as its argument and returns a stimulus table dataframe. spontaneous_activity_tabler : function - A function which takes a list of stimulus tables as arguments and returns a list of 0 or more tables + A function which takes a list of stimulus tables as + arguments and returns a list of 0 or more tables describing spontaneous activity sweeps. sort_key : str, optional - Sort the final stimulus table in ascending order by this key. Defaults to 'Start'. + Sort the final stimulus table in ascending order by this key. + Defaults to 'Start'. Returns ------- stim_table_full : pandas.DataFrame - Each row is a sweep. Has columns describing (in frames) the start and end times of each sweep. Other columns + Each row is a sweep. Has columns describing (in frames) the start + and end times of each sweep. Other columns describe the values of stimulus parameters on those sweeps. """ stimulus_tables = [] - max_index = 0 for ii, stimulus in enumerate(stimuli): current_tables = stimulus_tabler(stimulus) @@ -57,7 +50,8 @@ def create_stim_table( stimulus_tables.extend(current_tables) - stimulus_tables = sorted(stimulus_tables, key=lambda df: min(df[sort_key].values)) + stimulus_tables = sorted(stimulus_tables, + key=lambda df: min(df[sort_key].values)) for ii, stim_table in enumerate(stimulus_tables): stim_table[block_key] = ii @@ -73,8 +67,8 @@ def create_stim_table( def make_spontaneous_activity_tables( stimulus_tables, start_key="Start", end_key="End", duration_threshold=0.0 ): - """ Fills in frame gaps in a set of stimulus tables. Suitable for use as the spontaneous_activity_tabler in - create_stim_table. + """ Fills in frame gaps in a set of stimulus tables. Suitable for use as + the spontaneous_activity_tabler in create_stim_table. Parameters ---------- @@ -85,13 +79,14 @@ def make_spontaneous_activity_tables( end_key : str, optional Column name for the end of a sweep. Defaults to 'End'. duration_threshold : numeric or None - If not None (default is 0), remove spontaneous activity sweeps whose duration is - less than this threshold. + If not None (default is 0), remove spontaneous activity sweeps + whose duration is less than this threshold. Returns ------- - list : - Either empty, or contains a single pd.DataFrame. The rows of the dataframe are spontenous activity sweeps. + list : + Either empty, or contains a single pd.DataFrame. + The rows of the dataframe are spontaneous activity sweeps. """ @@ -111,7 +106,9 @@ def make_spontaneous_activity_tables( if duration_threshold is not None: spon_sweeps = spon_sweeps[ - np.fabs(spon_sweeps[start_key] - spon_sweeps[end_key]) > duration_threshold + np.fabs(spon_sweeps[start_key] + - spon_sweeps[end_key]) + > duration_threshold ] spon_sweeps.reset_index(drop=True, inplace=True) @@ -130,14 +127,18 @@ def apply_frame_times( Parameters ---------- stimulus_table : pd.DataFrame - Rows are sweeps. Columns are stimulus parameters as well as start and end frames for each sweep. + Rows are sweeps. Columns are stimulus parameters as well as start + and end frames for each sweep. frame_times : numpy.ndarrray Gives the time in seconds at which each frame (indices) began. frames_per_second : numeric, optional - If provided, and extra_frame_time is True, will be used to calculcate the extra_frame_time. + If provided, and extra_frame_time is True, will be used to calculcate + the extra_frame_time. extra_frame_time : float, optional - If provided, an additional frame time will be appended. The time will be incremented by extra_frame_time from - the previous last frame time, to denote the time at which the last frame ended. If False, no extra time will be + If provided, an additional frame time will be appended. The time will + be incremented by extra_frame_time from + the previous last frame time, to denote the time at which the last + frame ended. If False, no extra time will be appended. If None (default), the increment will be 1.0/fps. map_columns : tuple of str, optional Which columns to replace with times. Defaults to 'Start' and 'End @@ -154,7 +155,8 @@ def apply_frame_times( if extra_frame_time is True and frames_per_second is not None: extra_frame_time = 1.0 / frames_per_second if extra_frame_time is not False: - frame_times = np.append(frame_times, frame_times[-1] + extra_frame_time) + frame_times = np.append(frame_times, frame_times[-1] + + extra_frame_time) for column in map_columns: stimulus_table[column] = frame_times[ @@ -172,34 +174,36 @@ def apply_display_sequence( diff_key="dif", block_key="stimulus_block", ): - """ Adjust raw sweep frames for a stimulus based on the display sequence + """ Adjust raw sweep frames for a stimulus based on the display sequence for that stimulus. Parameters ---------- sweep_frames_table : pd.DataFrame - Each row is a sweep. Has two columns, 'start' and 'end', + Each row is a sweep. Has two columns, 'start' and 'end', which describe (in frames) when that sweep began and ended. frame_display_sequence : np.ndarray - 2D array. Rows are display intervals. The 0th column is the start frame of - that interval, the 1st the end frame. + 2D array. Rows are display intervals. The 0th column is the start + frame of that interval, the 1st the end frame. Returns ------- sweep_frames_table : pd.DataFrame - As above, but start and end frames have been adjusted based on the display sequence. + As above, but start and end frames have been adjusted based on + the display sequence. Notes ----- - The frame values in the raw sweep_frames_table are given in 0-indexed offsets from the - start of display for this stimulus. This domain only takes into account frames which are part - of a display interval for that stimulus, so the frame ids need to be adjusted to lie on the global + The frame values in the raw sweep_frames_table are given in 0-indexed + offsets from the start of display for this stimulus. This domain only + takes into account frames which are part of a display interval for that + stimulus, so the frame ids need to be adjusted to lie on the global frame sequence. """ sweep_frames_table = sweep_frames_table.copy() - if not block_key in sweep_frames_table.columns.values: + if block_key not in sweep_frames_table.columns.values: sweep_frames_table[block_key] = np.zeros( (sweep_frames_table.shape[0]), dtype=int ) @@ -210,7 +214,8 @@ def apply_display_sequence( sweep_frames_table[start_key] += frame_display_sequence[0, 0] for seg in range(len(frame_display_sequence) - 1): - match_inds = sweep_frames_table[start_key] >= frame_display_sequence[seg, 1] + match_inds = sweep_frames_table[start_key] \ + >= frame_display_sequence[seg, 1] sweep_frames_table.loc[match_inds, start_key] += ( frame_display_sequence[seg + 1, 0] - frame_display_sequence[seg, 1] @@ -232,21 +237,27 @@ def apply_display_sequence( def read_stimulus_name_from_path(stimulus): - """Obtains a human-readable stimulus name by looking at the filename of the 'stim_path' item. + """Obtains a human-readable stimulus name by looking at the filename of + the 'stim_path' item. Parameters ---------- stimulus : dict must contain a 'stim_path' item. - + Returns ------- - str : + str : name of stimulus """ - return stimulus["stim_path"].split("\\")[-1].split(".")[0] + stim_name = stimulus["stim_path"].split("\\")[-1].split(".")[0] + + if len(stim_name) == 0: + stim_name = stimulus["stim_path"].split("\\\\")[-2] + + return stim_name def build_stimuluswise_table( @@ -260,29 +271,30 @@ def build_stimuluswise_table( extract_const_params_from_repr=False, drop_const_params=spe.DROP_PARAMS, ): - """ Construct a table of sweeps, including their times on the experiment-global clock - and the values of each relevant parameter. + """ Construct a table of sweeps, including their times on the + experiment-global clock and the values of each relevant parameter. Parameters ---------- stimulus : dict - Describes presentation of a stimulus on a particular experiment. Has a number of fields, - of which we are using: + Describes presentation of a stimulus on a particular experiment. Has + a number of fields, of which we are using: stim_path : str windows file path to the stimulus data sweep_frames : list of lists - rows are sweeps, columns are start and end frames of that sweep + rows are sweeps, columns are start and end frames of that sweep (in the stimulus-specific frame domain). C-order. sweep_order : list of int indices are frames, values are the sweep on that frame display_sequence : list of list - rows are intervals in which the stimulus was displayed. Columns are start - and end times (s, global) of the display. C-order. + rows are intervals in which the stimulus was displayed. + Columns are start and end times (s, global) of the display. + C-order. dimnames : list of str Names of parameters for this stimulus (such as "Contrast") sweep_table : list of tuple - Each element is a tuple of parameter values (1 per dimname) describing - a single sweep. + Each element is a tuple of parameter values (1 per dimname) + describing a single sweep. seconds_to_frames : function Converts experiment seconds to frames start_key : str, optional @@ -294,13 +306,15 @@ def build_stimuluswise_table( block_key : str, optional key to use for the 0-index position of this stimulus block get_stimulus_name : function | dict -> str, optional - extracts stimulus name from the stimulus dictionary. Default is read_stimulus_name_from_path + extracts stimulus name from the stimulus dictionary. Default is + read_stimulus_name_from_path Returns ------- list of pandas.DataFrame : - Each table corresponds to an entry in the display sequence. - Rows are sweeps, columns are stimulus parameter values as well as "Start" and "End". + Each table corresponds to an entry in the display sequence. + Rows are sweeps, columns are stimulus parameter values as well as + "Start" and "End". """ @@ -312,7 +326,8 @@ def build_stimuluswise_table( sweep_frames_table = pd.DataFrame( stimulus["sweep_frames"], columns=(start_key, end_key) ) - sweep_frames_table[block_key] = np.zeros([sweep_frames_table.shape[0]], dtype=int) + sweep_frames_table[block_key] = np.zeros([sweep_frames_table.shape[0]], + dtype=int) sweep_frames_table = apply_display_sequence( sweep_frames_table, frame_display_sequence, block_key=block_key ) @@ -355,14 +370,18 @@ def build_stimuluswise_table( existing = const_param_key in existing_columns if not (existing_cap or existing_upper or existing): - stim_table[const_param_key] = [const_param_value] * stim_table.shape[0] + stim_table[const_param_key] = [const_param_value] * \ + stim_table.shape[0] else: logging.info( - f"found sweep_param named: {const_param_key}, ignoring const param of the same name (value: {const_param_value})" + f"""found sweep_param named: {const_param_key}, + ignoring const param of the same name (value: + {const_param_value})""" ) unique_indices = np.unique(stim_table[block_key].values) - output = [stim_table.loc[stim_table[block_key] == ii, :] for ii in unique_indices] + output = [stim_table.loc[stim_table[block_key] == ii, :] + for ii in unique_indices] return output @@ -373,13 +392,15 @@ def split_column(table, column, new_columns, drop_old=True): Parameters ---------- table : pandas.DataFrame - Columns will be drawn from and assigned to this dataframe. This dataframe will NOT be modified inplace. + Columns will be drawn from and assigned to this dataframe. This + dataframe will NOT be modified inplace. column : str This column will be split. new_columns : dict, mapping strings to functions - Each key will be the name of a new column, while its value (a function) will be used to build the - new column's values. The functions should map from a single value of the original column to a single value - of the new column. + Each key will be the name of a new column, while its value (a function) + will be used to build the new column's values. The functions should map + from a single value of the original column to a single value + of the new column. drop_old : bool, optional If True, the original column will be dropped from the table. @@ -390,7 +411,7 @@ def split_column(table, column, new_columns, drop_old=True): """ - if not column in table: + if column not in table: return table table = table.copy() @@ -409,21 +430,25 @@ def assign_sweep_values( drop=True, tmp_suffix="_stimtable_todrop", ): - """ Left joins a stimulus table to a sweep table in order to associate epochs in time with stimulus characteristics. - + """ Left joins a stimulus table to a sweep table in order to associate + epochs in time with stimulus characteristics. + Parameters ---------- stim_table : pd.DataFrame - Each row is a stimulus epoch, with start and end times and a foreign key onto a particular sweep. + Each row is a stimulus epoch, with start and end times and a foreign + key onto a particular sweep. sweep_table : pd.DataFrame - Each row is a sweep. Should have columns in common with the stim_table - the resulting table will use values from - the sweep_table. + Each row is a sweep. Should have columns in common with the stim_table + - the resulting table will use values from the sweep_table. on : str, optional Column on which to join. drop : bool, optional - If True (default), the join column (argument on) will be dropped from the output. + If True (default), the join column (argument on) will be dropped from + the output. tmp_suffix : str, optional - Will be used to identify overlapping columns. Should not appear in the name of any column in either dataframe. + Will be used to identify overlapping columns. Should not appear in the + name of any column in either dataframe. """ diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py b/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py index 56f68141f..3f3220219 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py @@ -1,14 +1,16 @@ import re import warnings -import functools -import pandas as pd import numpy as np -GABOR_DIAMETER_RE = re.compile(r"gabor_(\d*\.{0,1}\d*)_{0,1}deg(?:_\d+ms){0,1}") +GABOR_DIAMETER_RE = \ + re.compile(r"gabor_(\d*\.{0,1}\d*)_{0,1}deg(?:_\d+ms){0,1}") + GENERIC_MOVIE_RE = re.compile( - r"natural_movie_(?P\d+|one|two|three|four|five|six|seven|eight|nine)(_shuffled){0,1}(_more_repeats){0,1}" + r"natural_movie_" + + r"(?P\d+|one|two|three|four|five|six|seven|eight|nine)" + + r"(_shuffled){0,1}(_more_repeats){0,1}" ) DIGIT_NAMES = { "1": "one", @@ -40,8 +42,9 @@ def drop_empty_columns(table): def collapse_columns(table): - """ merge, where possible, columns that describe the same parameter. This is pretty conservative - it - only matches columns by capitalization and it only overrides nans. + """ merge, where possible, columns that describe the same parameter. This + is pretty conservative - it only matches columns by capitalization and + it only overrides nans. """ colnames = set(table.columns) @@ -74,23 +77,24 @@ def add_number_to_shuffled_movie( template="natural_movie_{}_shuffled", tmp_colname="__movie_number__", ): - """ + """ """ if not table[stim_colname].str.contains(SHUFFLED_MOVIE_RE).any(): return table table = table.copy() - table[tmp_colname] = table[stim_colname].str.extract(natural_movie_re, expand=True)[ - "number" - ] + table[tmp_colname] = \ + table[stim_colname].str.extract(natural_movie_re, + expand=True)["number"] unique_numbers = [ item for item in table[tmp_colname].dropna(inplace=False).unique() ] if len(unique_numbers) != 1: raise ValueError( - f"unable to uniquely determine a movie number for this session. Candidates: {unique_numbers}" + "unable to uniquely determine a movie number for this session. " + + f"Candidates: {unique_numbers}" ) movie_number = unique_numbers[0] @@ -103,6 +107,7 @@ def renamer(row): return template.format(movie_number) table[stim_colname] = table.apply(renamer, axis=1) + print(table.keys()) table.drop(columns=tmp_colname, inplace=True) return table @@ -114,9 +119,10 @@ def standardize_movie_numbers( digit_names=DIGIT_NAMES, stim_colname="stimulus_name", ): - """ Natural movie stimuli in visual coding are numbered using words, like "natural_movie_two" rather than - "natural_movie_2". This function ensures that all of the natural movie stimuli in an experiment are named by - that convention. + """ Natural movie stimuli in visual coding are numbered using words, like + "natural_movie_two" rather than "natural_movie_2". This function ensures + that all of the natural movie stimuli in an experiment are named by that + convention. Parameters ---------- @@ -134,19 +140,20 @@ def standardize_movie_numbers( Returns ------- table : pd.DataFrame - the stimulus table with movie numerals having been mapped to english words + the stimulus table with movie numerals having been mapped to english + words """ - replace = lambda match_obj: digit_names[match_obj["number"]] + def replace(match_obj): + return digit_names[match_obj["number"]] # for some reason pandas really wants us to use the captures warnings.filterwarnings("ignore", "This pattern has match groups") movie_rows = table[stim_colname].str.contains(movie_re, na=False) - table.loc[movie_rows, stim_colname] = table.loc[ - movie_rows, stim_colname - ].str.replace(numeral_re, replace) + table.loc[movie_rows, stim_colname] = \ + table.loc[movie_rows, stim_colname].str.replace(numeral_re, replace) return table @@ -162,18 +169,20 @@ def map_stimulus_names(table, name_map=None, stim_colname="stimulus_name"): rename the stimuli according to this mapping stim_colname: str, optional look in this column for stimulus names - + """ if name_map is None: return table - if "" in name_map: - name_map[np.nan] = name_map[""] + name_map[np.nan] = "spontaneous" table[stim_colname] = table[stim_colname].replace( to_replace=name_map, inplace=False ) + + name_map.pop(np.nan) + return table @@ -181,8 +190,8 @@ def map_column_names(table, name_map=None, ignore_case=True): if ignore_case and name_map is not None: name_map = {key.lower(): value for key, value in name_map.items()} - mapper = lambda name: name if name.lower() not in name_map else name_map[name.lower()] - else: - mapper = name_map - return table.rename(columns=mapper) \ No newline at end of file + output = table.rename(columns=name_map) + + return output +# diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py b/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py index a62cc6ac8..b1bcec1b3 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py @@ -2,18 +2,23 @@ import warnings -def validate_epoch_durations(table, start_key="Start", end_key="End", fail_on_negative_durations=False): +def validate_epoch_durations(table, + start_key="Start", + end_key="End", + fail_on_negative_durations=False): durations = table[end_key] - table[start_key] min_duration_index = durations.idxmin() min_duration = durations[min_duration_index] if min_duration == 0: warnings.warn( - f"there is an epoch in this stimulus table (index: {min_duration_index}) with duration = {min_duration}", + f"""there is an epoch in this stimulus table (index: + {min_duration_index}) with duration = {min_duration}""", UserWarning, ) if min_duration < 0: - msg = f"there is an epoch with negative duration (index: {min_duration_index})" + msg = f"""there is an epoch with negative duration (index: + {min_duration_index})""" if fail_on_negative_durations: raise ValueError(msg) warnings.warn(msg) @@ -34,14 +39,24 @@ def validate_max_spontaneous_epoch_duration( end_key="End", ): if get_spontanous_epochs is None: - get_spontanous_epochs = lambda table: table[np.isnan(table[index_key])] + def get_spontanous_epochs(table): + table[np.isnan(table[index_key])] spontaneous_epochs = get_spontanous_epochs(table) - durations = ( - spontaneous_epochs[end_key].values - spontaneous_epochs[start_key].values - ) - if np.amax(durations) > max_duration: - warnings.warn( - f"there is a spontaneous activity duration longer than {max_duration}", - UserWarning, + + if spontaneous_epochs is not None: + + durations = ( + spontaneous_epochs[end_key].values + - spontaneous_epochs[start_key].values ) + + try: + if np.amax(durations) > max_duration: + warnings.warn( + f"""there is a spontaneous activity duration longer than + {max_duration}""", + UserWarning, + ) + except ValueError: + warnings.warn("No spontaneous intervals detected.", UserWarning) diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py b/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py index 0571e41f6..6e30faf8d 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py @@ -21,7 +21,8 @@ def parse_stim_repr( array_re=ARRAY_RE, raise_on_unrecognized=False, ): - """ Read the string representation of a psychopy stimulus and extract stimulus parameters. + """ Read the string representation of a psychopy stimulus and extract + stimulus parameters. Parameters ---------- @@ -33,7 +34,7 @@ def parse_stim_repr( Returns ------- - dict : + dict : maps extracted parameter names to values """ @@ -49,7 +50,8 @@ def parse_stim_repr( return stim_params -# This is not currently in use by the stimulus_table module, but is a potentially handy utility +# This is not currently in use by the stimulus_table module, but is a +# potentially handy utility def extract_stim_class_from_repr(stim_repr, repr_class_re=REPR_CLASS_RE): match = repr_class_re.match(stim_repr) if match is not None and "class_name" in match.groupdict(): @@ -59,14 +61,16 @@ def extract_stim_class_from_repr(stim_repr, repr_class_re=REPR_CLASS_RE): def extract_const_params_from_stim_repr( stim_repr, repr_params_re=REPR_PARAMS_RE, array_re=ARRAY_RE ): - """Parameters which are not set as sweep_params in the stimulus script (usually because they are not - varied during the course of the session) are not output in an easily machine-readable format. This function + """Parameters which are not set as sweep_params in the stimulus script + (usually because they are not varied during the course of the session) are + not output in an easily machine-readable format. This function attempts to recover them by parsing the string repr of the stimulus. Parameters ---------- stim_repr : str - The repr of the camstim stimulus object. Served up per-stimulus in the stim pickle. + The repr of the camstim stimulus object. Served up per-stimulus + in the stim pickle. repr_params_re : re.Pattern Extracts attributes as "="-seperated strings array_re : re.Pattern @@ -75,8 +79,8 @@ def extract_const_params_from_stim_repr( Returns ------- repr_params : dict - dictionary of paramater keys and values extracted from the stim repr. Where possible, the values are converted - to native Python types. + dictionary of paramater keys and values extracted from the stim repr. + Where possible, the values are converted to native Python types. """ @@ -93,7 +97,7 @@ def extract_const_params_from_stim_repr( try: v = ast.literal_eval(v) - except ValueError as err: + except ValueError: pass repr_params[k] = v diff --git a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py index 2840ebb68..060747819 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py @@ -15,7 +15,6 @@ from allensdk.config.manifest import Manifest from ._schemas import InputSchema, OutputSchema -from allensdk.brain_observatory.nwb import setup_table_for_invalid_times # noqa: F401 from allensdk.brain_observatory.nwb import ( add_stimulus_presentations, add_stimulus_timestamps, @@ -28,15 +27,18 @@ eye_tracking_data_is_valid ) from allensdk.brain_observatory.argschema_utilities import ( - write_or_print_outputs, optional_lims_inputs + optional_lims_inputs ) from allensdk.brain_observatory import dict_to_indexed_array -from allensdk.brain_observatory.ecephys.file_io.continuous_file import ContinuousFile -from allensdk.brain_observatory.ecephys.nwb import (EcephysProbe, - EcephysElectrodeGroup, - EcephysSpecimen, - EcephysEyeTrackingRigMetadata, - EcephysCSD) +from allensdk.brain_observatory.ecephys.file_io.continuous_file import ( + ContinuousFile +) +from allensdk.brain_observatory.ecephys.nwb import ( + EcephysProbe, + EcephysElectrodeGroup, + EcephysSpecimen, + EcephysEyeTrackingRigMetadata, + EcephysCSD) from allensdk.brain_observatory.sync_dataset import Dataset import allensdk.brain_observatory.sync_utilities as su @@ -64,10 +66,15 @@ def fill_df(df, str_fill=""): return df -def get_inputs_from_lims(host, ecephys_session_id, output_root, job_queue, strategy): +def get_inputs_from_lims(host, + ecephys_session_id, + output_root, + job_queue, + strategy): """ - This is a development / testing utility for running this module from the Allen Institute for Brain Science's - Laboratory Information Management System (LIMS). It will only work if you are on our internal network. + This is a development / testing utility for running this module from the + Allen Institute for Brain Science's Laboratory Information Management + System (LIMS). It will only work if you are on our internal network. Parameters ---------- @@ -87,7 +94,10 @@ def get_inputs_from_lims(host, ecephys_session_id, output_root, job_queue, strat """ - uri = f"{host}/input_jsons?object_id={ecephys_session_id}&object_class=EcephysSession&strategy_class={strategy}&job_queue_name={job_queue}&output_directory={output_root}" + uri = f"{host}/input_jsons?object_id={ecephys_session_id}" + \ + f"&object_class=EcephysSession&strategy_class={strategy}" + \ + f"&job_queue_name={job_queue}&output_directory={output_root}" + response = requests.get(uri) data = response.json() @@ -148,8 +158,9 @@ def read_spike_times_to_dictionary( spike_times_path : str npy file identifying, per spike, the time at which that spike occurred. spike_units_path : str - npy file identifying, per spike, the unit associated with that spike. These are probe-local, so a - local_to_global_unit_map is used to associate spikes with global unit identifiers. + npy file identifying, per spike, the unit associated with that spike. + These are probe-local, so a local_to_global_unit_map is used to + associate spikes with global unit identifiers. local_to_global_unit_map : dict, optional Maps probewise local unit indices to global unit ids @@ -178,7 +189,8 @@ def read_spike_amplitudes_to_dictionary( templates = load_and_squeeze_npy(templates_path) spike_templates = load_and_squeeze_npy(spike_templates_path) - inverse_whitening_matrix = load_and_squeeze_npy(inverse_whitening_matrix_path) + inverse_whitening_matrix = \ + load_and_squeeze_npy(inverse_whitening_matrix_path) for temp_idx in range(templates.shape[0]): templates[temp_idx, :, :] = np.dot( @@ -186,11 +198,20 @@ def read_spike_amplitudes_to_dictionary( np.ascontiguousarray(inverse_whitening_matrix) ) - scaled_amplitudes = scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor=scale_factor) - return group_1d_by_unit(scaled_amplitudes, spike_units, local_to_global_unit_map) + scaled_amplitudes = scale_amplitudes(spike_amplitudes, + templates, + spike_templates, + scale_factor=scale_factor) + + return group_1d_by_unit(scaled_amplitudes, + spike_units, + local_to_global_unit_map) -def scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor=1.0): +def scale_amplitudes(spike_amplitudes, + templates, + spike_templates, + scale_factor=1.0): template_full_amplitudes = templates.max(axis=1) - templates.min(axis=1) template_amplitudes = template_full_amplitudes.max(axis=1) @@ -200,8 +221,13 @@ def scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor= return spike_amplitudes -def filter_and_sort_spikes(spike_times_mapping: Dict[int, np.ndarray], - spike_amplitudes_mapping: Dict[int, np.ndarray]) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]: +def filter_and_sort_spikes(spike_times_mapping: + Dict[int, np.ndarray], + spike_amplitudes_mapping: + Dict[int, np.ndarray]) -> Tuple[Dict[int, + np.ndarray], + Dict[int, + np.ndarray]]: """Filter out invalid spike timepoints and sort spike data (times + amplitudes) by times. @@ -298,11 +324,13 @@ def read_waveforms_to_dictionary( Parameters ---------- waveforms_path : str - npy file containing waveform data for each unit. Dimensions ought to be units X samples X channels + npy file containing waveform data for each unit. Dimensions ought to + be units X samples X channels local_to_global_unit_map : dict, optional Maps probewise local unit indices to global unit ids peak_channel_map : dict, optional - Maps unit identifiers to indices of peak channels. If provided, the output will contain only samples on the peak + Maps unit identifiers to indices of peak channels. If provided, + the output will contain only samples on the peak channel for each unit. Returns @@ -320,7 +348,8 @@ def read_waveforms_to_dictionary( if local_to_global_unit_map is not None: if unit_id not in local_to_global_unit_map: logging.warning( - f"unable to find unit at local position {unit_id} while reading waveforms" + f"""unable to find unit at local position + {unit_id} while reading waveforms""" ) continue unit_id = local_to_global_unit_map[unit_id] @@ -395,7 +424,7 @@ def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, """ probe_nwb_device = EcephysProbe(name=name, - description="Neuropixels 1.0 Probe", # required field + description="Neuropixels 1.0 Probe", manufacturer="imec", probe_id=probe_id, sampling_rate=sampling_rate) @@ -417,7 +446,8 @@ def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, def add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, - columns_to_add: Optional[List[Tuple[str, str]]] = None): + columns_to_add: + Optional[List[Tuple[str, str]]] = None): """Add additional columns to ecephys nwbfile electrode table. Parameters @@ -430,8 +460,10 @@ def add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, columns are added. """ default_columns = [ - ("probe_vertical_position", "Length-wise position of electrode/channel on device (microns)"), - ("probe_horizontal_position", "Width-wise position of electrode/channel on device (microns)"), + ("probe_vertical_position", + "Length-wise position of electrode/channel on device (microns)"), + ("probe_horizontal_position", + "Width-wise position of electrode/channel on device (microns)"), ("probe_id", "The unique id of this electrode's/channel's device"), ("local_index", "The local index of electrode/channel on device"), ("valid_data", "Whether data from this electrode/channel is usable") @@ -441,7 +473,8 @@ def add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, columns_to_add = default_columns for col_name, col_description in columns_to_add: - if (not nwbfile.electrodes) or (col_name not in nwbfile.electrodes.colnames): + if (not nwbfile.electrodes) or \ + (col_name not in nwbfile.electrodes.colnames): nwbfile.add_electrode_column(name=col_name, description=col_description) @@ -461,11 +494,16 @@ def add_ecephys_electrodes(nwbfile: pynwb.NWBFile, id: The unique id for a given electrode/channel probe_id: The unique id for an electrode's/channel's device valid_data: Whether the data for an electrode/channel is usable - local_index: The local index of an electrode/channel on a given device - probe_vertical_position: Length-wise position of electrode/channel on device (microns) - probe_horizontal_position: Width-wise position of electrode/channel on device (microns) - manual_structure_id: The LIMS id associated with an anatomical structure - manual_structure_acronym: Acronym associated with an anatomical structure + local_index: The local index of an electrode/channel on a + given device + probe_vertical_position: Length-wise position of electrode/channel + on device (microns) + probe_horizontal_position: Width-wise position of electrode/channel + on device (microns) + manual_structure_id: The LIMS id associated with an anatomical + structure + manual_structure_acronym: Acronym associated with an anatomical + structure anterior_posterior_ccf_coordinate dorsal_ventral_ccf_coordinate left_right_ccf_coordinate @@ -515,7 +553,8 @@ def add_ecephys_electrodes(nwbfile: pynwb.NWBFile, def add_ragged_data_to_dynamic_table( table, data, column_name, column_description="" ): - """ Builds the index and data vectors required for writing ragged array data to a pynwb dynamic table + """ Builds the index and data vectors required for writing ragged array + data to a pynwb dynamic table Parameters ---------- @@ -538,7 +577,10 @@ def add_ragged_data_to_dynamic_table( del data table.add_column( - name=column_name, description=column_description, data=values, index=idx + name=column_name, + description=column_description, + data=values, + index=idx ) @@ -629,7 +671,7 @@ def write_probe_lfp_file(session_id, session_metadata, session_start_time, logging.info(f"writing lfp file for probe {probe['id']}") nwbfile = pynwb.NWBFile( - session_description='LFP data and associated channel info for a single Ecephys probe', + session_description='LFP data and associated info for one probe', identifier=f"{probe['id']}", session_id=f"{session_id}", session_start_time=session_start_time, @@ -640,16 +682,18 @@ def write_probe_lfp_file(session_id, session_metadata, session_start_time, nwbfile = add_metadata_to_nwbfile(nwbfile, session_metadata) if probe.get("temporal_subsampling_factor", None) is not None: - probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"] - - nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile( - nwbfile, - probe_id=probe["id"], - name=probe["name"], - sampling_rate=probe["sampling_rate"], - lfp_sampling_rate=probe["lfp_sampling_rate"], - has_lfp_data=probe["lfp"] is not None - ) + probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / \ + probe["temporal_subsampling_factor"] + + nwbfile, probe_nwb_device, probe_nwb_electrode_group = \ + add_probe_to_nwbfile( + nwbfile, + probe_id=probe["id"], + name=probe["name"], + sampling_rate=probe["sampling_rate"], + lfp_sampling_rate=probe["lfp_sampling_rate"], + has_lfp_data=probe["lfp"] is not None + ) lfp_channels = np.load(probe['lfp']['input_channels_path'], allow_pickle=False) @@ -659,7 +703,7 @@ def write_probe_lfp_file(session_id, session_metadata, session_start_time, local_index_whitelist=lfp_channels) electrode_table_region = nwbfile.create_electrode_table_region( - region=np.arange(len(nwbfile.electrodes)).tolist(), # must use raw indices here + region=np.arange(len(nwbfile.electrodes)).tolist(), # use raw inds name='electrodes', description=f"lfp channels on probe {probe['id']}" ) @@ -678,17 +722,20 @@ def write_probe_lfp_file(session_id, session_metadata, session_start_time, nwbfile.add_acquisition(lfp.create_electrical_series( name=f"probe_{probe['id']}_lfp_data", data=H5DataIO(data=lfp_data, compression='gzip', compression_opts=9), - timestamps=H5DataIO(data=lfp_timestamps, compression='gzip', compression_opts=9), + timestamps=H5DataIO(data=lfp_timestamps, + compression='gzip', + compression_opts=9), electrodes=electrode_table_region )) nwbfile.add_acquisition(lfp) - csd, csd_times, csd_locs = read_csd_data_from_h5(probe["csd_path"]) - nwbfile = add_csd_to_nwbfile(nwbfile, csd, csd_times, csd_locs) + if ("csd_path" in probe.keys()): + csd, csd_times, csd_locs = read_csd_data_from_h5(probe["csd_path"]) + nwbfile = add_csd_to_nwbfile(nwbfile, csd, csd_times, csd_locs) with pynwb.NWBHDF5IO(probe['lfp']['output_path'], 'w') as lfp_writer: - logging.info(f"writing probe lfp file to {probe['lfp']['output_path']}") + logging.info(f"writing lfp file to {probe['lfp']['output_path']}") lfp_writer.write(nwbfile, cache_spec=True) return {"id": probe["id"], "nwb_path": probe["lfp"]["output_path"]} @@ -726,17 +773,20 @@ def add_csd_to_nwbfile(nwbfile: pynwb.NWBFile, csd: np.ndarray, nwbfile which has had CSD data added """ - csd_mod = pynwb.ProcessingModule("current_source_density", "Precalculated current source density from interpolated channel locations.") + csd_mod = pynwb.ProcessingModule("current_source_density", + "Precalculated current source density") nwbfile.add_processing_module(csd_mod) csd_ts = pynwb.base.TimeSeries( name="current_source_density", - data=csd.T, # TimeSeries should have data in (timepoints x channels) format + data=csd.T, # TimeSeries should have data in (time x channels) format timestamps=times, unit=csd_unit ) - x_locs, y_locs = np.split(csd_virt_channel_locs.astype(np.uint64), 2, axis=1) + x_locs, y_locs = np.split(csd_virt_channel_locs.astype(np.uint64), + 2, + axis=1) csd = EcephysCSD(name="ecephys_csd", time_series=csd_ts, @@ -756,8 +806,11 @@ def write_probewise_lfp_files(probes, session_id, session_metadata, output_paths = [] pool = mp.Pool(processes=pool_size) - write = partial(write_probe_lfp_file, session_id, session_metadata, - session_start_time, logging.getLogger("").getEffectiveLevel()) + write = partial(write_probe_lfp_file, + session_id, + session_metadata, + session_start_time, + logging.getLogger("").getEffectiveLevel()) for pout in pool.imap_unordered(write, probes): output_paths.append(pout) @@ -804,18 +857,25 @@ def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData: for probe in probes: unit_tables.append(pd.DataFrame(probe['units'])) - local_to_global_unit_map = {unit['cluster_id']: unit['id'] for unit in probe['units']} + local_to_global_unit_map = \ + {unit['cluster_id']: unit['id'] for unit in probe['units']} spike_times.update(read_spike_times_to_dictionary( - probe['spike_times_path'], probe['spike_clusters_file'], local_to_global_unit_map + probe['spike_times_path'], + probe['spike_clusters_file'], + local_to_global_unit_map )) mean_waveforms.update(read_waveforms_to_dictionary( - probe['mean_waveforms_path'], local_to_global_unit_map + probe['mean_waveforms_path'], + local_to_global_unit_map )) spike_amplitudes.update(read_spike_amplitudes_to_dictionary( - probe["spike_amplitudes_path"], probe["spike_clusters_file"], - probe["templates_path"], probe["spike_templates_path"], probe["inverse_whitening_matrix_path"], + probe["spike_amplitudes_path"], + probe["spike_clusters_file"], + probe["templates_path"], + probe["spike_templates_path"], + probe["inverse_whitening_matrix_path"], local_to_global_unit_map=local_to_global_unit_map, scale_factor=probe["amplitude_scale_factor"] )) @@ -826,29 +886,37 @@ def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData: def add_probewise_data_to_nwbfile(nwbfile, probes): - """ Adds channel (electrode) and spike data for a single probe to the session-level nwb file. + """ Adds channel (electrode) and spike data for a single probe to + the session-level nwb file. """ for probe in probes: logging.info(f'found probe {probe["id"]} with name {probe["name"]}') if probe.get("temporal_subsampling_factor", None) is not None: - probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"] - - nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile( - nwbfile, - probe_id=probe["id"], - name=probe["name"], - sampling_rate=probe["sampling_rate"], - lfp_sampling_rate=probe["lfp_sampling_rate"], - has_lfp_data=probe["lfp"] is not None - ) - - add_ecephys_electrodes(nwbfile, probe["channels"], probe_nwb_electrode_group) - - units_table, spike_times, spike_amplitudes, mean_waveforms = parse_probes_data(probes) - nwbfile.units = pynwb.misc.Units.from_dataframe(fill_df(units_table), name='units') - - sorted_spike_times, sorted_spike_amplitudes = filter_and_sort_spikes(spike_times, spike_amplitudes) + probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / \ + probe["temporal_subsampling_factor"] + + nwbfile, probe_nwb_device, probe_nwb_electrode_group = \ + add_probe_to_nwbfile( + nwbfile, + probe_id=probe["id"], + name=probe["name"], + sampling_rate=probe["sampling_rate"], + lfp_sampling_rate=probe["lfp_sampling_rate"], + has_lfp_data=probe["lfp"] is not None + ) + + add_ecephys_electrodes(nwbfile, + probe["channels"], + probe_nwb_electrode_group) + + units_table, spike_times, spike_amplitudes, mean_waveforms = \ + parse_probes_data(probes) + nwbfile.units = pynwb.misc.Units.from_dataframe(fill_df(units_table), + name='units') + + sorted_spike_times, sorted_spike_amplitudes = \ + filter_and_sort_spikes(spike_times, spike_amplitudes) add_ragged_data_to_dynamic_table( table=nwbfile.units, @@ -868,17 +936,20 @@ def add_probewise_data_to_nwbfile(nwbfile, probes): table=nwbfile.units, data=mean_waveforms, column_name="waveform_mean", - column_description="mean waveforms on peak channels (and over samples)", + column_description="mean waveforms on peak channels (over samples)", ) return nwbfile -def add_optotagging_table_to_nwbfile(nwbfile, optotagging_table, tag="optical_stimulation"): +def add_optotagging_table_to_nwbfile(nwbfile, + optotagging_table, + tag="optical_stimulation"): # "name" is a pynwb reserved column name that older versions of the # pre-processed optotagging_table may use. if "name" in optotagging_table.columns: - optotagging_table = optotagging_table.rename(columns={"name": "stimulus_name"}) + optotagging_table = \ + optotagging_table.rename(columns={"name": "stimulus_name"}) opto_ts = pynwb.base.TimeSeries( name="optotagging", @@ -887,21 +958,26 @@ def add_optotagging_table_to_nwbfile(nwbfile, optotagging_table, tag="optical_st unit="seconds" ) - opto_mod = pynwb.ProcessingModule("optotagging", "optogenetic stimulution data") + opto_mod = pynwb.ProcessingModule("optotagging", + "optogenetic stimulution data") opto_mod.add_data_interface(opto_ts) nwbfile.add_processing_module(opto_mod) optotagging_table = setup_table_for_epochs(optotagging_table, opto_ts, tag) if len(optotagging_table) > 0: - container = pynwb.epoch.TimeIntervals.from_dataframe(optotagging_table, "optogenetic_stimulation") + container = \ + pynwb.epoch.TimeIntervals.from_dataframe(optotagging_table, + "optogenetic_stimulation") opto_mod.add_data_interface(container) return nwbfile -def add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile: pynwb.NWBFile, - eye_tracking_rig_geometry: dict) -> pynwb.NWBFile: +def add_eye_tracking_rig_geometry_data_to_nwbfile( + nwbfile: pynwb.NWBFile, + eye_tracking_rig_geometry: dict) -> pynwb.NWBFile: + """ Rig geometry dict should consist of the following fields: monitor_position_mm: [x, y, z] monitor_rotation_deg: [x, y, z] @@ -910,8 +986,9 @@ def add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile: pynwb.NWBFile, led_position: [x, y, z] equipment: A string describing rig """ - eye_tracking_rig_mod = pynwb.ProcessingModule(name='eye_tracking_rig_metadata', - description='Eye tracking rig metadata module') + eye_tracking_rig_mod = \ + pynwb.ProcessingModule(name='eye_tracking_rig_metadata', + description='Eye tracking rig metadata module') rig_metadata = EcephysEyeTrackingRigMetadata( name="eye_tracking_rig_metadata", @@ -934,16 +1011,19 @@ def add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile: pynwb.NWBFile, return nwbfile -def add_eye_tracking_data_to_nwbfile(nwbfile: pynwb.NWBFile, - eye_tracking_frame_times: pd.Series, - eye_dlc_tracking_data: Dict[str, pd.DataFrame], - eye_gaze_data: Dict[str, pd.DataFrame]) -> pynwb.NWBFile: +def add_eye_tracking_data_to_nwbfile( + nwbfile: pynwb.NWBFile, + eye_tracking_frame_times: pd.Series, + eye_dlc_tracking_data: Dict[str, pd.DataFrame], + eye_gaze_data: Dict[str, pd.DataFrame]) -> pynwb.NWBFile: if eye_tracking_data_is_valid(eye_dlc_tracking_data=eye_dlc_tracking_data, synced_timestamps=eye_tracking_frame_times): - add_eye_tracking_ellipse_fit_data_to_nwbfile(nwbfile, - eye_dlc_tracking_data=eye_dlc_tracking_data, - synced_timestamps=eye_tracking_frame_times) + + add_eye_tracking_ellipse_fit_data_to_nwbfile( + nwbfile, + eye_dlc_tracking_data=eye_dlc_tracking_data, + synced_timestamps=eye_tracking_frame_times) # --- Add gaze mapped positions to nwb file --- if eye_gaze_data: @@ -961,11 +1041,11 @@ def write_ecephys_nwb( probes, running_speed_path, session_sync_path, - eye_tracking_rig_geometry, - eye_dlc_ellipses_path, - eye_gaze_mapping_path, pool_size, optotagging_table_path=None, + eye_tracking_rig_geometry=None, + eye_dlc_ellipses_path=None, + eye_gaze_mapping_path=None, session_metadata=None, **kwargs ): @@ -975,7 +1055,7 @@ def write_ecephys_nwb( identifier=f"{session_id}", session_id=f"{session_id}", session_start_time=session_start_time, - institution="Allen Institute for Brain Science" + institution="Allen Institute" ) if session_metadata is not None: @@ -985,9 +1065,11 @@ def write_ecephys_nwb( "colorSpace", "depth", "interpolate", "pos", "rgbPedestal", "tex", "texRes", "flipHoriz", "flipVert", "rgb", "signalDots" ] - stimulus_table = read_stimulus_table(stimulus_table_path, - columns_to_drop=stimulus_columns_to_drop) - nwbfile = add_stimulus_timestamps(nwbfile, stimulus_table['start_time'].values) # TODO: patch until full timestamps are output by stim table module + stimulus_table = \ + read_stimulus_table(stimulus_table_path, + columns_to_drop=stimulus_columns_to_drop) + nwbfile = \ + add_stimulus_timestamps(nwbfile, stimulus_table['start_time'].values) nwbfile = add_stimulus_presentations(nwbfile, stimulus_table) nwbfile = add_invalid_times(nwbfile, invalid_epochs) @@ -1001,22 +1083,31 @@ def write_ecephys_nwb( add_running_speed_to_nwbfile(nwbfile, running_speed) add_raw_running_data_to_nwbfile(nwbfile, raw_running_data) - add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile, - eye_tracking_rig_geometry) + if eye_tracking_rig_geometry is not None: + add_eye_tracking_rig_geometry_data_to_nwbfile( + nwbfile, + eye_tracking_rig_geometry + ) # Collect eye tracking/gaze mapping data from files - eye_tracking_frame_times = su.get_synchronized_frame_times(session_sync_file=session_sync_path, - sync_line_label_keys=Dataset.EYE_TRACKING_KEYS) - eye_dlc_tracking_data = read_eye_dlc_tracking_ellipses(Path(eye_dlc_ellipses_path)) - if eye_gaze_mapping_path: - eye_gaze_data = read_eye_gaze_mappings(Path(eye_gaze_mapping_path)) - else: - eye_gaze_data = None + if eye_dlc_ellipses_path is not None: + eye_tracking_frame_times = \ + su.get_synchronized_frame_times( + session_sync_file=session_sync_path, + sync_line_label_keys=Dataset.EYE_TRACKING_KEYS + ) + eye_dlc_tracking_data = \ + read_eye_dlc_tracking_ellipses(Path(eye_dlc_ellipses_path)) + + if eye_gaze_mapping_path is not None: + eye_gaze_data = read_eye_gaze_mappings(Path(eye_gaze_mapping_path)) + else: + eye_gaze_data = None - add_eye_tracking_data_to_nwbfile(nwbfile, - eye_tracking_frame_times, - eye_dlc_tracking_data, - eye_gaze_data) + add_eye_tracking_data_to_nwbfile(nwbfile, + eye_tracking_frame_times, + eye_dlc_tracking_data, + eye_gaze_data) Manifest.safe_make_parent_dirs(output_path) with pynwb.NWBHDF5IO(output_path, mode='w') as io: @@ -1024,11 +1115,16 @@ def write_ecephys_nwb( io.write(nwbfile, cache_spec=True) probes_with_lfp = [p for p in probes if p["lfp"] is not None] + probes_without_lfp = [p for p in probes if p["lfp"] is None] + probe_outputs = write_probewise_lfp_files(probes_with_lfp, session_id, session_metadata, session_start_time, pool_size=pool_size) + probe_outputs += \ + [{'id': p["id"], "nwb_path": ""} for p in probes_without_lfp] + return { 'nwb_path': output_path, "probe_outputs": probe_outputs @@ -1040,10 +1136,18 @@ def main(): format="%(asctime)s - %(process)s - %(levelname)s - %(message)s" ) - parser = optional_lims_inputs(sys.argv, InputSchema, OutputSchema, get_inputs_from_lims) + parser = optional_lims_inputs( + sys.argv, + InputSchema, + OutputSchema, + get_inputs_from_lims + ) + + write_ecephys_nwb(**parser.args) + + # output = write_ecephys_nwb(**parser.args) - output = write_ecephys_nwb(**parser.args) - write_or_print_outputs(output, parser) + # write_or_print_outputs(output, parser) if __name__ == "__main__": diff --git a/allensdk/brain_observatory/ecephys/write_nwb/_schemas.py b/allensdk/brain_observatory/ecephys/write_nwb/_schemas.py index f704c1666..102cc840f 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/_schemas.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/_schemas.py @@ -106,33 +106,49 @@ class Probe(RaisingSchema): mean_waveforms_path = String(required=True, validate=check_read_access) channels = Nested(Channel, many=True, required=True) units = Nested(Unit, many=True, required=True) - lfp = Nested(Lfp, many=False, required=True, allow_none=True) - csd_path = String(required=True, + lfp = Nested(Lfp, many=False, required=False, allow_none=True) + csd_path = String(required=False, validate=check_read_access, allow_none=True, - help="path to h5 file containing calculated current source density") - sampling_rate = Float(default=30000.0, help="sampling rate (Hz, master clock) at which raw data were acquired on this probe") - lfp_sampling_rate = Float(default=2500.0, allow_none=True, help="sampling rate of LFP data on this probe") - temporal_subsampling_factor = Float(default=2.0, allow_none=True, help="subsampling factor applied to lfp data for this probe (across time)") + help="""path to h5 file containing calculated current + source density""") + sampling_rate = Float( + default=30000.0, + help="""sampling rate (Hz, master clock) at which raw data were + acquired on this probe""") + lfp_sampling_rate = Float( + default=2500.0, + allow_none=True, + help="""sampling rate of LFP data on this probe""") + temporal_subsampling_factor = Float( + default=2.0, + allow_none=True, + help="""subsampling factor applied to lfp data for + this probe (across time)""") spike_amplitudes_path = String( validate=check_read_access, - help="path to npy file containing scale factor applied to the kilosort template used to extract each spike" + help="""path to npy file containing scale factor applied to the + kilosort template used to extract each spike""" ) spike_templates_path = String( validate=check_read_access, - help="path to file associating each spike with a kilosort template" + help="""path to file associating each spike with a kilosort template""" ) templates_path = String( validate=check_read_access, - help="path to file contianing an (nTemplates)x(nSamples)x(nUnits) array of kilosort templates" + help="""path to file containing an (nTemplates)x(nSamples)x(nUnits) + array of kilosort templates""" ) inverse_whitening_matrix_path = String( validate=check_read_access, - help="Kilosort templates are whitened. In order to use them for scaling spike amplitudes to volts, we need to remove the whitening" + help="""Kilosort templates are whitened. In order to use them for + scaling spike amplitudes to volts, we need to remove + the whitening""" ) amplitude_scale_factor = Float( default=0.195e-6, - help="amplitude scale factor converting raw amplitudes to Volts. Default converts from bits -> uV -> V" + help="""amplitude scale factor converting raw amplitudes to Volts. + Default converts from bits -> uV -> V""" ) @@ -193,37 +209,49 @@ class Meta: ) running_speed_path = String( required=True, - help="data collected about the running behavior of the experiment's subject", + help="""data collected about the running behavior of the experiment's + subject""", ) session_sync_path = String( required=True, validate=check_read_access, - help="Path to an h5 experiment session sync file (*.sync). This file relates events from different acquisition modalities to one another in time." + help="""Path to an h5 experiment session sync file (*.sync). This file + relates events from different acquisition modalities to one + another in time.""" + ) + pool_size = Int( + default=3, + help="number of child processes used to write probewise lfp files" + ) + optotagging_table_path = String( + required=False, + validate=check_read_access, + help="""file at this path contains information about the optogenetic + stimulation applied during this experiment""" ) eye_tracking_rig_geometry = Dict( - required=True, - help="Mapping containing information about session rig geometry used for eye gaze mapping." + required=False, + help="""Mapping containing information about session rig geometry used + for eye gaze mapping.""" ) eye_dlc_ellipses_path = String( - required=True, + required=False, validate=check_read_access, - help="h5 filepath containing raw ellipse fits produced by Deep Lab Cuts of subject eye, pupil, and corneal reflections during experiment" + help="""h5 filepath containing raw ellipse fits produced by Deep Lab + Cuts of subject eye, pupil, and corneal reflections during + experiment""" ) eye_gaze_mapping_path = String( required=False, allow_none=True, - help="h5 filepath containing eye gaze behavior of the experiment's subject" + help="""h5 filepath containing eye gaze behavior of the + experiment's subject""" ) - pool_size = Int( - default=3, - help="number of child processes used to write probewise lfp files" - ) - optotagging_table_path = String( + session_metadata = Nested( + SessionMetadata, + allow_none=True, required=False, - validate=check_read_access, - help="file at this path contains information about the optogenetic stimulation applied during this " - ) - session_metadata = Nested(SessionMetadata, allow_none=True, required=False, help="miscellaneous information describing this session") + help="miscellaneous information describing this session""") class ProbeOutputs(RaisingSchema): diff --git a/allensdk/brain_observatory/extract_running_speed/__main__.py b/allensdk/brain_observatory/extract_running_speed/__main__.py index 43b9ea375..62935ac9a 100644 --- a/allensdk/brain_observatory/extract_running_speed/__main__.py +++ b/allensdk/brain_observatory/extract_running_speed/__main__.py @@ -102,7 +102,9 @@ def main( # occasionally an extra set of frame times are acquired after the rest of # the signals. We detect and remove these - frame_times = sync_utilities.trim_discontiguous_times(frame_times) + if kwargs.get('trim_discontiguous_frame_times', True): + frame_times = sync_utilities.trim_discontiguous_times(frame_times) + num_raw_timestamps = len(frame_times) dx_deg = running_from_stim_file(stim_file, "dx", num_raw_timestamps) @@ -116,6 +118,12 @@ def main( vsig = running_from_stim_file(stim_file, "vsig", num_raw_timestamps) vin = running_from_stim_file(stim_file, "vin", num_raw_timestamps) + if len(vin) != len(dx_deg): + vin = np.concatenate((vin, np.zeros((len(dx_deg) - len(vin))))) + + if len(vsig) != len(dx_deg): + vsig = np.concatenate((vsig, np.zeros((len(dx_deg) - len(vsig))))) + velocities = extract_running_speeds( frame_times=frame_times, dx_deg=dx_deg, diff --git a/allensdk/brain_observatory/extract_running_speed/_schemas.py b/allensdk/brain_observatory/extract_running_speed/_schemas.py index 4ec30e7b2..e040ce49a 100644 --- a/allensdk/brain_observatory/extract_running_speed/_schemas.py +++ b/allensdk/brain_observatory/extract_running_speed/_schemas.py @@ -1,6 +1,6 @@ -from argschema import ArgSchema, ArgSchemaParser +from argschema import ArgSchema from argschema.schemas import DefaultSchema -from argschema.fields import Nested, InputDir, String, Float, Dict, Int, Boolean +from argschema.fields import Nested, String, Float, Boolean class InputParameters(ArgSchema): @@ -16,11 +16,20 @@ class InputParameters(ArgSchema): wheel_radius = Float(default=8.255, help="radius, in cm, of running wheel") subject_position = Float( default=2 / 3, - help="normalized distance of the subject from the center of the running wheel (1 is rim, 0 is center)", + help="normalized distance of the subject from the center " + + "of the running wheel (1 is rim, 0 is center)", ) use_median_duration = Boolean( default=True, - help="frame timestamps are often too noisy to use as the denominator in the velocity calculation. Can instead use the median frame duration." + help="frame timestamps are often too noisy to use as the " + + "denominator in the velocity calculation. " + + "Can instead use the median frame duration." + ) + + trim_discontiguous_frame_times = Boolean( + default=True, + help="set to False if stimulus was shown in chunks, " + + "and discontiguous vsyncs are expected." ) diff --git a/allensdk/brain_observatory/sync_dataset.py b/allensdk/brain_observatory/sync_dataset.py index 9f1b9d259..7f5c742d6 100644 --- a/allensdk/brain_observatory/sync_dataset.py +++ b/allensdk/brain_observatory/sync_dataset.py @@ -93,7 +93,7 @@ class Dataset(object): """ - FRAME_KEYS = ('frames', 'stim_vsync') + FRAME_KEYS = ('frames', 'stim_vsync', 'vsync_stim') PHOTODIODE_KEYS = ('photodiode', 'stim_photodiode') OPTOGENETIC_STIMULATION_KEYS = ("LED_sync", "opto_trial") EYE_TRACKING_KEYS = ("eye_frame_received", # Expected eye tracking diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py index 1f819bbed..1f3970c9a 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py @@ -3,26 +3,34 @@ import pandas as pd from .conftest import MockSessionApi -from allensdk.brain_observatory.ecephys.stimulus_analysis.dot_motion import DotMotion +from allensdk.brain_observatory.ecephys.stimulus_analysis.dot_motion \ + import DotMotion from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession class MockDMSessionApi(MockSessionApi): def get_stimulus_presentations(self): - features = np.array(np.meshgrid([0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0], # Dir - [0.001, 0.005, 0.01, 0.02]) # Speed + features = np.array(np.meshgrid([0.0, 45.0, 90.0, 135.0, 180.0, + 225.0, 270.0, 315.0], # Dir + [0.001, 0.005, 0.01, 0.02]) # Speed ).reshape(2, 32) - features = np.concatenate((features, np.array([np.nan, np.nan]).reshape((2, 1))), axis=1) # null case + features = np.concatenate((features, + np.array([np.nan, np.nan]).reshape((2, 1))), + axis=1) # null case return pd.DataFrame({ - 'start_time': np.concatenate(([0.0], np.linspace(0.5, 32.5, 33, endpoint=True), [33.5])), - 'stop_time': np.concatenate(([0.5], np.linspace(1.5, 33.5, 33, endpoint=True), [34.0])), - 'stimulus_name': ['spontaneous'] + ['dot_motion']*33 + ['spontaneous'], + 'start_time': np.concatenate(([0.0], np.linspace(0.5, 32.5, 33, + endpoint=True), [33.5])), + 'stop_time': np.concatenate(([0.5], np.linspace(1.5, 33.5, 33, + endpoint=True), [34.0])), + 'stimulus_name': ['spontaneous'] + + ['dot_motion']*33 + + ['spontaneous'], 'stimulus_block': [0] + [1]*33 + [0], 'duration': [0.5] + [1.0]*33 + [0.5], 'stimulus_index': [0] + [1]*33 + [0], - 'Dir': np.concatenate(([np.nan], features[0,:], [np.nan])), + 'Dir': np.concatenate(([np.nan], features[0, :], [np.nan])), 'Speed': np.concatenate(([np.nan], features[1, :], [np.nan])) }, index=pd.Index(name='id', data=np.arange(35))) @@ -30,7 +38,6 @@ def get_invalid_times(self): return pd.DataFrame() - @pytest.fixture def ecephys_api(): return MockDMSessionApi() @@ -53,9 +60,13 @@ def test_stimulus(ecephys_api): dm = DotMotion(ecephys_session=session) assert(isinstance(dm.stim_table, pd.DataFrame)) assert(len(dm.stim_table) == 33) - assert(set(dm.stim_table.columns).issuperset({'Dir', 'Speed', 'start_time', 'stop_time'})) + assert(set(dm.stim_table.columns).issuperset({'Dir', + 'Speed', + 'start_time', + 'stop_time'})) - assert(set(dm.directions) == {0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0}) + assert(set(dm.directions) == {0.0, 45.0, 90.0, 135.0, + 180.0, 225.0, 270.0, 315.0}) assert(dm.number_directions == 8) assert(set(dm.speeds) == {0.001, 0.005, 0.01, 0.02}) diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py index 120a552f2..926865a6c 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py @@ -4,19 +4,21 @@ import xarray as xr import types -from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysSessionApi -from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession, nan_intervals, build_spike_histogram +from allensdk.brain_observatory.ecephys.ecephys_session_api import \ + EcephysSessionApi +from allensdk.brain_observatory.ecephys.ecephys_session import \ + EcephysSession, nan_intervals, build_spike_histogram @pytest.fixture def raw_stimulus_table(): return pd.DataFrame({ 'start_time': np.arange(4)/2, - 'stop_time':np.arange(1, 5)/2, - 'stimulus_name':['a', 'a', 'a', 'a_movie'], - 'stimulus_block':[0, 0, 0, 1], + 'stop_time': np.arange(1, 5)/2, + 'stimulus_name': ['a', 'a', 'a', 'a_movie'], + 'stimulus_block': [0, 0, 0, 1], 'TF': np.empty(4) * np.nan, - 'SF':np.empty(4) * np.nan, + 'SF': np.empty(4) * np.nan, 'Ori': np.empty(4) * np.nan, 'Contrast': np.empty(4) * np.nan, 'Pos_x': np.empty(4) * np.nan, @@ -28,6 +30,7 @@ def raw_stimulus_table(): "texRes": np.ones([4]) }, index=pd.Index(name='id', data=np.arange(4))) + @pytest.fixture def raw_invalid_times_table(): return pd.DataFrame({ @@ -51,7 +54,6 @@ def raw_spike_times(): } - @pytest.fixture def raw_mean_waveforms(): return { @@ -106,11 +108,13 @@ def raw_lfp(): ) } + @pytest.fixture -def just_stimulus_table_api(raw_stimulus_table): +def just_stim_table_api(raw_stimulus_table): class EcephysJustStimulusTableApi(EcephysSessionApi): def get_stimulus_presentations(self): return raw_stimulus_table + def get_invalid_times(self): return pd.DataFrame() return EcephysJustStimulusTableApi() @@ -121,12 +125,16 @@ def channels_table_api(raw_channels, raw_probes, raw_lfp, raw_stimulus_table): class EcephysChannelsTableApi(EcephysSessionApi): def get_channels(self): return raw_channels + def get_probes(self): return raw_probes + def get_lfp(self, pid): return raw_lfp[pid] + def get_stimulus_presentations(self): return raw_stimulus_table + def get_invalid_times(self): return pd.DataFrame() @@ -134,16 +142,24 @@ def get_invalid_times(self): @pytest.fixture -def lfp_masking_api(raw_channels, raw_probes, raw_lfp, raw_stimulus_table, raw_invalid_times_table): +def lfp_masking_api(raw_channels, + raw_probes, + raw_lfp, + raw_stimulus_table, + raw_invalid_times_table): class EcephysMaskInvalidLFPApi(EcephysSessionApi): def get_channels(self): return raw_channels + def get_probes(self): return raw_probes + def get_lfp(self, pid): return raw_lfp[pid] + def get_stimulus_presentations(self): return raw_stimulus_table + def get_invalid_times(self): return raw_invalid_times_table return EcephysMaskInvalidLFPApi() @@ -154,47 +170,65 @@ def units_table_api(raw_channels, raw_units, raw_probes): class EcephysUnitsTableApi(EcephysSessionApi): def get_channels(self): return raw_channels + def get_units(self): return raw_units + def get_probes(self): - return raw_probes + return raw_probes return EcephysUnitsTableApi() + @pytest.fixture -def valid_stimulus_table_api(raw_stimulus_table,raw_invalid_times_table): +def valid_stimulus_table_api(raw_stimulus_table, raw_invalid_times_table): class EcephysValidStimulusTableApi(EcephysSessionApi): def get_invalid_times(self): return raw_invalid_times_table + def get_stimulus_presentations(self): return raw_stimulus_table return EcephysValidStimulusTableApi() @pytest.fixture -def mean_waveforms_api(raw_mean_waveforms, raw_channels, raw_units, raw_probes): +def mean_waveforms_api(raw_mean_waveforms, + raw_channels, + raw_units, + raw_probes): class EcephysMeanWaveformsApi(EcephysSessionApi): def get_mean_waveforms(self): return raw_mean_waveforms + def get_channels(self): return raw_channels + def get_units(self): return raw_units + def get_probes(self): return raw_probes return EcephysMeanWaveformsApi() @pytest.fixture -def spike_times_api(raw_units, raw_channels, raw_probes, raw_stimulus_table, raw_spike_times): +def spike_times_api(raw_units, + raw_channels, + raw_probes, + raw_stimulus_table, + raw_spike_times): class EcephysSpikeTimesApi(EcephysSessionApi): def get_spike_times(self): return raw_spike_times + def get_channels(self): return raw_channels + def get_units(self): return raw_units + def get_probes(self): return raw_probes + def get_stimulus_presentations(self): return raw_stimulus_table @@ -205,8 +239,8 @@ def get_invalid_times(self): def get_no_spikes_times(self): - # A special method used for testing cases when there are no spikes for a given session, will be swapped out for - # get_spike_times() + # A special method used for testing cases when there are no spikes for a + # given session, will be swapped out for get_spike_times() return { 0: np.array([]), 1: np.array([]), @@ -222,7 +256,7 @@ def get_ecephys_session_id(self): return EcephysSessionMetadataApi() -def test_get_stimulus_epochs(just_stimulus_table_api): +def test_get_stimulus_epochs(just_stim_table_api): expected = pd.DataFrame({ "start_time": [0, 3/2], @@ -232,13 +266,16 @@ def test_get_stimulus_epochs(just_stimulus_table_api): "stimulus_block": [0, 1] }) - session = EcephysSession(api=just_stimulus_table_api) + session = EcephysSession(api=just_stim_table_api) obtained = session.get_stimulus_epochs() print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) def test_get_invalid_times(valid_stimulus_table_api, raw_invalid_times_table): @@ -249,7 +286,10 @@ def test_get_invalid_times(valid_stimulus_table_api, raw_invalid_times_table): obtained = session.get_invalid_times() - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) def test_get_stimulus_presentations(valid_stimulus_table_api): @@ -257,19 +297,26 @@ def test_get_stimulus_presentations(valid_stimulus_table_api): expected = pd.DataFrame({ "start_time": [0, 1/2, 1, 3/2], "stop_time": [1/2, 1, 3/2, 2], - "stimulus_name": ['invalid_presentation', 'invalid_presentation', 'a', 'a_movie'], + "stimulus_name": ['invalid_presentation', + 'invalid_presentation', 'a', 'a_movie'], "phase": [np.nan, np.nan, 120.0, 180.0] }, index=pd.Index(name='stimulus_presentations_id', data=[0, 1, 2, 3])) session = EcephysSession(api=valid_stimulus_table_api) - obtained = session.stimulus_presentations[["start_time", "stop_time", "stimulus_name", "phase"]] + obtained = session.stimulus_presentations[["start_time", + "stop_time", + "stimulus_name", + "phase"]] print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) -def test_get_stimulus_presentations_no_invalid_times(just_stimulus_table_api): +def test_get_stimulus_presentations_no_invalid_times(just_stim_table_api): expected = pd.DataFrame({ "start_time": [0, 1/2, 1, 3/2], @@ -278,13 +325,19 @@ def test_get_stimulus_presentations_no_invalid_times(just_stimulus_table_api): }, index=pd.Index(name='stimulus_presentations_id', data=[0, 1, 2, 3])) - session = EcephysSession(api=just_stimulus_table_api) + session = EcephysSession(api=just_stim_table_api) - obtained = session.stimulus_presentations[["start_time", "stop_time", "stimulus_name"]] + obtained = session.stimulus_presentations[["start_time", + "stop_time", + "stimulus_name"]] print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) + def test_session_metadata(session_metadata_api): session = EcephysSession(api=session_metadata_api) @@ -292,14 +345,15 @@ def test_session_metadata(session_metadata_api): assert 12345 == session.ecephys_session_id -def test_build_stimulus_presentations(just_stimulus_table_api): +def test_build_stimulus_presentations(just_stim_table_api): expected_columns = [ - 'start_time', 'stop_time', 'stimulus_name', 'stimulus_block', - 'temporal_frequency', 'spatial_frequency', 'orientation', 'contrast', - 'x_position', 'y_position', 'color', 'frame', 'phase', 'duration', "stimulus_condition_id" + 'start_time', 'stop_time', 'stimulus_name', 'stimulus_block', + 'temporal_frequency', 'spatial_frequency', 'orientation', 'contrast', + 'x_position', 'y_position', 'color', 'frame', 'phase', + 'duration', "stimulus_condition_id" ] - session = EcephysSession(api=just_stimulus_table_api) + session = EcephysSession(api=just_stim_table_api) obtained = session.stimulus_presentations print(obtained.head()) @@ -330,7 +384,11 @@ def test_build_units_table(units_table_api): def test_presentationwise_spike_counts(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = session.presentationwise_spike_counts(np.linspace(-.1, .1, 3), session.stimulus_presentations.index.values, session.units.index.values) + obtained = \ + session.presentationwise_spike_counts( + np.linspace(-.1, .1, 3), + session.stimulus_presentations.index.values, + session.units.index.values) first = obtained.loc[{'unit_id': 2, 'stimulus_presentation_id': 2}] assert np.allclose([0, 3], first) @@ -343,32 +401,34 @@ def test_presentationwise_spike_counts(spike_times_api): @pytest.mark.parametrize("spike_times,time_domain,expected", [ [ - {1: [1.5, 2.5]}, + {1: [1.5, 2.5]}, [[1, 2, 3, 4], [1.1, 2.1, 3.1, 4.1]], np.array([[1, 1, 0], [1, 1, 0]])[:, :, None] ], [ - {1: [1.5, 2.5]}, + {1: [1.5, 2.5]}, [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], np.array([[1, 1, 0], [0, 1, 0]])[:, :, None] ], [ - {1: [1.5, 2.5], 2: [1.5, 2.5]}, + {1: [1.5, 2.5], 2: [1.5, 2.5]}, [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], np.stack(([[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]), axis=2) - ] -, + ], [ - {1: [1.5, 2.5], 2: [1.5, 1.55]}, + {1: [1.5, 2.5], 2: [1.5, 1.55]}, [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], np.stack(([[1, 1, 0], [0, 1, 0]], [[2, 0, 0], [0, 0, 0]]), axis=2) ] ]) @pytest.mark.parametrize("binarize", [True, False]) def test_build_spike_histogram(spike_times, time_domain, expected, binarize): - + unit_ids = [k for k in spike_times.keys()] - obtained = build_spike_histogram(time_domain, spike_times, unit_ids, binarize=binarize) + obtained = build_spike_histogram(time_domain, + spike_times, + unit_ids, + binarize=binarize) expected = np.array(expected) if binarize: @@ -380,7 +440,10 @@ def test_build_spike_histogram(spike_times, time_domain, expected, binarize): def test_presentationwise_spike_times(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = session.presentationwise_spike_times(session.stimulus_presentations.index.values, session.units.index.values) + obtained = \ + session.presentationwise_spike_times( + session.stimulus_presentations.index.values, + session.units.index.values) expected = pd.DataFrame({ 'unit_id': [2, 2, 2], @@ -388,22 +451,32 @@ def test_presentationwise_spike_times(spike_times_api): 'time_since_stimulus_presentation_onset': [0.01, 0.02, 0.03] }, index=pd.Index(name='spike_time', data=[1.01, 1.02, 1.03])) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) def test_empty_presentationwise_spike_times(spike_times_api): - # Test that when there are no spikes presentationwise_spike_times doesn't fail and instead returns a empty dataframe - spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, spike_times_api) + # Test that when there are no spikes presentationwise_spike_times + # doesn't fail and instead returns a empty dataframe + spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, + spike_times_api) session = EcephysSession(api=spike_times_api) - obtained = session.presentationwise_spike_times(session.stimulus_presentations.index.values, - session.units.index.values) + obtained = \ + session.presentationwise_spike_times( + session.stimulus_presentations.index.values, + session.units.index.values) + assert(isinstance(obtained, pd.DataFrame)) assert(obtained.empty) def test_conditionwise_spike_statistics(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = session.conditionwise_spike_statistics(stimulus_presentation_ids=[0, 1, 2]) + obtained = \ + session.conditionwise_spike_statistics( + stimulus_presentation_ids=[0, 1, 2]) pd.set_option('display.max_columns', None) @@ -413,7 +486,9 @@ def test_conditionwise_spike_statistics(spike_times_api): def test_conditionwise_spike_statistics_using_rates(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = session.conditionwise_spike_statistics(stimulus_presentation_ids=[0, 1, 2], use_rates=True) + obtained = \ + session.conditionwise_spike_statistics( + stimulus_presentation_ids=[0, 1, 2], use_rates=True) pd.set_option('display.max_columns', None) assert np.allclose([0, 0, 6], obtained["spike_mean"].values) @@ -421,7 +496,8 @@ def test_conditionwise_spike_statistics_using_rates(spike_times_api): def test_empty_conditionwise_spike_statistics(spike_times_api): # special case when there are no spikes - spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, spike_times_api) + spike_times_api.get_spike_times = \ + types.MethodType(get_no_spikes_times, spike_times_api) session = EcephysSession(api=spike_times_api) obtained = session.conditionwise_spike_statistics( stimulus_presentation_ids=session.stimulus_presentations.index.values, @@ -430,30 +506,39 @@ def test_empty_conditionwise_spike_statistics(spike_times_api): assert(len(obtained) == 12) assert(not np.any(obtained['spike_count'])) # check all spike_counts are 0 assert(not np.any(obtained['spike_mean'])) # spike_means are 0 - assert(np.all(np.isnan(obtained['spike_std']))) # std/sem will be undefined + assert(np.all(np.isnan(obtained['spike_std']))) # std/sem is undefined assert(np.all(np.isnan(obtained['spike_sem']))) -def test_get_stimulus_parameter_values(just_stimulus_table_api): - session = EcephysSession(api=just_stimulus_table_api) +def test_get_stimulus_parameter_values(just_stim_table_api): + session = EcephysSession(api=just_stim_table_api) obtained = session.get_stimulus_parameter_values() expected = { 'color': [0, 5.5, 11, 16.5], 'phase': [0, 60, 120, 180] } - + for k, v in expected.items(): assert np.allclose(v, obtained[k]) assert len(expected) == len(obtained) @pytest.mark.parametrize("detailed", [True, False]) -def test_get_stimulus_table(detailed, just_stimulus_table_api, raw_stimulus_table): - session = EcephysSession(api=just_stimulus_table_api) - obtained = session.get_stimulus_table(['a'], include_detailed_parameters=detailed) - - expected_columns = ['start_time', 'stop_time', 'stimulus_name', 'stimulus_block', 'Color', 'Phase'] +def test_get_stimulus_table(detailed, + just_stim_table_api, + raw_stimulus_table): + session = EcephysSession(api=just_stim_table_api) + obtained = session.get_stimulus_table( + ['a'], + include_detailed_parameters=detailed) + + expected_columns = ['start_time', + 'stop_time', + 'stimulus_name', + 'stimulus_block', + 'Color', + 'Phase'] if detailed: expected_columns.append("texRes") expected = raw_stimulus_table.loc[:2, expected_columns] @@ -465,27 +550,30 @@ def test_get_stimulus_table(detailed, just_stimulus_table_api, raw_stimulus_tabl print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) -def test_filter_owned_df(just_stimulus_table_api): - session = EcephysSession(api=just_stimulus_table_api) +def test_filter_owned_df(just_stim_table_api): + session = EcephysSession(api=just_stim_table_api) ids = [0, 2] obtained = session._filter_owned_df('stimulus_presentations', ids) assert np.allclose([0, 120], obtained['phase'].values) -def test_filter_owned_df_scalar(just_stimulus_table_api): - session = EcephysSession(api=just_stimulus_table_api) +def test_filter_owned_df_scalar(just_stim_table_api): + session = EcephysSession(api=just_stim_table_api) ids = 3 obtained = session._filter_owned_df('stimulus_presentations', ids) assert obtained['phase'].values[0] == 180 -def test_build_inter_presentation_intervals(just_stimulus_table_api): - session = EcephysSession(api=just_stimulus_table_api) +def test_build_inter_presentation_intervals(just_stim_table_api): + session = EcephysSession(api=just_stim_table_api) obtained = session.inter_presentation_intervals expected = pd.DataFrame({ @@ -497,11 +585,14 @@ def test_build_inter_presentation_intervals(just_stimulus_table_api): ) ) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) -def test_get_inter_presentation_intervals_for_stimulus(just_stimulus_table_api): - session = EcephysSession(api=just_stimulus_table_api) +def test_get_inter_presentation_intervals_for_stimulus(just_stim_table_api): + session = EcephysSession(api=just_stim_table_api) obtained = session.get_inter_presentation_intervals_for_stimulus('a') expected = pd.DataFrame({ @@ -513,7 +604,10 @@ def test_get_inter_presentation_intervals_for_stimulus(just_stimulus_table_api): ) ) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) + pd.testing.assert_frame_equal(expected, + obtained, + check_like=True, + check_dtype=False) def test_get_lfp(channels_table_api): diff --git a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py index 3d1298c44..405750874 100644 --- a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py +++ b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py @@ -15,10 +15,11 @@ from allensdk.brain_observatory.ecephys.current_source_density.__main__ \ import write_csd_to_h5 import allensdk.brain_observatory.ecephys.write_nwb.__main__ as write_nwb -from allensdk.brain_observatory.ecephys.ecephys_session_api import \ - EcephysNwbSessionApi +from allensdk.brain_observatory.ecephys.ecephys_session_api \ + import EcephysNwbSessionApi from allensdk.test.brain_observatory.behavior.test_eye_tracking_processing \ import create_preload_eye_tracking_df +from allensdk.brain_observatory.nwb import setup_table_for_invalid_times @pytest.fixture @@ -88,24 +89,24 @@ def test_roundtrip_basic_metadata(roundtripper): @pytest.mark.parametrize("metadata, expected_metadata", [ ({ - "specimen_name": "mouse_1", - "age_in_days": 100.0, - "full_genotype": "wt", - "strain": "c57", - "sex": "F", - "stimulus_name": "brain_observatory_2.0", - "donor_id": 12345, - "species": "Mus musculus"}, + "specimen_name": "mouse_1", + "age_in_days": 100.0, + "full_genotype": "wt", + "strain": "c57", + "sex": "F", + "stimulus_name": "brain_observatory_2.0", + "donor_id": 12345, + "species": "Mus musculus"}, { - "specimen_name": "mouse_1", - "age_in_days": 100.0, - "age": "P100D", - "full_genotype": "wt", - "strain": "c57", - "sex": "F", - "stimulus_name": "brain_observatory_2.0", - "subject_id": "12345", - "species": "Mus musculus"}) + "specimen_name": "mouse_1", + "age_in_days": 100.0, + "age": "P100D", + "full_genotype": "wt", + "strain": "c57", + "sex": "F", + "stimulus_name": "brain_observatory_2.0", + "subject_id": "12345", + "species": "Mus musculus"}) ]) def test_add_metadata(nwbfile, roundtripper, metadata, expected_metadata): nwbfile = write_nwb.add_metadata_to_nwbfile(nwbfile, metadata) @@ -120,8 +121,8 @@ def test_add_metadata(nwbfile, roundtripper, metadata, expected_metadata): if obtained[key] != value: misses[key] = {"expected": value, "obtained": obtained[key]} - assert len( - misses) == 0, f"the following metadata items were mismatched: {misses}" + assert len(misses) == 0, \ + f"the following metadata items were mismatched: {misses}" @pytest.mark.parametrize("presentations", [ @@ -150,13 +151,16 @@ def test_add_stimulus_presentations(nwbfile, presentations, roundtripper): api = roundtripper(nwbfile, EcephysNwbSessionApi) obtained_stimulus_table = api.get_stimulus_presentations() - pd.testing.assert_frame_equal(presentations, obtained_stimulus_table, - check_dtype=False) + pd.testing.assert_frame_equal( + presentations, + obtained_stimulus_table, + check_dtype=False) -def test_add_stimulus_presentations_color(nwbfile, - stimulus_presentations_color, - roundtripper): +def test_add_stimulus_presentations_color( + nwbfile, + stimulus_presentations_color, + roundtripper): write_nwb.add_stimulus_timestamps(nwbfile, [0, 1]) write_nwb.add_stimulus_presentations(nwbfile, stimulus_presentations_color) @@ -171,8 +175,8 @@ def test_add_stimulus_presentations_color(nwbfile, if expected != obtained: mismatched = True - assert not mismatched, f"expected: {expected_color}, obtain" \ - f"ed: {obtained_color}" + assert not mismatched, \ + f"expected: {expected_color}, obtained: {obtained_color}" @pytest.mark.parametrize("opto_table, expected", [ @@ -203,8 +207,12 @@ def test_add_stimulus_presentations_color(nwbfile, "stimulus_name": ["w", "x", "y", "z"]}), None) ]) -def test_add_optotagging_table_to_nwbfile(nwbfile, roundtripper, opto_table, - expected): +def test_add_optotagging_table_to_nwbfile( + nwbfile, + roundtripper, + opto_table, + expected): + opto_table["duration"] = opto_table["stop_time"] - opto_table["start_time"] nwbfile = write_nwb.add_optotagging_table_to_nwbfile(nwbfile, opto_table) @@ -235,8 +243,17 @@ def test_add_optotagging_table_to_nwbfile(nwbfile, roundtripper, opto_table, }, index=pd.Index([12], name="id")) ] ]) -def test_add_probe_to_nwbfile(nwbfile, roundtripper, roundtrip, pid, name, - srate, lfp_srate, has_lfp, expected): +def test_add_probe_to_nwbfile( + nwbfile, + roundtripper, + roundtrip, + pid, + name, + srate, + lfp_srate, + has_lfp, + expected): + nwbfile, _, _ = write_nwb.add_probe_to_nwbfile(nwbfile, pid, name=name, sampling_rate=srate, @@ -265,6 +282,7 @@ def test_add_probe_to_nwbfile(nwbfile, roundtripper, roundtrip, pid, name, ]) def test_add_ecephys_electrode_columns(nwbfile, columns_to_add, expected_columns): + write_nwb.add_ecephys_electrode_columns(nwbfile, columns_to_add) assert set(nwbfile.electrodes.colnames) == expected_columns @@ -272,54 +290,53 @@ def test_add_ecephys_electrode_columns(nwbfile, columns_to_add, @pytest.mark.parametrize(("channels, local_index_whitelist, " "expected_electrode_table"), [ - ([{"id": 1, - "probe_id": 1234, - "valid_data": True, - "local_index": 23, - "probe_vertical_position": 10, - "probe_horizontal_position": 10, - "anterior_posterior_ccf_coordinate": 15.0, - "dorsal_ventral_ccf_coordinate": 20.0, - "left_right_ccf_coordinate": 25.0, - "manual_structure_acronym": "CA1", - "impedence": np.nan, - "filtering": "AP band: 500 Hz high-pass; LFP " - "band: 1000 Hz low-pass"}, - {"id": 2, - "probe_id": 1234, - "valid_data": True, - "local_index": 15, - "probe_vertical_position": 20, - "probe_horizontal_position": 20, - "anterior_posterior_ccf_coordinate": 25.0, - "dorsal_ventral_ccf_coordinate": 30.0, - "left_right_ccf_coordinate": 35.0, - "manual_structure_acronym": "CA3", - "impedence": 42.0, - "filtering": "custom"}], - - [15, 23], - - pd.DataFrame({ - "id": [2, 1], - "probe_id": [1234, 1234], - "valid_data": [True, True], - "local_index": [15, 23], - "probe_vertical_position": [20, 10], - "probe_horizontal_position": [20, 10], - "x": [25.0, 15.0], - "y": [30.0, 20.0], - "z": [35.0, 25.0], - "location": ["CA3", "CA1"], - "imp": [42.0, np.nan], - "filtering": ["custom", - "AP band: 500 Hz high-pass; " - "LFP band: 1000 Hz low-pass"] - }).set_index("id")) - - ]) + ([{"id": 1, + "probe_id": 1234, + "valid_data": True, + "local_index": 23, + "probe_vertical_position": 10, + "probe_horizontal_position": 10, + "anterior_posterior_ccf_coordinate": 15.0, + "dorsal_ventral_ccf_coordinate": 20.0, + "left_right_ccf_coordinate": 25.0, + "manual_structure_acronym": "CA1", + "impedence": np.nan, + "filtering": "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass"}, + {"id": 2, + "probe_id": 1234, + "valid_data": True, + "local_index": 15, + "probe_vertical_position": 20, + "probe_horizontal_position": 20, + "anterior_posterior_ccf_coordinate": 25.0, + "dorsal_ventral_ccf_coordinate": 30.0, + "left_right_ccf_coordinate": 35.0, + "manual_structure_acronym": "CA3", + "impedence": 42.0, + "filtering": "custom"}], + + [15, 23], + + pd.DataFrame({ + "id": [2, 1], + "probe_id": [1234, 1234], + "valid_data": [True, True], + "local_index": [15, 23], + "probe_vertical_position": [20, 10], + "probe_horizontal_position": [20, 10], + "x": [25.0, 15.0], + "y": [30.0, 20.0], + "z": [35.0, 25.0], + "location": ["CA3", "CA1"], + "imp": [42.0, np.nan], + "filtering": ["custom", + "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass"] + }).set_index("id")) + +]) def test_add_ecephys_electrodes(nwbfile, channels, local_index_whitelist, expected_electrode_table): + mock_device = pynwb.device.Device(name="mock_device") mock_electrode_group = pynwb.ecephys.ElectrodeGroup(name="mock_group", description="", @@ -329,8 +346,8 @@ def test_add_ecephys_electrodes(nwbfile, channels, local_index_whitelist, write_nwb.add_ecephys_electrodes(nwbfile, channels, mock_electrode_group, local_index_whitelist) - obt_electrode_table = nwbfile.electrodes.to_dataframe().drop( - columns=["group", "group_name"]) + obt_electrode_table = \ + nwbfile.electrodes.to_dataframe().drop(columns=["group", "group_name"]) pd.testing.assert_frame_equal(obt_electrode_table, expected_electrode_table, @@ -341,12 +358,14 @@ def test_add_ecephys_electrodes(nwbfile, channels, local_index_whitelist, [{"a": [1, 2, 3], "b": [4, 5, 6]}, ["a", "b"], [3, 6], [1, 2, 3, 4, 5, 6]] ]) def test_dict_to_indexed_array(dc, order, exp_idx, exp_data): + obt_idx, obt_data = write_nwb.dict_to_indexed_array(dc, order) assert np.allclose(exp_idx, obt_idx) assert np.allclose(exp_data, obt_data) def test_add_ragged_data_to_dynamic_table(units_table, spike_times): + write_nwb.add_ragged_data_to_dynamic_table( table=units_table, data=spike_times, @@ -362,8 +381,13 @@ def test_add_ragged_data_to_dynamic_table(units_table, spike_times): [True, True], [True, False] ]) -def test_add_running_speed_to_nwbfile(nwbfile, running_speed, roundtripper, - roundtrip, include_rotation): +def test_add_running_speed_to_nwbfile( + nwbfile, + running_speed, + roundtripper, + roundtrip, + include_rotation): + nwbfile = write_nwb.add_running_speed_to_nwbfile(nwbfile, running_speed) if roundtrip: api_obt = roundtripper(nwbfile, EcephysNwbSessionApi) @@ -379,10 +403,15 @@ def test_add_running_speed_to_nwbfile(nwbfile, running_speed, roundtripper, @pytest.mark.parametrize("roundtrip", [[True]]) -def test_add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, - roundtripper, roundtrip): - nwbfile = write_nwb.add_raw_running_data_to_nwbfile(nwbfile, - raw_running_data) +def test_add_raw_running_data_to_nwbfile( + nwbfile, + raw_running_data, + roundtripper, + roundtrip): + + nwbfile = write_nwb.add_raw_running_data_to_nwbfile( + nwbfile, + raw_running_data) if roundtrip: api_obt = roundtripper(nwbfile, EcephysNwbSessionApi) else: @@ -391,7 +420,8 @@ def test_add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, obtained = api_obt.get_raw_running_data() expected = raw_running_data.rename( - columns={"dx": "net_rotation", "vsig": "signal_voltage", + columns={"dx": "net_rotation", + "vsig": "signal_voltage", "vin": "supply_voltage"}) pd.testing.assert_frame_equal(expected, obtained, check_like=True) @@ -399,43 +429,61 @@ def test_add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, @pytest.mark.parametrize( "presentations, column_renames_map, columns_to_drop, expected", [ (pd.DataFrame({'alpha': [0.5, 0.4, 0.3, 0.2, 0.1], - 'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + 'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'start_time': [1., 2., 4., 5., 6.], 'stop_time': [2., 4., 5., 6., 8.]}), {"alpha": "beta"}, None, pd.DataFrame({'beta': [0.5, 0.4, 0.3, 0.2, 0.1], - 'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + 'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'start_time': [1., 2., 4., 5., 6.], 'stop_time': [2., 4., 5., 6., 8.]})), (pd.DataFrame({'alpha': [0.5, 0.4, 0.3, 0.2, 0.1], - 'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + 'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'start_time': [1., 2., 4., 5., 6.], 'stop_time': [2., 4., 5., 6., 8.]}), {"alpha": "beta"}, ["Nonexistant_column_to_drop"], pd.DataFrame({'beta': [0.5, 0.4, 0.3, 0.2, 0.1], - 'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + 'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'start_time': [1., 2., 4., 5., 6.], 'stop_time': [2., 4., 5., 6., 8.]})), (pd.DataFrame({'alpha': [0.5, 0.4, 0.3, 0.2, 0.1], - 'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + 'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'Start': [1., 2., 4., 5., 6.], 'End': [2., 4., 5., 6., 8.]}), None, ["alpha"], - pd.DataFrame({'stimulus_name': ['gabors', 'gabors', 'random', 'movie', + pd.DataFrame({'stimulus_name': ['gabors', + 'gabors', + 'random', + 'movie', 'gabors'], 'start_time': [1., 2., 4., 5., 6.], 'stop_time': [2., 4., 5., 6., 8.]})), - ]) + ]) def test_read_stimulus_table(tmpdir_factory, presentations, column_renames_map, columns_to_drop, expected): dirname = str(tmpdir_factory.mktemp("ecephys_nwb_test")) @@ -449,8 +497,6 @@ def test_read_stimulus_table(tmpdir_factory, presentations, pd.testing.assert_frame_equal(obt, expected) -# read_spike_times_to_dictionary(spike_times_path, spike_units_path, -# local_to_global_unit_map=None) def test_read_spike_times_to_dictionary(tmpdir_factory): dirname = str(tmpdir_factory.mktemp("ecephys_nwb_spike_times")) spike_times_path = os.path.join(dirname, "spike_times.npy") @@ -464,13 +510,14 @@ def test_read_spike_times_to_dictionary(tmpdir_factory): local_to_global_unit_map = {ii: -ii for ii in spike_units} - obtained = \ - write_nwb.read_spike_times_to_dictionary(spike_times_path, - spike_units_path, - local_to_global_unit_map) + obtained = write_nwb.read_spike_times_to_dictionary( + spike_times_path, + spike_units_path, + local_to_global_unit_map) for ii in range(15): - assert np.allclose(obtained[-ii], - sorted([spike_times[ii], spike_times[15 + ii]])) + assert np.allclose( + obtained[-ii], + sorted([spike_times[ii], spike_times[15 + ii]])) def test_read_waveforms_to_dictionary(tmpdir_factory): @@ -486,8 +533,9 @@ def test_read_waveforms_to_dictionary(tmpdir_factory): mean_waveforms = np.random.rand(nunits, nsamples, nchannels) np.save(waveforms_path, mean_waveforms, allow_pickle=False) - obtained = write_nwb.read_waveforms_to_dictionary(waveforms_path, - local_to_global_unit_map) + obtained = write_nwb.read_waveforms_to_dictionary( + waveforms_path, + local_to_global_unit_map) for ii in range(nunits): assert np.allclose(mean_waveforms[ii, :, :], obtained[-ii]) @@ -498,9 +546,10 @@ def lfp_data(): subsample_channels = np.array([3, 2]) return { - "data": np.arange(total_timestamps * len(subsample_channels), - dtype=np.int16).reshape( - (total_timestamps, len(subsample_channels))), + "data": np.arange( + total_timestamps * len(subsample_channels), + dtype=np.int16).reshape((total_timestamps, + len(subsample_channels))), "timestamps": np.linspace(0, 1, total_timestamps), "subsample_channels": subsample_channels } @@ -585,13 +634,14 @@ def csd_data(): def test_write_probe_lfp_file(tmpdir_factory, lfp_data, probe_data, csd_data): + tmpdir = Path(tmpdir_factory.mktemp("probe_lfp_nwb")) input_data_path = tmpdir / Path("lfp_data.dat") input_timestamps_path = tmpdir / Path("lfp_timestamps.npy") input_channels_path = tmpdir / Path("lfp_channels.npy") input_csd_path = tmpdir / Path("csd.h5") - output_path = str( - tmpdir / Path("lfp.nwb")) # pynwb.NWBHDF5IO chokes on Path + output_path = str(tmpdir / Path("lfp.nwb")) + # pynwb.NWBHDF5IO chokes on Path test_lfp_paths = { "input_data_path": input_data_path, @@ -616,17 +666,24 @@ def test_write_probe_lfp_file(tmpdir_factory, lfp_data, probe_data, csd_data): write_csd_to_h5(path=input_csd_path, **csd_data) - np.save(input_timestamps_path, lfp_data["timestamps"], allow_pickle=False) - np.save(input_channels_path, lfp_data["subsample_channels"], + np.save(input_timestamps_path, + lfp_data["timestamps"], + allow_pickle=False) + np.save(input_channels_path, + lfp_data["subsample_channels"], allow_pickle=False) with open(input_data_path, "wb") as input_data_file: input_data_file.write(lfp_data["data"].tobytes()) - write_nwb.write_probe_lfp_file(4242, test_session_metadata, datetime.now(), - logging.INFO, probe_data) + write_nwb.write_probe_lfp_file( + 4242, + test_session_metadata, + datetime.now(), + logging.INFO, probe_data) - exp_electrodes = pd.DataFrame(probe_data["channels"]).set_index("id").loc[ - [2, 1], :] + exp_electrodes = \ + pd.DataFrame(probe_data["channels"]).set_index("id").loc[[2, 1], :] + exp_electrodes = exp_electrodes.rename(columns={'impedence': 'imp'}) exp_electrodes.rename(columns={"anterior_posterior_ccf_coordinate": "x", "dorsal_ventral_ccf_coordinate": "y", "left_right_ccf_coordinate": "z", @@ -636,18 +693,20 @@ def test_write_probe_lfp_file(tmpdir_factory, lfp_data, probe_data, csd_data): with pynwb.NWBHDF5IO(output_path, "r") as obt_io: obt_f = obt_io.read() - obt_ser = obt_f.get_acquisition("probe_12345_lfp").electrical_series[ - "probe_12345_lfp_data"] + obt_acq = \ + obt_f.get_acquisition("probe_12345_lfp") + obt_ser = obt_acq.electrical_series["probe_12345_lfp_data"] assert np.allclose(lfp_data["data"], obt_ser.data[:]) assert np.allclose(lfp_data["timestamps"], obt_ser.timestamps[:]) - obt_electrodes = obt_f.electrodes.to_dataframe().loc[ - :, ["local_index", "probe_horizontal_position", - "probe_id", "probe_vertical_position", - "valid_data", "x", "y", "z", "location", - "filtering"] - ] - obt_electrodes["impedence"] = np.nan + obt_electrodes_df = obt_f.electrodes.to_dataframe() + + obt_electrodes = obt_electrodes_df.loc[ + :, ["local_index", "probe_horizontal_position", + "probe_id", "probe_vertical_position", + "valid_data", "x", "y", "z", "location", "imp", + "filtering"] + ] assert obt_f.session_id == "4242" assert obt_f.subject.subject_id == "42" @@ -656,30 +715,41 @@ def test_write_probe_lfp_file(tmpdir_factory, lfp_data, probe_data, csd_data): # that are causing tests to fail. # Perhaps related to: https://stackoverflow.com/a/36279549 if platform.system() == "Windows": - pd.testing.assert_frame_equal(obt_electrodes, exp_electrodes, - check_like=True, check_dtype=False) + pd.testing.assert_frame_equal( + obt_electrodes, + exp_electrodes, + check_like=True, + check_dtype=False) else: - pd.testing.assert_frame_equal(obt_electrodes, exp_electrodes, - check_like=True) + pd.testing.assert_frame_equal( + obt_electrodes, + exp_electrodes, + check_like=True) + + processing_module = \ + obt_f.get_processing_module("current_source_density") - csd_series = obt_f.get_processing_module("current_source_density")[ - "ecephys_csd"] + csd_series = processing_module["ecephys_csd"] assert np.allclose(csd_data["csd"], csd_series.time_series.data[:].T) assert np.allclose(csd_data["relative_window"], csd_series.time_series.timestamps[:]) - obt_channel_locations = np.stack( - (csd_series.virtual_electrode_x_positions, - csd_series.virtual_electrode_y_positions), - axis=1) - assert np.allclose([[1, 2], [3, 3]], - obt_channel_locations) # csd interpolated - # channel locations + obt_channel_locations = \ + np.stack((csd_series.virtual_electrode_x_positions, + csd_series.virtual_electrode_y_positions), axis=1) + + # csd interpolated channel locations + assert np.allclose([[1, 2], [3, 3]], obt_channel_locations) @pytest.mark.parametrize("roundtrip", [True, False]) -def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, - probe_data, csd_data): +def test_write_probe_lfp_file_roundtrip( + tmpdir_factory, + roundtrip, + lfp_data, + probe_data, + csd_data): + expected_csd = xr.DataArray( name="CSD", data=csd_data["csd"], @@ -687,10 +757,14 @@ def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, coords={ "virtual_channel_index": np.arange(csd_data["csd"].shape[0]), "time": csd_data["relative_window"], - "vertical_position": - (("virtual_channel_index",), csd_data["csd_locations"][:, 1]), + "vertical_position": ( + ("virtual_channel_index",), + csd_data["csd_locations"][:, 1] + ), "horizontal_position": ( - ("virtual_channel_index",), csd_data["csd_locations"][:, 0]), + ("virtual_channel_index",), + csd_data["csd_locations"][:, 0] + ), } ) @@ -706,8 +780,7 @@ def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, input_timestamps_path = tmpdir / Path("lfp_timestamps.npy") input_channels_path = tmpdir / Path("lfp_channels.npy") input_csd_path = tmpdir / Path("csd.h5") - output_path = str( - tmpdir / Path("lfp.nwb")) # pynwb.NWBHDF5IO chokes on Path + output_path = str(tmpdir / Path("lfp.nwb")) test_lfp_paths = { "input_data_path": input_data_path, @@ -721,17 +794,25 @@ def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, write_csd_to_h5(path=input_csd_path, **csd_data) - np.save(input_timestamps_path, lfp_data["timestamps"], allow_pickle=False) - np.save(input_channels_path, lfp_data["subsample_channels"], + np.save(input_timestamps_path, + lfp_data["timestamps"], + allow_pickle=False) + np.save(input_channels_path, + lfp_data["subsample_channels"], allow_pickle=False) with open(input_data_path, "wb") as input_data_file: input_data_file.write(lfp_data["data"].tobytes()) - write_nwb.write_probe_lfp_file(4242, None, datetime.now(), logging.INFO, - probe_data) + write_nwb.write_probe_lfp_file( + 4242, + None, + datetime.now(), + logging.INFO, + probe_data) - obt = EcephysNwbSessionApi(path=None, probe_lfp_paths={ - 12345: NWBHDF5IO(output_path, "r").read}) + obt = EcephysNwbSessionApi( + path=None, + probe_lfp_paths={12345: NWBHDF5IO(output_path, "r").read}) obtained_lfp = obt.get_lfp(12345) obtained_csd = obt.get_current_source_density(12345) @@ -742,6 +823,7 @@ def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, @pytest.fixture def invalid_epochs(): + epochs = [ { "type": "EcephysSession", @@ -770,8 +852,9 @@ def invalid_epochs(): def test_add_invalid_times(invalid_epochs, tmpdir_factory): - nwbfile_name = str( - tmpdir_factory.mktemp("test").join("test_invalid_times.nwb")) + + nwbfile_name = \ + str(tmpdir_factory.mktemp("test").join("test_invalid_times.nwb")) nwbfile = NWBFile( session_description="EcephysSession", @@ -788,12 +871,15 @@ def test_add_invalid_times(invalid_epochs, tmpdir_factory): df = nwbfile.invalid_times.to_dataframe() df_in = nwbfile_in.invalid_times.to_dataframe() - pd.testing.assert_frame_equal(df, df_in, check_like=True, + pd.testing.assert_frame_equal(df, + df_in, + check_like=True, check_dtype=False) def test_roundtrip_add_invalid_times(nwbfile, invalid_epochs, roundtripper): - expected = write_nwb.setup_table_for_invalid_times(invalid_epochs) + + expected = setup_table_for_invalid_times(invalid_epochs) nwbfile = write_nwb.add_invalid_times(nwbfile, invalid_epochs) api = roundtripper(nwbfile, EcephysNwbSessionApi) @@ -803,11 +889,13 @@ def test_roundtrip_add_invalid_times(nwbfile, invalid_epochs, roundtripper): def test_no_invalid_times_table(): + epochs = [] - assert write_nwb.setup_table_for_invalid_times(epochs).empty is True + assert setup_table_for_invalid_times(epochs).empty is True def test_setup_table_for_invalid_times(): + epoch = { "type": "EcephysSession", "id": 739448407, @@ -816,7 +904,7 @@ def test_setup_table_for_invalid_times(): "end_time": 2005.0, } - s = write_nwb.setup_table_for_invalid_times([epoch]).loc[0] + s = setup_table_for_invalid_times([epoch]).loc[0] assert s["start_time"] == epoch["start_time"] assert s["stop_time"] == epoch["end_time"] @@ -856,20 +944,30 @@ def expected_amplitudes(): return np.array([0, 15, 60, 45, 120]) -def test_scale_amplitudes(spike_amplitudes, templates, spike_templates, - expected_amplitudes): +def test_scale_amplitudes( + spike_amplitudes, + templates, + spike_templates, + expected_amplitudes): + scale_factor = 0.195 expected = expected_amplitudes * scale_factor - obtained = write_nwb.scale_amplitudes(spike_amplitudes, templates, - spike_templates, scale_factor) + obtained = write_nwb.scale_amplitudes( + spike_amplitudes, + templates, + spike_templates, + scale_factor) assert np.allclose(expected, obtained) -def test_read_spike_amplitudes_to_dictionary(tmpdir_factory, spike_amplitudes, - templates, spike_templates, - expected_amplitudes): +def test_read_spike_amplitudes_to_dictionary( + tmpdir_factory, + spike_amplitudes, + templates, + spike_templates, + expected_amplitudes): tmpdir = str(tmpdir_factory.mktemp("spike_amps")) spike_amplitudes_path = os.path.join(tmpdir, "spike_amplitudes.npy") @@ -893,7 +991,8 @@ def test_read_spike_amplitudes_to_dictionary(tmpdir_factory, spike_amplitudes, np.save(spike_units_path, spike_units, allow_pickle=False) np.save(templates_path, templates, allow_pickle=False) np.save(spike_templates_path, spike_templates, allow_pickle=False) - np.save(inverse_whitening_matrix_path, inverse_whitening_matrix, + np.save(inverse_whitening_matrix_path, + inverse_whitening_matrix, allow_pickle=False) obtained = write_nwb.read_spike_amplitudes_to_dictionary( @@ -908,29 +1007,31 @@ def test_read_spike_amplitudes_to_dictionary(tmpdir_factory, spike_amplitudes, assert np.allclose(expected_amplitudes[3:], obtained[1]) -@pytest.mark.parametrize( - "spike_times_mapping, spike_amplitudes_mapping, expected", [ +@pytest.mark.parametrize("""spike_times_mapping, + spike_amplitudes_mapping, expected""", [ - ({12345: np.array([0, 1, 2, -1, 5, 4])}, # spike_times_mapping + ({12345: np.array([0, 1, 2, -1, 5, 4])}, # spike_times_mapping - {12345: np.array([0, 1, 2, 3, 4, 5])}, # spike_amplitudes_mapping + {12345: np.array([0, 1, 2, 3, 4, 5])}, # spike_amplitudes_mapping - ({12345: np.array([0, 1, 2, 4, 5])}, # expected - {12345: np.array([0, 1, 2, 5, 4])})), + ({12345: np.array([0, 1, 2, 4, 5])}, # expected + {12345: np.array([0, 1, 2, 5, 4])})), - ({12345: np.array([0, 1, 2, -1, 5, 4]), # spike_times_mapping - 54321: np.array([5, 4, 3, -1, 6])}, + ({12345: np.array([0, 1, 2, -1, 5, 4]), # spike_times_mapping + 54321: np.array([5, 4, 3, -1, 6])}, - {12345: np.array([0, 1, 2, 3, 4, 5]), # spike_amplitudes_mapping - 54321: np.array([0, 1, 2, 3, 4])}, + {12345: np.array([0, 1, 2, 3, 4, 5]), # spike_amplitudes_mapping + 54321: np.array([0, 1, 2, 3, 4])}, - ({12345: np.array([0, 1, 2, 4, 5]), # expected - 54321: np.array([3, 4, 5, 6])}, - {12345: np.array([0, 1, 2, 5, 4]), - 54321: np.array([2, 1, 0, 4])})), - ]) -def test_filter_and_sort_spikes(spike_times_mapping, spike_amplitudes_mapping, - expected): + ({12345: np.array([0, 1, 2, 4, 5]), # expected + 54321: np.array([3, 4, 5, 6])}, + {12345: np.array([0, 1, 2, 5, 4]), + 54321: np.array([2, 1, 0, 4])})), +]) +def test_filter_and_sort_spikes( + spike_times_mapping, + spike_amplitudes_mapping, + expected): expected_spike_times, expected_spike_amplitudes = expected obtained_spike_times, obtained_spike_amplitudes = \ @@ -1011,6 +1112,7 @@ def test_filter_and_sort_spikes(spike_times_mapping, spike_amplitudes_mapping, def test_add_probewise_data_to_nwbfile(monkeypatch, nwbfile, roundtripper, roundtrip, probes, parsed_probe_data, expected_units_table): + def mock_parse_probes_data(probes): return parsed_probe_data @@ -1027,7 +1129,7 @@ def mock_parse_probes_data(probes): @pytest.mark.parametrize("roundtrip", [True, False]) -@pytest.mark.parametrize("eye_tracking_rig_geometry, expected", [ +@pytest.mark.parametrize("eye_tracking_rig_geom, expected", [ ({"monitor_position_mm": [1., 2., 3.], "monitor_rotation_deg": [4., 5., 6.], "camera_position_mm": [7., 8., 9.], @@ -1044,14 +1146,16 @@ def mock_parse_probes_data(probes): index=["x", "y", "z"]), "equipment": "test_rig"}), ]) -def test_add_eye_tracking_rig_geometry_data_to_nwbfile( - nwbfile, roundtripper, - roundtrip, - eye_tracking_rig_geometry, - expected): - nwbfile = write_nwb.add_eye_tracking_rig_geometry_data_to_nwbfile( - nwbfile, - eye_tracking_rig_geometry) +def test_add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile, + roundtripper, + roundtrip, + eye_tracking_rig_geom, + expected): + + nwbfile = \ + write_nwb.add_eye_tracking_rig_geometry_data_to_nwbfile( + nwbfile, + eye_tracking_rig_geom) if roundtrip: obt = roundtripper(nwbfile, EcephysNwbSessionApi) else: @@ -1059,129 +1163,110 @@ def test_add_eye_tracking_rig_geometry_data_to_nwbfile( obtained_metadata = obt.get_rig_metadata() pd.testing.assert_frame_equal(obtained_metadata["geometry"], - expected["geometry"], check_like=True) + expected["geometry"], + check_like=True) assert obtained_metadata["equipment"] == expected["equipment"] @pytest.mark.parametrize("roundtrip", [True, False]) @pytest.mark.parametrize(("eye_tracking_frame_times, eye_dlc_tracking_data, " "eye_gaze_data, expected_pupil_data, " - "expected_gaze_data"), - [ - ( - # eye_tracking_frame_times - pd.Series([3., 4., 5., 6., 7.]), - # eye_dlc_tracking_data - { - "pupil_params": - create_preload_eye_tracking_df( - np.full((5, 5), 1.)), - "cr_params": - create_preload_eye_tracking_df( - np.full((5, 5), 2.)), - "eye_params": - create_preload_eye_tracking_df( - np.full((5, 5), 3.))}, - # eye_gaze_data - {"raw_pupil_areas": pd.Series( - [2., 4., 6., 8., 10.]), - "raw_eye_areas": pd.Series( - [3., 5., 7., 9., 11.]), - "raw_screen_coordinates": - pd.DataFrame( - {"y": [2., 4., 6., 8., 10.], - "x": [3., 5., 7., 9., 11.]}), - "raw_screen_coordinates_spherical": - pd.DataFrame( - {"y": [2., 4., 6., 8., 10.], - "x": [3., 5., 7., 9., 11.]}), - "new_pupil_areas": pd.Series( - [2., 4., np.nan, 8., 10.]), - "new_eye_areas": pd.Series( - [3., 5., np.nan, 9., 11.]), - "new_screen_coordinates": - pd.DataFrame( - {"y": [2., 4., np.nan, 8., - 10.], - "x": [3., 5., np.nan, 9., - 11.]}), - "new_screen_coordinates_spherical": - pd.DataFrame( - {"y": [2., 4., np.nan, 8., - 10.], - "x": [3., 5., np.nan, 9., - 11.]}), - "synced_frame_timestamps": pd.Series( - [3., 4., 5., 6., 7.])}, - # expected_pupil_data - pd.DataFrame( - { - "corneal_reflection_center_x": - [2.] * 5, - "corneal_reflection_center_y": - [2.] * 5, - "corneal_reflection_height": - [4.] * 5, - "corneal_reflection_width": - [4.] * 5, - "corneal_reflection_phi": - [2.] * 5, - "pupil_center_x": [1.] * 5, - "pupil_center_y": [1.] * 5, - "pupil_height": [2.] * 5, - "pupil_width": [2.] * 5, - "pupil_phi": [1.] * 5, - "eye_center_x": [3.] * 5, - "eye_center_y": [3.] * 5, - "eye_height": [6.] * 5, - "eye_width": [6.] * 5, - "eye_phi": [3.] * 5}, - index=[3., 4., 5., 6., 7.]), - # expected_gaze_data - pd.DataFrame( - { - "raw_eye_area": - [3., 5., 7., 9., 11.], - "raw_pupil_area": - [2., 4., 6., 8., 10.], - "raw_screen_coordinates_x_cm": - [3., 5., 7., 9., 11.], - "raw_screen_coordinates_y_cm": - [2., 4., 6., 8., 10.], - "raw_screen_coordinates_" - "spherical_x_deg": - [3., 5., 7., 9., 11.], - "raw_screen_coordinates_" - "spherical_y_deg": - [2., 4., 6., 8., 10.], - "filtered_eye_area": - [3., 5., np.nan, 9., 11.], - "filtered_pupil_area": - [2., 4., np.nan, 8., 10.], - "filtered_screen_coordinates_" - "x_cm": [3., 5., np.nan, 9., 11.], - "filtered_screen_coordinates_" - "y_cm": [2., 4., np.nan, 8., 10.], - "filtered_screen_coordinates_" - "spherical_x_deg": - [3., 5., np.nan, 9., 11.], - "filtered_screen_coordinates_" - "spherical_y_deg": - [2., 4., np.nan, 8., 10.]}, - index=[3., 4., 5., 6., 7.]) - ), - ]) -def test_add_eye_tracking_data_to_nwbfile(nwbfile, roundtripper, roundtrip, + "expected_gaze_data"), [ + ( + # eye_tracking_frame_times + pd.Series([3., 4., 5., 6., 7.]), + # eye_dlc_tracking_data + {"pupil_params": create_preload_eye_tracking_df(np.full((5, 5), 1.)), + "cr_params": create_preload_eye_tracking_df(np.full((5, 5), 2.)), + "eye_params": create_preload_eye_tracking_df(np.full((5, 5), 3.))}, + # eye_gaze_data + {"raw_pupil_areas": pd.Series([2., 4., 6., 8., 10.]), + "raw_eye_areas": pd.Series([3., 5., 7., 9., 11.]), + "raw_screen_coordinates": pd.DataFrame( + {"y": [2., 4., 6., 8., 10.], + "x": [3., 5., 7., 9., 11.]}), + "raw_screen_coordinates_spherical": pd.DataFrame( + {"y": [2., 4., 6., 8., 10.], + "x": [3., 5., 7., 9., 11.]}), + "new_pupil_areas": pd.Series([2., 4., np.nan, 8., 10.]), + "new_eye_areas": pd.Series([3., 5., np.nan, 9., 11.]), + "new_screen_coordinates": pd.DataFrame( + {"y": [2., 4., np.nan, 8., 10.], + "x": [3., 5., np.nan, 9., 11.]}), + "new_screen_coordinates_spherical": pd.DataFrame({ + "y": [2., 4., np.nan, 8., 10.], + "x": [3., 5., np.nan, 9., 11.]}), + "synced_frame_timestamps": pd.Series([3., 4., 5., 6., 7.])}, + # expected_pupil_data + pd.DataFrame({"corneal_reflection_center_x": [2.] * 5, + "corneal_reflection_center_y": [2.] * 5, + "corneal_reflection_height": [4.] * 5, + "corneal_reflection_width": [4.] * 5, + "corneal_reflection_phi": [2.] * 5, + "pupil_center_x": [1.] * 5, + "pupil_center_y": [1.] * 5, + "pupil_height": [2.] * 5, + "pupil_width": [2.] * 5, + "pupil_phi": [1.] * 5, + "eye_center_x": [3.] * 5, + "eye_center_y": [3.] * 5, + "eye_height": [6.] * 5, + "eye_width": [6.] * 5, + "eye_phi": [3.] * 5}, + index=[3., 4., 5., 6., 7.]), + # expected_gaze_data + pd.DataFrame({"raw_eye_area": [3., 5., 7., 9., 11.], + "raw_pupil_area": [2., 4., 6., 8., 10.], + "raw_screen_coordinates_x_cm": [3., 5., 7., 9., 11.], + "raw_screen_coordinates_y_cm": [2., 4., 6., 8., 10.], + "raw_screen_coordinates_spherical_x_deg": [3., + 5., + 7., + 9., + 11.], + "raw_screen_coordinates_spherical_y_deg": [2., + 4., + 6., + 8., + 10.], + "filtered_eye_area": [3., 5., np.nan, 9., 11.], + "filtered_pupil_area": [2., 4., np.nan, 8., 10.], + "filtered_screen_coordinates_x_cm": [3., + 5., + np.nan, + 9., + 11.], + "filtered_screen_coordinates_y_cm": [2., + 4., + np.nan, + 8., + 10.], + "filtered_screen_coordinates_spherical_x_deg": [3., + 5., + np.nan, + 9., + 11.], + "filtered_screen_coordinates_spherical_y_deg": [2., + 4., + np.nan, + 8., + 10.]}, + index=[3., 4., 5., 6., 7.]) + ), +]) +def test_add_eye_tracking_data_to_nwbfile(nwbfile, + roundtripper, + roundtrip, eye_tracking_frame_times, eye_dlc_tracking_data, eye_gaze_data, expected_pupil_data, expected_gaze_data): - nwbfile = write_nwb.add_eye_tracking_data_to_nwbfile( - nwbfile, - eye_tracking_frame_times, - eye_dlc_tracking_data, - eye_gaze_data) + nwbfile = \ + write_nwb.add_eye_tracking_data_to_nwbfile(nwbfile, + eye_tracking_frame_times, + eye_dlc_tracking_data, + eye_gaze_data) if roundtrip: obt = roundtripper(nwbfile, EcephysNwbSessionApi) @@ -1189,7 +1274,8 @@ def test_add_eye_tracking_data_to_nwbfile(nwbfile, roundtripper, roundtrip, obt = EcephysNwbSessionApi.from_nwbfile(nwbfile) obtained_pupil_data = obt.get_pupil_data() obtained_screen_gaze_data = obt.get_screen_gaze_data( - include_filtered_data=True) + include_filtered_data=True + ) pd.testing.assert_frame_equal(obtained_pupil_data, expected_pupil_data, check_like=True) diff --git a/doc_template/index.rst b/doc_template/index.rst index 1dbae3937..7b2f9facf 100644 --- a/doc_template/index.rst +++ b/doc_template/index.rst @@ -123,6 +123,7 @@ What's New - 2.13.2 - Fixes bug that caused file paths on windows machines to be incorrect in Visual behavior user-facing classes - Updates to support MESO.2 - Loosens/updates required versions for several dependencies +- Updates in order to generate valid NWB files for Neuropixels Visual Coding data collected between 2019 and 2021 What's New - 2.13.1 -----------------------------------------------------------------------