From ec9fcf3350b8e5809a9968504d6e55ae0ab10fbd Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 30 Nov 2022 15:46:21 -0600 Subject: [PATCH 001/152] first prototype in separating `Clustering` into multiple steps --- element_array_ephys/spike_sorting/__init__.py | 0 .../spike_sorting/ecephys_spike_sorting.py | 250 ++++++++++++++++++ 2 files changed, 250 insertions(+) create mode 100644 element_array_ephys/spike_sorting/__init__.py create mode 100644 element_array_ephys/spike_sorting/ecephys_spike_sorting.py diff --git a/element_array_ephys/spike_sorting/__init__.py b/element_array_ephys/spike_sorting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py new file mode 100644 index 00000000..1dc71e7a --- /dev/null +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -0,0 +1,250 @@ +import datajoint as dj +from element_array_ephys import get_logger +from decimal import Decimal +import json +from datetime import datetime, timedelta + +from element_interface.utils import find_full_path +from element_array_ephys.readers import spikeglx, kilosort, openephys, kilosort_triggering + +log = get_logger(__name__) + +schema = dj.schema() + +ephys = None + + +def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): + """ + activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) + :param schema_name: schema name on the database server to activate the `spike_sorting` schema + :param ephys_schema_name: schema name of the activated ephys element for which this ephys_report schema will be downstream from + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) + """ + global ephys + ephys = dj.create_virtual_module("ephys", ephys_schema_name) + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=ephys.__dict__, + ) + + +@schema +class KilosortPreProcessing(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return (ephys.ClusteringTask * ephys.ClusteringParamSet + & {'task_mode': 'trigger'} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")') + + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort.run_CatGT() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['depth_estimation', 'median_subtraction'] + run_kilosort.run_modules() + + self.insert1({**key, + "params": params, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + + +@schema +class KilosortClustering(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + params = (KilosortPreProcessing & key).fetch1('params') + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ['kilosort_helper'] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['kilosort_helper'] + run_kilosort.run_modules() + + self.insert1({**key, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + + +@schema +class KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + assert clustering_method in ( + "kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + params = (KilosortPreProcessing & key).fetch1('params') + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ['kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + run_kilosort.run_modules() + + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + + self.insert1({**key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) From f5724384952f801086e751d3645437bea2694604 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 14:06:00 -0600 Subject: [PATCH 002/152] Update ecephys_spike_sorting.py --- .../spike_sorting/ecephys_spike_sorting.py | 145 ++++++++++++------ 1 file changed, 94 insertions(+), 51 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 1dc71e7a..1592e65e 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -5,7 +5,12 @@ from datetime import datetime, timedelta from element_interface.utils import find_full_path -from element_array_ephys.readers import spikeglx, kilosort, openephys, kilosort_triggering +from element_array_ephys.readers import ( + spikeglx, + kilosort, + openephys, + kilosort_triggering, +) log = get_logger(__name__) @@ -35,8 +40,8 @@ def activate(schema_name, ephys_schema_name, *, create_schema=True, create_table @schema class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> ephys.ClusteringTask --- @@ -47,9 +52,11 @@ class KilosortPreProcessing(dj.Imported): @property def key_source(self): - return (ephys.ClusteringTask * ephys.ClusteringParamSet - & {'task_mode': 'trigger'} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")') + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) def make(self, key): """Triggers or imports clustering analysis.""" @@ -62,7 +69,9 @@ def make(self, key): assert task_mode == "trigger", 'Supporting "trigger" task_mode only' if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) # update clustering_output_dir ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} @@ -71,10 +80,14 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method", "params") - assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + assert clustering_method in ( + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' # add additional probe-recording and channels details into `params` params = {**params, **ephys.get_recording_channels_details(key)} @@ -82,17 +95,19 @@ def make(self, key): if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") + run_CatGT = ( + params.pop("run_CatGT", True) + and "_tcat." not in spikeglx_meta_filepath.stem + ) run_kilosort = kilosort_triggering.SGLXKilosortPipeline( npx_input_dir=spikeglx_meta_filepath.parent, ks_output_dir=kilosort_dir, params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, + run_CatGT=run_CatGT, ) run_kilosort.run_CatGT() elif acq_software == "Open Ephys": @@ -107,19 +122,26 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['depth_estimation', 'median_subtraction'] + run_kilosort._modules = ["depth_estimation", "median_subtraction"] run_kilosort.run_modules() - self.insert1({**key, - "params": params, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> KilosortPreProcessing --- @@ -134,17 +156,19 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + assert clustering_method in ( + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' - params = (KilosortPreProcessing & key).fetch1('params') + params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_kilosort = kilosort_triggering.SGLXKilosortPipeline( @@ -154,7 +178,7 @@ def make(self, key): KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', run_CatGT=True, ) - run_kilosort._modules = ['kilosort_helper'] + run_kilosort._modules = ["kilosort_helper"] run_kilosort._CatGT_finished = True run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -169,18 +193,25 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['kilosort_helper'] + run_kilosort._modules = ["kilosort_helper"] run_kilosort.run_modules() - self.insert1({**key, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> KilosortClustering --- @@ -196,18 +227,19 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") assert clustering_method in ( - "kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' - params = (KilosortPreProcessing & key).fetch1('params') + params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_kilosort = kilosort_triggering.SGLXKilosortPipeline( @@ -217,10 +249,12 @@ def make(self, key): KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', run_CatGT=True, ) - run_kilosort._modules = ['kilosort_postprocessing', - 'noise_templates', - 'mean_waveforms', - 'quality_metrics'] + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] run_kilosort._CatGT_finished = True run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -235,16 +269,25 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['kilosort_postprocessing', - 'noise_templates', - 'mean_waveforms', - 'quality_metrics'] + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] run_kilosort.run_modules() with open(self._modules_input_hash_fp) as f: modules_status = json.load(f) - self.insert1({**key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) From edf1578b45425410c6cb53e5777866afa5f04f98 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 14:07:53 -0600 Subject: [PATCH 003/152] Update ecephys_spike_sorting.py --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 1592e65e..eb02c251 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -98,7 +98,7 @@ def make(self, key): spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_CatGT = ( - params.pop("run_CatGT", True) + params.get("run_CatGT", True) and "_tcat." not in spikeglx_meta_filepath.stem ) From 7e267c57571c3396ecdc60d159db0245326e4047 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 17:15:51 -0600 Subject: [PATCH 004/152] fix typo --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 943d3354..9414c49e 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -799,7 +799,7 @@ class Clustering(dj.Imported): Attributes: ClusteringTask (foreign key): ClusteringTask primary key. clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. + package_version (varchar(16): Package version used for a clustering analysis. """ definition = """ From 6e7ddf15966d51474455efdd758e99e93850b6b7 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 10:23:28 -0600 Subject: [PATCH 005/152] Update ecephys_spike_sorting.py --- .../spike_sorting/ecephys_spike_sorting.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index eb02c251..ed6d699e 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -19,17 +19,23 @@ ephys = None -def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): +def activate( + schema_name, + *, + ephys_module, + create_schema=True, + create_tables=True, +): """ activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_schema_name: schema name of the activated ephys element for which this ephys_report schema will be downstream from + :param ephys_module: the activated ephys element for which this ephys_report schema will be downstream from :param create_schema: when True (default), create schema in the database if it does not yet exist. :param create_tables: when True (default), create tables in the database if they do not yet exist. (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) """ global ephys - ephys = dj.create_virtual_module("ephys", ephys_schema_name) + ephys = ephys_module schema.activate( schema_name, create_schema=create_schema, From 7fd9bb4368208e12a809b49e402190e3d34eaa07 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 11:15:43 -0600 Subject: [PATCH 006/152] improve log messages --- element_array_ephys/readers/kilosort_triggering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py index 8e4b80ff..5d76c7af 100644 --- a/element_array_ephys/readers/kilosort_triggering.py +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -21,13 +21,13 @@ get_noise_channels, ) except Exception as e: - print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}') + print(f'Warning: Failed loading "ecephys_spike_sorting" package - {str(e)}') # import pykilosort package try: import pykilosort except Exception as e: - print(f'Error in loading "pykilosort" package - {str(e)}') + print(f'Warning: Failed loading "pykilosort" package - {str(e)}') class SGLXKilosortPipeline: From 26a56e72f76afa24833f275c01df0606e9d24ec2 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 11:58:01 -0600 Subject: [PATCH 007/152] fix key_source --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index ed6d699e..d7f7865a 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -62,7 +62,7 @@ def key_source(self): ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) + ) - ephys.Clustering def make(self, key): """Triggers or imports clustering analysis.""" From 654bc522ce8bba576b2793924f39d1a97416bb0f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 18 Jan 2023 11:58:13 -0600 Subject: [PATCH 008/152] Update ecephys_spike_sorting.py --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d7f7865a..0cf4bea8 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -297,3 +297,6 @@ def make(self, key): / 3600, } ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1({**key, "clustering_time": datetime.utcnow()}) From f75e14f13a2f413b2c34be401652ade70ce11a94 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Jan 2023 15:43:15 -0600 Subject: [PATCH 009/152] bugfix --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 0cf4bea8..3eca46b9 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -283,7 +283,7 @@ def make(self, key): ] run_kilosort.run_modules() - with open(self._modules_input_hash_fp) as f: + with open(run_kilosort._modules_input_hash_fp) as f: modules_status = json.load(f) self.insert1( From a32d1d25b895bc1c05f4149f920b206b07646aae Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Jan 2023 16:23:37 -0600 Subject: [PATCH 010/152] bugfix --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 3eca46b9..d33a3752 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -299,4 +299,6 @@ def make(self, key): ) # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1({**key, "clustering_time": datetime.utcnow()}) + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From 3bea7755245905dc59271a248514caad004c7f10 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 20 Jan 2023 16:50:35 -0600 Subject: [PATCH 011/152] Update kilosort_triggering.py --- element_array_ephys/readers/kilosort_triggering.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py index 7f30cac4..4e831d1b 100644 --- a/element_array_ephys/readers/kilosort_triggering.py +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -777,8 +777,7 @@ def _write_channel_map_file( # channels to exclude mask = get_noise_channels(ap_band_file, channel_count, sample_rate, bit_volts) - bad_channel_ind = np.where(mask is False)[0] - connected[bad_channel_ind] = 0 + connected = np.where(mask is False, 0, connected) mdict = { "chanMap": chanMap, From 4f955b32e8049c2961e9c916029b2d27f008476a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 23 Jan 2023 10:53:46 -0600 Subject: [PATCH 012/152] fix docstring --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d33a3752..cec8f7ac 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -29,10 +29,9 @@ def activate( """ activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this ephys_report schema will be downstream from + :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from :param create_schema: when True (default), create schema in the database if it does not yet exist. :param create_tables: when True (default), create tables in the database if they do not yet exist. - (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) """ global ephys ephys = ephys_module From 53854d0a986d564c8e75295f83da4dbd5446d92b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 23 Jan 2023 11:05:47 -0600 Subject: [PATCH 013/152] added description --- .../spike_sorting/ecephys_spike_sorting.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index cec8f7ac..d779d0c0 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -1,3 +1,26 @@ +""" +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the +"ecephys_spike_sorting" pipeline. +The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. +Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. + +At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. +https://github.com/datajoint-company/ecephys_spike_sorting + +The follow pipeline features three tables: +1. KilosortPreProcessing - for preprocessing steps (no GPU required) + - median_subtraction for Open Ephys + - or the CatGT step for SpikeGLX +2. KilosortClustering - kilosort (MATLAB) - requires GPU + - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) +3. KilosortPostProcessing - for postprocessing steps (no GPU required) + - kilosort_postprocessing + - noise_templates + - mean_waveforms + - quality_metrics +""" + + import datajoint as dj from element_array_ephys import get_logger from decimal import Decimal From fe9955ca4269b86958541440f1587bdecc5b0c1b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 25 Jan 2023 15:33:08 -0600 Subject: [PATCH 014/152] refactor `_supported_kilosort_versions` --- .../spike_sorting/ecephys_spike_sorting.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d779d0c0..fca4e452 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -41,6 +41,12 @@ ephys = None +_supported_kilosort_versions = [ + "kilosort2", + "kilosort2.5", + "kilosort3", +] + def activate( schema_name, @@ -111,11 +117,9 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method", "params") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' # add additional probe-recording and channels details into `params` params = {**params, **ephys.get_recording_channels_details(key)} @@ -186,11 +190,6 @@ def make(self, key): acq_software, clustering_method = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' params = (KilosortPreProcessing & key).fetch1("params") @@ -257,11 +256,6 @@ def make(self, key): acq_software, clustering_method = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' params = (KilosortPreProcessing & key).fetch1("params") From aea325d9bb6a975fd4e2c382f313b209a5be0017 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 25 Jan 2023 17:27:00 -0600 Subject: [PATCH 015/152] remove unused imports --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index fca4e452..4de349eb 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -30,8 +30,6 @@ from element_interface.utils import find_full_path from element_array_ephys.readers import ( spikeglx, - kilosort, - openephys, kilosort_triggering, ) From 4f648cc8054a6e237b971ab6c8cb4b88c6b0c568 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 1 Feb 2023 18:37:39 -0600 Subject: [PATCH 016/152] add new file for spike interface modularized clustering approach --- .../spike_sorting/si_clustering.py | 534 ++++++++++++++++++ 1 file changed, 534 insertions(+) create mode 100644 element_array_ephys/spike_sorting/si_clustering.py diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py new file mode 100644 index 00000000..32384d01 --- /dev/null +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -0,0 +1,534 @@ +""" +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the +"ecephys_spike_sorting" pipeline. +The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. +Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. + +At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. +https://github.com/datajoint-company/ecephys_spike_sorting + +The follow pipeline features intermediary tables: +1. KilosortPreProcessing - for preprocessing steps (no GPU required) + - median_subtraction for Open Ephys + - or the CatGT step for SpikeGLX +2. KilosortClustering - kilosort (MATLAB) - requires GPU + - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) +3. KilosortPostProcessing - for postprocessing steps (no GPU required) + - kilosort_postprocessing + - noise_templates + - mean_waveforms + - quality_metrics + + +""" +import datajoint as dj +import os +from element_array_ephys import get_logger +from decimal import Decimal +import json +import numpy as np +from datetime import datetime, timedelta + +from element_interface.utils import find_full_path +from element_array_ephys.readers import ( + spikeglx, + kilosort_triggering, +) +import element_array_ephys.ephys_no_curation as ephys +import element_array_ephys.probe as probe +# from element_array_ephys.ephys_no_curation import ( +# get_ephys_root_data_dir, +# get_session_directory, +# get_openephys_filepath, +# get_spikeglx_meta_filepath, +# get_recording_channels_details, +# ) +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.comparison as sc +import spikeinterface.widgets as sw +import spikeinterface.preprocessing as sip +import probeinterface as pi + +log = get_logger(__name__) + +schema = dj.schema() + +ephys = None + +_supported_kilosort_versions = [ + "kilosort2", + "kilosort2.5", + "kilosort3", +] + + +def activate( + schema_name, + *, + ephys_module, + create_schema=True, + create_tables=True, +): + """ + activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) + :param schema_name: schema name on the database server to activate the `spike_sorting` schema + :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + """ + global ephys + ephys = ephys_module + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=ephys.__dict__, + ) + +@schema +class SI_preprocessing(dj.Imported): + """A table to handle preprocessing of each clustering task.""" + + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) - ephys.Clustering + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + if acq_software == "SpikeGLX": + sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + sglx_filepath = ephys.get_spikeglx_meta_filepath(key) + stream_name = os.path.split(sglx_filepath)[1] + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # Create SI recording extractor object + # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) + sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * ephys.EphysRecording & key) + + xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + probe = pi.Probe(ndim=2, si_units='um') + probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + probe.create_auto_shape(probe_type='tip') + channel_indices = np.arange(channels_details['num_channels']) + probe.set_device_channel_indices(channel_indices) + oe_si_recording.set_probe(probe=probe) + + # run preprocessing and save results to output folder + sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) + sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + sglx_recording_cmr.save_to_folder('sglx_recording_cmr', kilosort_dir) + + + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + oe_full_path = find_full_path(get_ephys_root_data_dir(),get_session_directory(key)) + oe_filepath = get_openephys_filepath(key) + stream_name = os.path.split(oe_filepath)[1] + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # Create SI recording extractor object + # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = se.read_openephys(folder_path=oe_full_path, stream_name=stream_name) + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * ephys.EphysRecording & key) + + xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + channels_details = get_recording_channels_details(key) + + # Create SI probe object + probe = pi.Probe(ndim=2, si_units='um') + probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + probe.create_auto_shape(probe_type='tip') + channel_indices = np.arange(channels_details['num_channels']) + probe.set_device_channel_indices(channel_indices) + oe_si_recording.set_probe(probe=probe) + + # run preprocessing and save results to output folder + oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) + oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") + oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) +@schema +class SI_KilosortClustering(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + sglx_probe = ephys.get_openephys_probe_data(key) + oe_si_recording = se.load_from_folder + assert len(oe_probe.recording_info["recording_files"]) == 1 + if clustering_method.startswith('kilosort2.5'): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method + sorting_kilosort = si.run_sorter( + sorter_name = sorter_name, + recording = oe_si_recording, + output_folder = kilosort_dir, + docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + **params + ) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + oe_si_recording = se.load_from_folder + assert len(oe_probe.recording_info["recording_files"]) == 1 + if clustering_method.startswith('kilosort2.5'): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method + sorting_kilosort = si.run_sorter( + sorter_name = sorter_name, + recording = oe_si_recording, + output_folder = kilosort_dir, + docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + **params + ) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + + + + +@schema +class KilosortPreProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) - ephys.Clustering + + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + + + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + run_CatGT = ( + params.get("run_CatGT", True) + and "_tcat." not in spikeglx_meta_filepath.stem + ) + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=run_CatGT, + ) + run_kilosort.run_CatGT() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ["depth_estimation", "median_subtraction"] + run_kilosort.run_modules() + + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class KilosortClustering(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ["kilosort_helper"] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ["kilosort_helper"] + run_kilosort.run_modules() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] + run_kilosort.run_modules() + + with open(run_kilosort._modules_input_hash_fp) as f: + modules_status = json.load(f) + + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From 60091acad42b3081f3fbea53301d15993bc2e175 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 01:32:18 -0600 Subject: [PATCH 017/152] add spike interface clustering and post processing modules --- .../spike_sorting/si_clustering.py | 123 ++++++++++++++++-- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 32384d01..9ddddb75 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -44,7 +44,9 @@ # get_recording_channels_details, # ) import spikeinterface as si +import spikeinterface.core as sic import spikeinterface.extractors as se +import spikeinterface.exporters as sie import spikeinterface.sorters as ss import spikeinterface.comparison as sc import spikeinterface.widgets as sw @@ -88,7 +90,7 @@ def activate( ) @schema -class SI_preprocessing(dj.Imported): +class SI_PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -172,8 +174,8 @@ def make(self, key): elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_full_path = find_full_path(get_ephys_root_data_dir(),get_session_directory(key)) - oe_filepath = get_openephys_filepath(key) + oe_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + oe_filepath = ephys.get_openephys_filepath(key) stream_name = os.path.split(oe_filepath)[1] assert len(oe_probe.recording_info["recording_files"]) == 1 @@ -186,7 +188,7 @@ def make(self, key): * ephys.EphysRecording & key) xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] - channels_details = get_recording_channels_details(key) + channels_details = ephys.get_recording_channels_details(key) # Create SI probe object probe = pi.Probe(ndim=2, si_units='um') @@ -199,7 +201,8 @@ def make(self, key): # run preprocessing and save results to output folder oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") - oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) self.insert1( { @@ -217,7 +220,7 @@ class SI_KilosortClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> KilosortPreProcessing + -> SI_PreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -236,16 +239,18 @@ def make(self, key): params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": - sglx_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder - assert len(oe_probe.recording_info["recording_files"]) == 1 + # sglx_probe = ephys.get_openephys_probe_data(key) + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + # sglx_si_recording = se.load_from_folder(recording_file) + sglx_si_recording = sic.load_extractor(recording_file) + # assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" else: sorter_name = clustering_method sorting_kilosort = si.run_sorter( sorter_name = sorter_name, - recording = oe_si_recording, + recording = sglx_si_recording, output_folder = kilosort_dir, docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params @@ -253,7 +258,7 @@ def make(self, key): sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder + oe_si_recording = se.load_from_folder assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -266,7 +271,8 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir, n_jobs=-1, chunk_size=30000) + # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( { @@ -279,7 +285,100 @@ def make(self, key): } ) +@schema +class SI_KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> SI_KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + sorting_file = kilosort_dir / 'sorting_kilosort' + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sglx_si_recording = sic.load_extractor(recording_file) + sorting_kilosort = sic.load_extractor(sorting_file) + + we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) + unit_id0 = sorting_kilosort.unit_ids[0] + waveforms = we_kilosort.get_waveforms(unit_id0) + template = we_kilosort.get_template(unit_id0) + snrs = si.compute_snrs(we_kilosort) + + + # QC Metrics + si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", + "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + # ["firing_rate","snr","presence_ratio","isi_violation", + # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", + # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) + + we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) + + + + elif acq_software == "Open Ephys": + sorting_file = kilosort_dir / 'sorting_kilosort' + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sglx_si_recording = sic.load_extractor(recording_file) + sorting_kilosort = sic.load_extractor(sorting_file) + + we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) + unit_id0 = sorting_kilosort.unit_ids[0] + waveforms = we_kilosort.get_waveforms(unit_id0) + template = we_kilosort.get_template(unit_id0) + snrs = si.compute_snrs(we_kilosort) + + + # QC Metrics + si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", + "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + + we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) + + + + with open(run_kilosort._modules_input_hash_fp) as f: + modules_status = json.load(f) + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From bca5fa9593e7736548d253daae6ec0452bfec94e Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 01:47:15 -0600 Subject: [PATCH 018/152] edit typos --- .../spike_sorting/si_clustering.py | 261 +----------------- 1 file changed, 4 insertions(+), 257 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 9ddddb75..b3391f93 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -236,7 +236,8 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (KilosortPreProcessing & key).fetch1("params") + params = (SI_PreProcessing & key).fetch1("params") + if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -286,7 +287,7 @@ def make(self, key): ) @schema -class SI_KilosortPostProcessing(dj.Imported): +class SI_PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -307,7 +308,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (KilosortPreProcessing & key).fetch1("params") + params = (SI_PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' @@ -335,7 +336,6 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - elif acq_software == "Open Ephys": sorting_file = kilosort_dir / 'sorting_kilosort' recording_file = kilosort_dir / 'sglx_recording_cmr.json' @@ -358,8 +358,6 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - - with open(run_kilosort._modules_input_hash_fp) as f: modules_status = json.load(f) @@ -380,254 +378,3 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -@schema -class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> ephys.ClusteringTask - --- - params: longblob # finalized parameterset for this run - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - @property - def key_source(self): - return ( - ephys.ClusteringTask * ephys.ClusteringParamSet - & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering - - def make(self, key): - """Triggers or imports clustering analysis.""" - execution_time = datetime.utcnow() - - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' - - if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir( - key, relative=True, mkdir=True - ) - # update clustering_output_dir - ephys.ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - - - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.get("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_CatGT() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["depth_estimation", "median_subtraction"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "params": params, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortPreProcessing - --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortClustering - --- - modules_status: longblob # dictionary of summary status for all modules - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort.run_modules() - - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) - - self.insert1( - { - **key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1( - {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True - ) From 1c4b0b578f31b2779a65db8ef67a8738a6352ff1 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 14:49:47 -0600 Subject: [PATCH 019/152] removed module_status from table keys --- element_array_ephys/spike_sorting/si_clustering.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index b3391f93..d3c3bbf9 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -293,7 +293,6 @@ class SI_PostProcessing(dj.Imported): definition = """ -> SI_KilosortClustering --- - modules_status: longblob # dictionary of summary status for all modules execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -358,13 +357,10 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) self.insert1( { **key, - "modules_status": modules_status, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 56c9941f12072b3d572520f69ee999025117ffb5 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:16:47 -0600 Subject: [PATCH 020/152] remove _ from SI table names --- .../spike_sorting/si_clustering.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index d3c3bbf9..a72989ed 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -8,8 +8,8 @@ https://github.com/datajoint-company/ecephys_spike_sorting The follow pipeline features intermediary tables: -1. KilosortPreProcessing - for preprocessing steps (no GPU required) - - median_subtraction for Open Ephys +1. SIPreProcessing - for preprocessing steps (no GPU required) + - - or the CatGT step for SpikeGLX 2. KilosortClustering - kilosort (MATLAB) - requires GPU - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) @@ -90,7 +90,7 @@ def activate( ) @schema -class SI_PreProcessing(dj.Imported): +class SIPreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -168,8 +168,8 @@ def make(self, key): # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) - sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_recording_cmr.save_to_folder('sglx_recording_cmr', kilosort_dir) + # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + sglx_si_recording_filtered.save_to_folder('sglx_si_recording_filtered', kilosort_dir) elif acq_software == "Open Ephys": @@ -216,7 +216,7 @@ def make(self, key): } ) @schema -class SI_KilosortClustering(dj.Imported): +class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -236,8 +236,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SI_PreProcessing & key).fetch1("params") - + params = (SIPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -287,11 +286,11 @@ def make(self, key): ) @schema -class SI_PostProcessing(dj.Imported): +class SIPostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SI_KilosortClustering + -> SIClustering --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -307,7 +306,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SI_PreProcessing & key).fetch1("params") + params = (SIPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' From ce14098041a5292ec8dd9abd9776074d975ad3b2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:23:56 -0600 Subject: [PATCH 021/152] bugfix --- element_array_ephys/spike_sorting/si_clustering.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a72989ed..97a7dd53 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -202,7 +202,8 @@ def make(self, key): oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) - oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + oe_si_recording_filtered.save_to_folder('', kilosort_dir) self.insert1( { @@ -220,7 +221,7 @@ class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SI_PreProcessing + -> SIPreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration From 7c836f12fd1d47e5b7cf15435eaa19ad1faa7ae4 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:33:52 -0600 Subject: [PATCH 022/152] change si related table names --- element_array_ephys/spike_sorting/si_clustering.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 97a7dd53..d97066a6 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -90,7 +90,7 @@ def activate( ) @schema -class SIPreProcessing(dj.Imported): +class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -164,7 +164,7 @@ def make(self, key): probe.create_auto_shape(probe_type='tip') channel_indices = np.arange(channels_details['num_channels']) probe.set_device_channel_indices(channel_indices) - oe_si_recording.set_probe(probe=probe) + sglx_si_recording.set_probe(probe=probe) # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) @@ -217,7 +217,7 @@ def make(self, key): } ) @schema -class SIClustering(dj.Imported): +class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -237,7 +237,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SIPreProcessing & key).fetch1("params") + params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -287,11 +287,11 @@ def make(self, key): ) @schema -class SIPostProcessing(dj.Imported): +class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SIClustering + -> ClusteringModule --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -307,7 +307,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SIPreProcessing & key).fetch1("params") + params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' From dd6366498d1a4bd974803b89abdfd7ab30a96623 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:38:49 -0600 Subject: [PATCH 023/152] bugfix --- element_array_ephys/spike_sorting/si_clustering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index d97066a6..cabe5c25 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -221,7 +221,7 @@ class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SIPreProcessing + -> PreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration From 54888ed7c747bf5452e90d590641824bd70684f1 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 18:02:46 -0600 Subject: [PATCH 024/152] update initial comment --- .../spike_sorting/si_clustering.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cabe5c25..e69855f1 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -1,22 +1,20 @@ """ The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"ecephys_spike_sorting" pipeline. -The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. -Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. +"spikeinterface" pipeline. +Spikeinterface developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) -At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. -https://github.com/datajoint-company/ecephys_spike_sorting +The DataJoint pipeline currently incorporated Spikeinterfaces approach of running Kilosort using a container The follow pipeline features intermediary tables: -1. SIPreProcessing - for preprocessing steps (no GPU required) - - - - or the CatGT step for SpikeGLX -2. KilosortClustering - kilosort (MATLAB) - requires GPU +1. PreProcessing - for preprocessing steps (no GPU required) + - create recording extractor and link it to a probe + - bandpass filtering + - common mode referencing +2. ClusteringModule - kilosort (MATLAB) - requires GPU and docker/singularity containers - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. KilosortPostProcessing - for postprocessing steps (no GPU required) - - kilosort_postprocessing - - noise_templates - - mean_waveforms +3. PostProcessing - for postprocessing steps (no GPU required) + - create waveform extractor object + - extract templates, waveforms and snrs - quality_metrics @@ -48,8 +46,6 @@ import spikeinterface.extractors as se import spikeinterface.exporters as sie import spikeinterface.sorters as ss -import spikeinterface.comparison as sc -import spikeinterface.widgets as sw import spikeinterface.preprocessing as sip import probeinterface as pi From f804233b63a74ad935b953c3051ad74de0a4f4b2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 8 Feb 2023 19:17:51 -0600 Subject: [PATCH 025/152] fix preprocessing file loading issues --- .../spike_sorting/si_clustering.py | 64 ++++++++----------- 1 file changed, 27 insertions(+), 37 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e69855f1..e6e129bc 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -32,15 +32,8 @@ spikeglx, kilosort_triggering, ) -import element_array_ephys.ephys_no_curation as ephys import element_array_ephys.probe as probe -# from element_array_ephys.ephys_no_curation import ( -# get_ephys_root_data_dir, -# get_session_directory, -# get_openephys_filepath, -# get_spikeglx_meta_filepath, -# get_recording_channels_details, -# ) + import spikeinterface as si import spikeinterface.core as sic import spikeinterface.extractors as se @@ -92,6 +85,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- + file_name: varchar(60) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -137,30 +131,27 @@ def make(self, key): params = {**params, **ephys.get_recording_channels_details(key)} params["fs"] = params["sample_rate"] + if acq_software == "SpikeGLX": sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) stream_name = os.path.split(sglx_filepath)[1] - assert len(oe_probe.recording_info["recording_files"]) == 1 + # assert len(oe_probe.recording_info["recording_files"]) == 1 # Create SI recording extractor object # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) - electrode_query = (probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * ephys.EphysRecording & key) - xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] channels_details = ephys.get_recording_channels_details(key) # Create SI probe object - probe = pi.Probe(ndim=2, si_units='um') - probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) - probe.create_auto_shape(probe_type='tip') - channel_indices = np.arange(channels_details['num_channels']) - probe.set_device_channel_indices(channel_indices) - sglx_si_recording.set_probe(probe=probe) + si_probe = pi.Probe(ndim=2, si_units='um') + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.create_auto_shape(probe_type='tip') + si_probe.set_device_channel_indices(channels_details['channel_ind']) + sglx_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) @@ -170,29 +161,25 @@ def make(self, key): elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - oe_filepath = ephys.get_openephys_filepath(key) - stream_name = os.path.split(oe_filepath)[1] - + oe_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + assert len(oe_probe.recording_info["recording_files"]) == 1 + stream_name = os.path.split(oe_probe.recording_info['recording_files'][0])[1] # Create SI recording extractor object # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys(folder_path=oe_full_path, stream_name=stream_name) - electrode_query = (probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * ephys.EphysRecording & key) + oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) - xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + channels_details = ephys.get_recording_channels_details(key) # Create SI probe object - probe = pi.Probe(ndim=2, si_units='um') - probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) - probe.create_auto_shape(probe_type='tip') - channel_indices = np.arange(channels_details['num_channels']) - probe.set_device_channel_indices(channel_indices) - oe_si_recording.set_probe(probe=probe) + si_probe = pi.Probe(ndim=2, si_units='um') + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.create_auto_shape(probe_type='tip') + si_probe.set_device_channel_indices(channels_details['channel_ind']) + oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) @@ -219,8 +206,10 @@ class ClusteringModule(dj.Imported): definition = """ -> PreProcessing --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + recording_file: varchar(60) # filename of saved recording object + sorting_file: varchar(60) # filename of saved sorting object + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration """ def make(self, key): @@ -234,10 +223,11 @@ def make(self, key): ).fetch1("acq_software", "clustering_method") params = (PreProcessing & key).fetch1("params") + file_name = (PreProcessing & key).fetch1("file_name") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) - recording_file = kilosort_dir / 'sglx_recording_cmr.json' + recording_file = kilosort_dir / file_name # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_file) # assert len(oe_probe.recording_info["recording_files"]) == 1 From 8e1b73dd5013b9eb6da542678726fea800a3dcf6 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Thu, 9 Feb 2023 15:58:12 -0600 Subject: [PATCH 026/152] set file saving and file loading to pickle format --- element_array_ephys/spike_sorting/si_clustering.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e6e129bc..cf909725 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -186,11 +186,14 @@ def make(self, key): oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) - oe_si_recording_filtered.save_to_folder('', kilosort_dir) + save_file_name = 'si_recording.pkl' + save_file_path = kilosort_dir / save_file_name + oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) self.insert1( { **key, + "file_name": save_file_name, "params": params, "execution_time": execution_time, "execution_duration": ( From f0b7497e7b173396a35efbd9a1545981a095d96c Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 10 Feb 2023 18:09:40 -0600 Subject: [PATCH 027/152] sglx preprocessing modifications --- .../spike_sorting/si_clustering.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cf909725..6ec4b6e2 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -34,7 +34,7 @@ ) import element_array_ephys.probe as probe -import spikeinterface as si +import spikeinterface.full as si import spikeinterface.core as sic import spikeinterface.extractors as se import spikeinterface.exporters as sie @@ -85,7 +85,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- - file_name: varchar(60) # filename where recording object is saved to + recording_filename: varchar(60) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -133,22 +133,19 @@ def make(self, key): if acq_software == "SpikeGLX": - sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) - stream_name = os.path.split(sglx_filepath)[1] - - # assert len(oe_probe.recording_info["recording_files"]) == 1 # Create SI recording extractor object - # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) - sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) - - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + channels_details = ephys.get_recording_channels_details(key) + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + # Create SI probe object si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) si_probe.create_auto_shape(probe_type='tip') si_probe.set_device_channel_indices(channels_details['channel_ind']) sglx_si_recording.set_probe(probe=si_probe) @@ -156,7 +153,10 @@ def make(self, key): # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_si_recording_filtered.save_to_folder('sglx_si_recording_filtered', kilosort_dir) + + save_file_name = 'si_recording.pkl' + save_file_path = kilosort_dir / save_file_name + sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) elif acq_software == "Open Ephys": @@ -170,22 +170,21 @@ def make(self, key): # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) + channels_details = ephys.get_recording_channels_details(key) xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - channels_details = ephys.get_recording_channels_details(key) - # Create SI probe object si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) si_probe.create_auto_shape(probe_type='tip') si_probe.set_device_channel_indices(channels_details['channel_ind']) oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder + # Switch case to allow for specified preprocessing steps oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") - # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) - # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + save_file_name = 'si_recording.pkl' save_file_path = kilosort_dir / save_file_name oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) @@ -193,7 +192,7 @@ def make(self, key): self.insert1( { **key, - "file_name": save_file_name, + "recording_filename": save_file_name, "params": params, "execution_time": execution_time, "execution_duration": ( @@ -202,15 +201,14 @@ def make(self, key): / 3600, } ) -@schema + @schema class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ -> PreProcessing --- - recording_file: varchar(60) # filename of saved recording object - sorting_file: varchar(60) # filename of saved sorting object + sorting_filename: varchar(60) # filename of saved sorting object execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -226,13 +224,13 @@ def make(self, key): ).fetch1("acq_software", "clustering_method") params = (PreProcessing & key).fetch1("params") - file_name = (PreProcessing & key).fetch1("file_name") + recording_filename = (PreProcessing & key).fetch1("recording_filename") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) - recording_file = kilosort_dir / file_name + recording_fullpath = kilosort_dir / recording_filename # sglx_si_recording = se.load_from_folder(recording_file) - sglx_si_recording = sic.load_extractor(recording_file) + sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -245,10 +243,11 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder + oe_si_recording = sic.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -261,7 +260,8 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir, n_jobs=-1, chunk_size=30000) + sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort.dump_to_pickle(sorting_save_path) # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( @@ -363,3 +363,7 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + + +def preProcessing_switch(preprocess_list): + \ No newline at end of file From 13fe31c49fa5e5286c4ef916605bf017a33a9d87 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 13 Feb 2023 17:48:12 -0600 Subject: [PATCH 028/152] sglx testing progress --- element_array_ephys/spike_sorting/si_clustering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 6ec4b6e2..08ff86bd 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -128,8 +128,8 @@ def make(self, key): ), f'Clustering_method "{clustering_method}" is not supported' # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] + # params = {**params, **ephys.get_recording_channels_details(key)} + # params["fs"] = params["sample_rate"] if acq_software == "SpikeGLX": From d41c7f3c6d0fb2b0c97578d6c8cd28ab1b6a2832 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 14 Feb 2023 18:46:54 -0600 Subject: [PATCH 029/152] wip parametrize preprocessing --- .../spike_sorting/si_clustering.py | 117 +++++++++++++++--- 1 file changed, 99 insertions(+), 18 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 08ff86bd..f0144ea6 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -131,6 +131,46 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] + + preprocess_list = params.pop('PreProcessing_params') + + # If else + if preprocess_list['Filter']: + oe_si_recording = sip.FilterRecording(oe_si_recording) + elif preprocess_list['BandpassFilter']: + oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) + elif preprocess_list['HighpassFilter']: + oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) + elif preprocess_list['NormalizeByQuantile']: + oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) + elif preprocess_list['Scale']: + oe_si_recording = sip.ScaleRecording(oe_si_recording) + elif preprocess_list['Center']: + oe_si_recording = sip.CenterRecording(oe_si_recording) + elif preprocess_list['ZScore']: + oe_si_recording = sip.ZScoreRecording(oe_si_recording) + elif preprocess_list['Whiten']: + oe_si_recording = sip.WhitenRecording(oe_si_recording) + elif preprocess_list['CommonReference']: + oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) + elif preprocess_list['PhaseShift']: + oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) + elif preprocess_list['Rectify']: + oe_si_recording = sip.RectifyRecording(oe_si_recording) + elif preprocess_list['Clip']: + oe_si_recording = sip.ClipRecording(oe_si_recording) + elif preprocess_list['BlankSaturation']: + oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) + elif preprocess_list['RemoveArtifacts']: + oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) + elif preprocess_list['RemoveBadChannels']: + oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) + elif preprocess_list['ZeroChannelPad']: + oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) + elif preprocess_list['DeepInterpolation']: + oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) + elif preprocess_list['Resample']: + oe_si_recording = sip.ResampleRecording(oe_si_recording) if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) @@ -232,17 +272,23 @@ def make(self, key): # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 + + ## Assume that the worker process will trigger this sorting step + # - Will need to store/load the sorter_name, sglx_si_recording object etc. + # - Store in shared EC2 space accessible by all containers (needs to be mounted) + # - Load into the cloud init script, and + # - Option A: Can call this function within a separate container within spike_sorting_worker if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( - sorter_name = sorter_name, - recording = sglx_si_recording, - output_folder = kilosort_dir, - docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - **params - ) + # sorting_kilosort = si.run_sorter( + # sorter_name = sorter_name, + # recording = sglx_si_recording, + # output_folder = kilosort_dir, + # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + # **params + # ) sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": @@ -253,13 +299,13 @@ def make(self, key): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( - sorter_name = sorter_name, - recording = oe_si_recording, - output_folder = kilosort_dir, - docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - **params - ) + # sorting_kilosort = si.run_sorter( + # sorter_name = sorter_name, + # recording = oe_si_recording, + # output_folder = kilosort_dir, + # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + # **params + # ) sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' sorting_kilosort.dump_to_pickle(sorting_save_path) # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) @@ -363,7 +409,42 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -def preProcessing_switch(preprocess_list): - \ No newline at end of file +## Example SI parameter set +''' +{'detect_threshold': 6, + 'projection_threshold': [10, 4], + 'preclust_threshold': 8, + 'car': True, + 'minFR': 0.02, + 'minfr_goodchannels': 0.1, + 'nblocks': 5, + 'sig': 20, + 'freq_min': 150, + 'sigmaMask': 30, + 'nPCs': 3, + 'ntbuff': 64, + 'nfilt_factor': 4, + 'NT': None, + 'do_correction': True, + 'wave_length': 61, + 'keep_good_only': False, + 'PreProcessing_params': {'Filter': False, + 'BandpassFilter': True, + 'HighpassFilter': False, + 'NotchFilter': False, + 'NormalizeByQuantile': False, + 'Scale': False, + 'Center': False, + 'ZScore': False, + 'Whiten': False, + 'CommonReference': False, + 'PhaseShift': False, + 'Rectify': False, + 'Clip': False, + 'BlankSaturation': False, + 'RemoveArtifacts': False, + 'RemoveBadChannels': False, + 'ZeroChannelPad': False, + 'DeepInterpolation': False, + 'Resample': False}} +''' \ No newline at end of file From 109a71ad4e8c884659bee146dda2afdbfe4fa77a Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 17 Feb 2023 19:34:53 -0600 Subject: [PATCH 030/152] post processing waveform extractor extensions --- .../spike_sorting/si_clustering.py | 282 +++++++++++------- 1 file changed, 178 insertions(+), 104 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f0144ea6..13f129d8 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -34,6 +34,7 @@ ) import element_array_ephys.probe as probe +import spikeinterface import spikeinterface.full as si import spikeinterface.core as sic import spikeinterface.extractors as se @@ -78,6 +79,7 @@ def activate( add_objects=ephys.__dict__, ) + @schema class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" @@ -85,7 +87,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- - recording_filename: varchar(60) # filename where recording object is saved to + recording_filename: varchar(30) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -98,6 +100,7 @@ def key_source(self): & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' ) - ephys.Clustering + def make(self, key): """Triggers or imports clustering analysis.""" execution_time = datetime.utcnow() @@ -131,101 +134,121 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - - preprocess_list = params.pop('PreProcessing_params') + preprocess_list = params.pop("PreProcessing_params") - # If else - if preprocess_list['Filter']: + # If else + # need to figure out ordering + if preprocess_list["Filter"]: oe_si_recording = sip.FilterRecording(oe_si_recording) - elif preprocess_list['BandpassFilter']: + elif preprocess_list["BandpassFilter"]: oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) - elif preprocess_list['HighpassFilter']: + elif preprocess_list["HighpassFilter"]: oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) - elif preprocess_list['NormalizeByQuantile']: + elif preprocess_list["NormalizeByQuantile"]: oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) - elif preprocess_list['Scale']: + elif preprocess_list["Scale"]: oe_si_recording = sip.ScaleRecording(oe_si_recording) - elif preprocess_list['Center']: + elif preprocess_list["Center"]: oe_si_recording = sip.CenterRecording(oe_si_recording) - elif preprocess_list['ZScore']: + elif preprocess_list["ZScore"]: oe_si_recording = sip.ZScoreRecording(oe_si_recording) - elif preprocess_list['Whiten']: + elif preprocess_list["Whiten"]: oe_si_recording = sip.WhitenRecording(oe_si_recording) - elif preprocess_list['CommonReference']: + elif preprocess_list["CommonReference"]: oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) - elif preprocess_list['PhaseShift']: + elif preprocess_list["PhaseShift"]: oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) - elif preprocess_list['Rectify']: + elif preprocess_list["Rectify"]: oe_si_recording = sip.RectifyRecording(oe_si_recording) - elif preprocess_list['Clip']: + elif preprocess_list["Clip"]: oe_si_recording = sip.ClipRecording(oe_si_recording) - elif preprocess_list['BlankSaturation']: + elif preprocess_list["BlankSaturation"]: oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) - elif preprocess_list['RemoveArtifacts']: + elif preprocess_list["RemoveArtifacts"]: oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) - elif preprocess_list['RemoveBadChannels']: + elif preprocess_list["RemoveBadChannels"]: oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) - elif preprocess_list['ZeroChannelPad']: + elif preprocess_list["ZeroChannelPad"]: oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) - elif preprocess_list['DeepInterpolation']: + elif preprocess_list["DeepInterpolation"]: oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) - elif preprocess_list['Resample']: + elif preprocess_list["Resample"]: oe_si_recording = sip.ResampleRecording(oe_si_recording) - + if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) # Create SI recording extractor object - sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) - + sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + channels_details = ephys.get_recording_channels_details(key) - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) - si_probe.create_auto_shape(probe_type='tip') - si_probe.set_device_channel_indices(channels_details['channel_ind']) + xy_coords = [ + list(i) + for i in zip(channels_details["x_coords"], channels_details["y_coords"]) + ] + + # Create SI probe object + si_probe = pi.Probe(ndim=2, si_units="um") + si_probe.set_contacts( + positions=xy_coords, shapes="square", shape_params={"width": 12} + ) + si_probe.create_auto_shape(probe_type="tip") + si_probe.set_device_channel_indices(channels_details["channel_ind"]) sglx_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder - sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) + sglx_si_recording_filtered = sip.bandpass_filter( + sglx_si_recording, freq_min=300, freq_max=6000 + ) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - save_file_name = 'si_recording.pkl' + save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - + oe_session_full_path = find_full_path( + ephys.get_ephys_root_data_dir(), ephys.get_session_directory(key) + ) + assert len(oe_probe.recording_info["recording_files"]) == 1 - stream_name = os.path.split(oe_probe.recording_info['recording_files'][0])[1] + stream_name = os.path.split(oe_probe.recording_info["recording_files"][0])[ + 1 + ] # Create SI recording extractor object - # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) + # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = se.read_openephys( + folder_path=oe_session_full_path, stream_name=stream_name + ) channels_details = ephys.get_recording_channels_details(key) - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) - si_probe.create_auto_shape(probe_type='tip') - si_probe.set_device_channel_indices(channels_details['channel_ind']) + xy_coords = [ + list(i) + for i in zip(channels_details["x_coords"], channels_details["y_coords"]) + ] + + # Create SI probe object + si_probe = pi.Probe(ndim=2, si_units="um") + si_probe.set_contacts( + positions=xy_coords, shapes="square", shape_params={"width": 12} + ) + si_probe.create_auto_shape(probe_type="tip") + si_probe.set_device_channel_indices(channels_details["channel_ind"]) oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder # Switch case to allow for specified preprocessing steps - oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) - oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") + oe_si_recording_filtered = sip.bandpass_filter( + oe_si_recording, freq_min=300, freq_max=6000 + ) + oe_recording_cmr = sip.common_reference( + oe_si_recording_filtered, reference="global", operator="median" + ) - save_file_name = 'si_recording.pkl' + save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) @@ -240,15 +263,17 @@ def make(self, key): ).total_seconds() / 3600, } - ) - @schema -class ClusteringModule(dj.Imported): + ) + + +@schema +class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ -> PreProcessing --- - sorting_filename: varchar(60) # filename of saved sorting object + sorting_filename: varchar(30) # filename of saved sorting object execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -263,56 +288,56 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (PreProcessing & key).fetch1("params") - recording_filename = (PreProcessing & key).fetch1("recording_filename") + params = (PreProcessing & key).fetch1("params") + recording_filename = (PreProcessing & key).fetch1("recording_filename") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = se.load_from_folder(recording_file) + # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 ## Assume that the worker process will trigger this sorting step - # - Will need to store/load the sorter_name, sglx_si_recording object etc. + # - Will need to store/load the sorter_name, sglx_si_recording object etc. # - Store in shared EC2 space accessible by all containers (needs to be mounted) - # - Load into the cloud init script, and + # - Load into the cloud init script, and # - Option A: Can call this function within a separate container within spike_sorting_worker - if clustering_method.startswith('kilosort2.5'): + if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - # sorting_kilosort = si.run_sorter( - # sorter_name = sorter_name, - # recording = sglx_si_recording, - # output_folder = kilosort_dir, - # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - # **params - # ) - sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort = si.run_sorter( + sorter_name=sorter_name, + recording=sglx_si_recording, + output_folder=kilosort_dir, + docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", + **params, + ) + sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = sic.load_extractor(recording_fullpath) + oe_si_recording = sic.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 - if clustering_method.startswith('kilosort2.5'): + if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - # sorting_kilosort = si.run_sorter( - # sorter_name = sorter_name, - # recording = oe_si_recording, - # output_folder = kilosort_dir, - # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - # **params - # ) - sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort = si.run_sorter( + sorter_name=sorter_name, + recording=oe_si_recording, + output_folder=kilosort_dir, + docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", + **params, + ) + sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" sorting_kilosort.dump_to_pickle(sorting_save_path) - # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( { **key, + "sorting_filename": list(sorting_save_path.parts)[-1], "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time @@ -321,6 +346,7 @@ def make(self, key): } ) + @schema class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" @@ -345,53 +371,100 @@ def make(self, key): params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": - sorting_file = kilosort_dir / 'sorting_kilosort' - recording_file = kilosort_dir / 'sglx_recording_cmr.json' - sglx_si_recording = sic.load_extractor(recording_file) + recording_filename = (PreProcessing & key).fetch1("recording_filename") + sorting_file = kilosort_dir / "sorting_kilosort" + filtered_recording_file = kilosort_dir / recording_filename + sglx_si_recording_filtered = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort = si.WaveformExtractor.create( + sglx_si_recording_filtered, + sorting_kilosort, + "waveforms", + remove_if_exists=True, + ) + we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) snrs = si.compute_snrs(we_kilosort) - - # QC Metrics - si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", - "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + # QC Metrics + ( + si_violations_ratio, + isi_violations_rate, + isi_violations_count, + ) = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics( + we_kilosort, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) # ["firing_rate","snr","presence_ratio","isi_violation", # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) - - we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - + we_savedir = kilosort_dir / "we_kilosort" + we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) elif acq_software == "Open Ephys": - sorting_file = kilosort_dir / 'sorting_kilosort' - recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sorting_file = kilosort_dir / "sorting_kilosort" + recording_file = kilosort_dir / "sglx_recording_cmr.json" sglx_si_recording = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort = si.WaveformExtractor.create( + sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True + ) + we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) snrs = si.compute_snrs(we_kilosort) - - # QC Metrics - si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", - "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + # QC Metrics + # Apply waveform extractor extensions + spike_locations = si.compute_spike_locations(we_kilosort) + spike_amplitudes = si.compute_spike_amplitudes(we_kilosort) + unit_locations = si.compute_unit_locations(we_kilosort) + template_metrics = si.compute_template_metrics(we_kilosort) + noise_levels = si.compute_noise_levels(we_kilosort) + drift_metrics = si.compute_drift_metrics(we_kilosort) + + (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( + we_kilosort, isi_threshold_ms=1.5 + ) + (isi_histograms, bins) = si.compute_isi_histograms(we_kilosort) + metrics = si.compute_quality_metrics( + we_kilosort, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + # "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) - we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - + we_kilosort.save("we_kilosort", kilosort_dir, n_jobs=-1, chunk_size=30000) self.insert1( { @@ -409,8 +482,9 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + ## Example SI parameter set -''' +""" {'detect_threshold': 6, 'projection_threshold': [10, 4], 'preclust_threshold': 8, @@ -447,4 +521,4 @@ def make(self, key): 'ZeroChannelPad': False, 'DeepInterpolation': False, 'Resample': False}} -''' \ No newline at end of file +""" From 1febd7e05111aa19679c64f86947973e0b533ebf Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 17 Feb 2023 19:37:12 -0600 Subject: [PATCH 031/152] post processing waveform extractor extensions --- element_array_ephys/spike_sorting/si_clustering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 13f129d8..a018119d 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -441,8 +441,9 @@ def make(self, key): unit_locations = si.compute_unit_locations(we_kilosort) template_metrics = si.compute_template_metrics(we_kilosort) noise_levels = si.compute_noise_levels(we_kilosort) + pcs = si.compute_principal_components(we_kilosort) drift_metrics = si.compute_drift_metrics(we_kilosort) - + template_similarity = si.compute_tempoate_similarity(we_kilosort) (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( we_kilosort, isi_threshold_ms=1.5 ) From a478e0679ee5ae3747a324880415f090791cb868 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 20 Feb 2023 16:21:23 -0600 Subject: [PATCH 032/152] Fix data loading bug related to cluster_groups and KSLabel df key --- element_array_ephys/readers/kilosort.py | 42 +++++++++++++------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py index abddee74..e88ba335 100644 --- a/element_array_ephys/readers/kilosort.py +++ b/element_array_ephys/readers/kilosort.py @@ -1,19 +1,16 @@ -import logging -import pathlib -import re -from datetime import datetime from os import path - -import numpy as np +from datetime import datetime +import pathlib import pandas as pd - +import numpy as np +import re +import logging from .utils import convert_to_number log = logging.getLogger(__name__) class Kilosort: - _kilosort_core_files = [ "params.py", "amplitudes.npy", @@ -118,7 +115,8 @@ def _load(self): # Read the Cluster Groups for cluster_pattern, cluster_col_name in zip( - ["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"] + ["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"], + ["group", "KSLabel", "KSLabel"], ): try: cluster_file = next(self._kilosort_dir.glob(cluster_pattern)) @@ -127,22 +125,26 @@ def _load(self): else: cluster_file_suffix = cluster_file.suffix assert cluster_file_suffix in (".tsv", ".xlsx") - break + + if cluster_file_suffix == ".tsv": + df = pd.read_csv(cluster_file, sep="\t", header=0) + elif cluster_file_suffix == ".xlsx": + df = pd.read_excel(cluster_file, engine="openpyxl") + else: + df = pd.read_csv(cluster_file, delimiter="\t") + + try: + self._data["cluster_groups"] = np.array(df[cluster_col_name].values) + self._data["cluster_ids"] = np.array(df["cluster_id"].values) + except KeyError: + continue + else: + break else: raise FileNotFoundError( 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' ) - if cluster_file_suffix == ".tsv": - df = pd.read_csv(cluster_file, sep="\t", header=0) - elif cluster_file_suffix == ".xlsx": - df = pd.read_excel(cluster_file, engine="openpyxl") - else: - df = pd.read_csv(cluster_file, delimiter="\t") - - self._data["cluster_groups"] = np.array(df[cluster_col_name].values) - self._data["cluster_ids"] = np.array(df["cluster_id"].values) - def get_best_channel(self, unit): template_idx = self.data["spike_templates"][ np.where(self.data["spike_clusters"] == unit)[0][0] From 634761ddad17bfd024b84567881499fba5e2d46e Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 21 Feb 2023 18:13:02 -0600 Subject: [PATCH 033/152] waveform extraction wip --- element_array_ephys/ephys_no_curation.py | 109 ++++++++++-------- .../spike_sorting/si_clustering.py | 4 +- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index f4ed4b55..69afaea2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1,17 +1,17 @@ -import gc -import importlib -import inspect +import datajoint as dj import pathlib import re -from decimal import Decimal - -import datajoint as dj import numpy as np +import inspect +import importlib +import gc +from decimal import Decimal import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from . import ephys_report, get_logger, probe -from .readers import kilosort, openephys, spikeglx +from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid +from .readers import spikeglx, kilosort, openephys +from element_array_ephys import probe, get_logger, ephys_report + log = get_logger(__name__) @@ -19,6 +19,9 @@ _linking_module = None +import spikeinterface +import spikeinterface.full as si + def activate( ephys_schema_name: str, @@ -32,7 +35,7 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. + probe_schema_name (str): A string containing the name of the probe scehma. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -129,7 +132,7 @@ class AcquisitionSoftware(dj.Lookup): """ definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) + acq_software: varchar(24) """ contents = zip(["SpikeGLX", "Open Ephys"]) @@ -272,11 +275,11 @@ class EphysRecording(dj.Imported): definition = """ # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion + -> ProbeInsertion --- -> probe.ElectrodeConfig -> AcquisitionSoftware - sampling_rate: float # (Hz) + sampling_rate: float # (Hz) recording_datetime: datetime # datetime of the recording from this probe recording_duration: float # (seconds) duration of the recording from this probe """ @@ -315,8 +318,8 @@ def make(self, key): break else: raise FileNotFoundError( - "Ephys recording data not found!" - " Neither SpikeGLX nor Open Ephys recording files found" + f"Ephys recording data not found!" + f" Neither SpikeGLX nor Open Ephys recording files found" ) supported_probe_types = probe.ProbeType.fetch("probe_type") @@ -471,9 +474,9 @@ class Electrode(dj.Part): definition = """ -> master - -> probe.ElectrodeConfig.Electrode + -> probe.ElectrodeConfig.Electrode --- - lfp: longblob # (uV) recorded lfp at this electrode + lfp: longblob # (uV) recorded lfp at this electrode """ # Only store LFP for every 9th channel, due to high channel density, @@ -614,14 +617,14 @@ class ClusteringParamSet(dj.Lookup): ClusteringMethod (dict): ClusteringMethod primary key. paramset_desc (varchar(128) ): Description of the clustering parameter set. param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Set of clustering parameters + params (longblob) """ definition = """ # Parameter set to be used in a clustering procedure paramset_idx: smallint --- - -> ClusteringMethod + -> ClusteringMethod paramset_desc: varchar(128) param_set_hash: uuid unique index (param_set_hash) @@ -724,18 +727,15 @@ class ClusteringTask(dj.Manual): """ @classmethod - def infer_output_dir( - cls, key, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: + def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): """Infer output directory if it is not provided. Args: key (dict): ClusteringTask primary key. Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 + Pathlib.Path: Expected clustering_output_dir based on the following convention: processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 """ processed_dir = pathlib.Path(get_processed_root_data_dir()) session_dir = find_full_path( @@ -802,14 +802,14 @@ class Clustering(dj.Imported): Attributes: ClusteringTask (foreign key): ClusteringTask primary key. clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16): Package version used for a clustering analysis. + package_version (varchar(16) ): Package version used for a clustering analysis. """ definition = """ # Clustering Procedure -> ClusteringTask --- - clustering_time: datetime # time of generation of this set of clustering results + clustering_time: datetime # time of generation of this set of clustering results package_version='': varchar(16) """ @@ -850,10 +850,6 @@ def make(self, key): spikeglx_meta_filepath.parent ) spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) if clustering_method.startswith("pykilosort"): kilosort_triggering.run_pykilosort( @@ -874,7 +870,7 @@ def make(self, key): ks_output_dir=kilosort_dir, params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, + run_CatGT=True, ) run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -929,7 +925,7 @@ class CuratedClustering(dj.Imported): definition = """ # Clustering results of the spike sorting step. - -> Clustering + -> Clustering """ class Unit(dj.Part): @@ -946,7 +942,7 @@ class Unit(dj.Part): spike_depths (longblob): Array of depths associated with each spike, relative to each spike. """ - definition = """ + definition = """ # Properties of a given unit from a round of clustering (and curation) -> master unit: int @@ -956,7 +952,7 @@ class Unit(dj.Part): spike_count: int # how many spikes in this recording for this unit spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe + spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe """ def make(self, key): @@ -1080,8 +1076,8 @@ class Waveform(dj.Part): # Spike waveforms and their mean across spikes for the given unit -> master -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- + -> probe.ElectrodeConfig.Electrode + --- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit """ @@ -1109,15 +1105,32 @@ def make(self, key): for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") } + waveforms_folder = kilosort_dir / "we_kilosort" + + waveforms_folder = kilosort_dir.rglob(*waveform) + # Mean waveforms need to be extracted from waveform extractor object + if (waveforms_folder).exists(): + we_kilosort = si.load_waveforms(waveforms_folder) + unit_waveforms = we_kilosort.get_all_templates() + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + if unit_no in units: + unit_waveform = we_kilosort.get_waveforms(unit_id=unit_no) + mean_templates = we_kilosort.get_templates(unit_id=unit_no) + if (kilosort_dir / "mean_waveforms.npy").exists(): unit_waveforms = np.load( kilosort_dir / "mean_waveforms.npy" ) # unit x channel x sample def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): + for unit_no, unit_waveform in zip(cluster_ids, unit_waveforms): unit_peak_waveform = {} unit_electrode_waveforms = [] if unit_no in units: @@ -1207,7 +1220,7 @@ class QualityMetrics(dj.Imported): definition = """ # Clusters and waveforms metrics - -> CuratedClustering + -> CuratedClustering """ class Cluster(dj.Part): @@ -1232,26 +1245,26 @@ class Cluster(dj.Part): contamination_rate (float): Frequency of spikes in the refractory period. """ - definition = """ + definition = """ # Cluster metrics for a particular unit -> master -> CuratedClustering.Unit --- - firing_rate=null: float # (Hz) firing rate for a unit + firing_rate=null: float # (Hz) firing rate for a unit snr=null: float # signal-to-noise ratio for a unit presence_ratio=null: float # fraction of time in which spikes are present isi_violation=null: float # rate of ISI violation as a fraction of overall rate number_violation=null: int # total number of ISI violations amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # + l_ratio=null: float # d_prime=null: float # Classification accuracy based on LDA nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster silhouette_score=null: float # Standard metric for cluster overlap max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # + cumulative_drift=null: float # Cumulative change in spike depth throughout recording + contamination_rate=null: float # """ class Waveform(dj.Part): @@ -1268,10 +1281,10 @@ class Waveform(dj.Part): recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. + velocity_below (float) inverse velocity of waveform propagation from soma toward the bottom of the probe. """ - definition = """ + definition = """ # Waveform metrics for a particular unit -> master -> CuratedClustering.Unit diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a018119d..be50356b 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -464,8 +464,8 @@ def make(self, key): ], ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) - - we_kilosort.save("we_kilosort", kilosort_dir, n_jobs=-1, chunk_size=30000) + we_savedir = kilosort_dir / "we_kilosort" + we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) self.insert1( { From ff0dfee68bc45fbc43e42dec19541473ec9090e0 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 22 Feb 2023 20:08:36 -0600 Subject: [PATCH 034/152] modification to handle spike interface waveforms --- element_array_ephys/ephys_no_curation.py | 84 ++++++++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 69afaea2..85ecb1a7 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1105,13 +1105,14 @@ def make(self, key): for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") } - waveforms_folder = kilosort_dir / "we_kilosort" + waveforms_folder = [ + f for f in kilosort_dir.parent.rglob(r"*/waveforms*") if f.is_dir() + ] - waveforms_folder = kilosort_dir.rglob(*waveform) - # Mean waveforms need to be extracted from waveform extractor object - if (waveforms_folder).exists(): - we_kilosort = si.load_waveforms(waveforms_folder) - unit_waveforms = we_kilosort.get_all_templates() + if (kilosort_dir / "mean_waveforms.npy").exists(): + unit_waveforms = np.load( + kilosort_dir / "mean_waveforms.npy" + ) # unit x channel x sample def yield_unit_waveforms(): for unit_no, unit_waveform in zip( @@ -1119,18 +1120,46 @@ def yield_unit_waveforms(): ): unit_peak_waveform = {} unit_electrode_waveforms = [] - if unit_no in units: - unit_waveform = we_kilosort.get_waveforms(unit_id=unit_no) - mean_templates = we_kilosort.get_templates(unit_id=unit_no) + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms - if (kilosort_dir / "mean_waveforms.npy").exists(): - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample + # Spike interface mean and peak waveform extraction from we object + + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): + we_kilosort = si.load_waveforms(waveforms_folder[0].parent) + unit_templates = we_kilosort.get_all_templates() + unit_waveforms = np.reshape( + unit_templates, + ( + unit_templates.shape[1], + unit_templates.shape[3], + unit_templates.shape[2], + ), + ) + # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) def yield_unit_waveforms(): - for unit_no, unit_waveform in zip(cluster_ids, unit_waveforms): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): unit_peak_waveform = {} unit_electrode_waveforms = [] if unit_no in units: @@ -1154,6 +1183,33 @@ def yield_unit_waveforms(): } yield unit_peak_waveform, unit_electrode_waveforms + # Approach not using spike interface templates (ie. taking mean of each unit waveform) + # def yield_unit_waveforms(): + # for unit_id in we_kilosort.unit_ids: + # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) + # unit_peak_waveform = {} + # unit_electrode_waveforms = [] + # if unit_id in units: + # for channel, channel_waveform in zip( + # kilosort_dataset.data["channel_map"], unit_waveform + # ): + # unit_electrode_waveforms.append( + # { + # **units[unit_id], + # **channel2electrodes[channel], + # "waveform_mean": channel_waveform, + # } + # ) + # if ( + # channel2electrodes[channel]["electrode"] + # == units[unit_id]["electrode"] + # ): + # unit_peak_waveform = { + # **units[unit_id], + # "peak_electrode_waveform": channel_waveform, + # } + # yield unit_peak_waveform, unit_electrode_waveforms + else: if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) From cb31229e89d6599c71647a3c7bb34e2498dcd192 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 22 Feb 2023 20:11:37 -0600 Subject: [PATCH 035/152] adjust post processing --- .../spike_sorting/si_clustering.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index be50356b..84f26644 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -130,10 +130,15 @@ def make(self, key): clustering_method in _supported_kilosort_versions ), f'Clustering_method "{clustering_method}" is not supported' + if clustering_method.startswith("kilosort2.5"): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method # add additional probe-recording and channels details into `params` # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] + default_params = si.get_default_sorter_params(sorter_name) preprocess_list = params.pop("PreProcessing_params") # If else @@ -406,7 +411,7 @@ def make(self, key): "num_spikes", "amplitude_cutoff", "amplitude_median", - "sliding_rp_violation", + # "sliding_rp_violation", "rp_violation", "drift", ], @@ -436,14 +441,14 @@ def make(self, key): # QC Metrics # Apply waveform extractor extensions - spike_locations = si.compute_spike_locations(we_kilosort) - spike_amplitudes = si.compute_spike_amplitudes(we_kilosort) - unit_locations = si.compute_unit_locations(we_kilosort) - template_metrics = si.compute_template_metrics(we_kilosort) - noise_levels = si.compute_noise_levels(we_kilosort) - pcs = si.compute_principal_components(we_kilosort) - drift_metrics = si.compute_drift_metrics(we_kilosort) - template_similarity = si.compute_tempoate_similarity(we_kilosort) + _ = si.compute_spike_locations(we_kilosort) + _ = si.compute_spike_amplitudes(we_kilosort) + _ = si.compute_unit_locations(we_kilosort) + _ = si.compute_template_metrics(we_kilosort) + _ = si.compute_noise_levels(we_kilosort) + _ = si.compute_principal_components(we_kilosort) + _ = si.compute_drift_metrics(we_kilosort) + _ = si.compute_tempoate_similarity(we_kilosort) (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( we_kilosort, isi_threshold_ms=1.5 ) From 6098421e9037017be02010672402efd885ce5b24 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 6 Mar 2023 11:56:38 -0600 Subject: [PATCH 036/152] bugfix in postprocessing definition --- element_array_ephys/spike_sorting/si_clustering.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 84f26644..f3fd4c1d 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -10,7 +10,7 @@ - create recording extractor and link it to a probe - bandpass filtering - common mode referencing -2. ClusteringModule - kilosort (MATLAB) - requires GPU and docker/singularity containers +2. SIClustering - kilosort (MATLAB) - requires GPU and docker/singularity containers - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) 3. PostProcessing - for postprocessing steps (no GPU required) - create waveform extractor object @@ -357,7 +357,7 @@ class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> ClusteringModule + -> SIClustering --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -426,11 +426,11 @@ def make(self, key): elif acq_software == "Open Ephys": sorting_file = kilosort_dir / "sorting_kilosort" recording_file = kilosort_dir / "sglx_recording_cmr.json" - sglx_si_recording = sic.load_extractor(recording_file) + oe_si_recording = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) we_kilosort = si.WaveformExtractor.create( - sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True + oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True ) we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) @@ -472,6 +472,9 @@ def make(self, key): we_savedir = kilosort_dir / "we_kilosort" we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) + metrics_savefile = kilosort_dir / "metrics.csv" + metrics.to_csv(metrics_savefile) + self.insert1( { **key, From 60064646c1d3c65d2d37173f0ef686fa5a108387 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:17:35 -0600 Subject: [PATCH 037/152] add SI ibl destriping and catGT implementations --- .../spike_sorting/si_clustering.py | 108 +++++++++++------- 1 file changed, 69 insertions(+), 39 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f3fd4c1d..cb4e1858 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -141,45 +141,6 @@ def make(self, key): default_params = si.get_default_sorter_params(sorter_name) preprocess_list = params.pop("PreProcessing_params") - # If else - # need to figure out ordering - if preprocess_list["Filter"]: - oe_si_recording = sip.FilterRecording(oe_si_recording) - elif preprocess_list["BandpassFilter"]: - oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) - elif preprocess_list["HighpassFilter"]: - oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) - elif preprocess_list["NormalizeByQuantile"]: - oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) - elif preprocess_list["Scale"]: - oe_si_recording = sip.ScaleRecording(oe_si_recording) - elif preprocess_list["Center"]: - oe_si_recording = sip.CenterRecording(oe_si_recording) - elif preprocess_list["ZScore"]: - oe_si_recording = sip.ZScoreRecording(oe_si_recording) - elif preprocess_list["Whiten"]: - oe_si_recording = sip.WhitenRecording(oe_si_recording) - elif preprocess_list["CommonReference"]: - oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) - elif preprocess_list["PhaseShift"]: - oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) - elif preprocess_list["Rectify"]: - oe_si_recording = sip.RectifyRecording(oe_si_recording) - elif preprocess_list["Clip"]: - oe_si_recording = sip.ClipRecording(oe_si_recording) - elif preprocess_list["BlankSaturation"]: - oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) - elif preprocess_list["RemoveArtifacts"]: - oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) - elif preprocess_list["RemoveBadChannels"]: - oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) - elif preprocess_list["ZeroChannelPad"]: - oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) - elif preprocess_list["DeepInterpolation"]: - oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) - elif preprocess_list["Resample"]: - oe_si_recording = sip.ResampleRecording(oe_si_recording) - if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) @@ -212,6 +173,8 @@ def make(self, key): save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) + sglx_si_recording = run_IBLdestriping(sglx_si_recording) + elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) oe_session_full_path = find_full_path( @@ -492,6 +455,73 @@ def make(self, key): ) +# def runPreProcessList(preprocess_list, recording): +# # If else +# # need to figure out ordering +# if preprocess_list["Filter"]: +# recording = sip.FilterRecording(recording) +# if preprocess_list["BandpassFilter"]: +# recording = sip.BandpassFilterRecording(recording) +# if preprocess_list["HighpassFilter"]: +# recording = sip.HighpassFilterRecording(recording) +# if preprocess_list["NormalizeByQuantile"]: +# recording = sip.NormalizeByQuantileRecording(recording) +# if preprocess_list["Scale"]: +# recording = sip.ScaleRecording(recording) +# if preprocess_list["Center"]: +# recording = sip.CenterRecording(recording) +# if preprocess_list["ZScore"]: +# recording = sip.ZScoreRecording(recording) +# if preprocess_list["Whiten"]: +# recording = sip.WhitenRecording(recording) +# if preprocess_list["CommonReference"]: +# recording = sip.CommonReferenceRecording(recording) +# if preprocess_list["PhaseShift"]: +# recording = sip.PhaseShiftRecording(recording) +# elif preprocess_list["Rectify"]: +# recording = sip.RectifyRecording(recording) +# elif preprocess_list["Clip"]: +# recording = sip.ClipRecording(recording) +# elif preprocess_list["BlankSaturation"]: +# recording = sip.BlankSaturationRecording(recording) +# elif preprocess_list["RemoveArtifacts"]: +# recording = sip.RemoveArtifactsRecording(recording) +# elif preprocess_list["RemoveBadChannels"]: +# recording = sip.RemoveBadChannelsRecording(recording) +# elif preprocess_list["ZeroChannelPad"]: +# recording = sip.ZeroChannelPadRecording(recording) +# elif preprocess_list["DeepInterpolation"]: +# recording = sip.DeepInterpolationRecording(recording) +# elif preprocess_list["Resample"]: +# recording = sip.ResampleRecording(recording) + + +def mimic_IBLdestriping_modified(recording): + # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) + recording = si.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = recording.remove_channels(bad_channel_ids) + recording = si.phase_shift(recording) + recording = si.common_reference(recording, operator="median", reference="global") + return recording + +def mimic_IBLdestriping(recording): + # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. + recording = si.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = sip.interpolate_bad_channels(bad_channel_ids) + recording = si.phase_shift(recording) + recording = si.highpass_spatial_filter(recording, operator="median", reference="global") + # For IBL destriping use highpass_spatial_filter used instead of common reference + return recording + +def mimic_catGT(sglx_recording): + sglx_recording = si.phase_shift(sglx_recording) + sglx_recording = si.common_reference(sglx_recording, operator="median", reference="global") + return sglx_recording + ## Example SI parameter set """ {'detect_threshold': 6, From c050875c3f3d98e085c1a2d6aa591952df124632 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:24:09 -0600 Subject: [PATCH 038/152] remove preprocess params list --- .../spike_sorting/si_clustering.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cb4e1858..81e5f0b3 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -139,7 +139,7 @@ def make(self, key): # params["fs"] = params["sample_rate"] default_params = si.get_default_sorter_params(sorter_name) - preprocess_list = params.pop("PreProcessing_params") + # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) @@ -173,7 +173,7 @@ def make(self, key): save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - sglx_si_recording = run_IBLdestriping(sglx_si_recording) + sglx_si_recording = mimic_catGT(sglx_si_recording) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) @@ -208,17 +208,17 @@ def make(self, key): oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder - # Switch case to allow for specified preprocessing steps - oe_si_recording_filtered = sip.bandpass_filter( - oe_si_recording, freq_min=300, freq_max=6000 - ) - oe_recording_cmr = sip.common_reference( - oe_si_recording_filtered, reference="global", operator="median" - ) - + # # Switch case to allow for specified preprocessing steps + # oe_si_recording_filtered = sip.bandpass_filter( + # oe_si_recording, freq_min=300, freq_max=6000 + # ) + # oe_recording_cmr = sip.common_reference( + # oe_si_recording_filtered, reference="global", operator="median" + # ) + oe_si_recording = mimic_IBLdestriping(oe_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name - oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) + oe_si_recording.dump_to_pickle(file_path=save_file_path) self.insert1( { @@ -506,6 +506,7 @@ def mimic_IBLdestriping_modified(recording): recording = si.common_reference(recording, operator="median", reference="global") return recording + def mimic_IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. recording = si.highpass_filter(recording, freq_min=400.0) @@ -513,15 +514,21 @@ def mimic_IBLdestriping(recording): # For IBL destriping interpolate bad channels recording = sip.interpolate_bad_channels(bad_channel_ids) recording = si.phase_shift(recording) - recording = si.highpass_spatial_filter(recording, operator="median", reference="global") # For IBL destriping use highpass_spatial_filter used instead of common reference + recording = si.highpass_spatial_filter( + recording, operator="median", reference="global" + ) return recording + def mimic_catGT(sglx_recording): sglx_recording = si.phase_shift(sglx_recording) - sglx_recording = si.common_reference(sglx_recording, operator="median", reference="global") + sglx_recording = si.common_reference( + sglx_recording, operator="median", reference="global" + ) return sglx_recording + ## Example SI parameter set """ {'detect_threshold': 6, From 4ea56c0f3d86bd41ad4d623cc31ef82d131f03a3 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:28:18 -0600 Subject: [PATCH 039/152] preprocessing changes --- .../spike_sorting/si_clustering.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 81e5f0b3..33704081 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -138,7 +138,7 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - default_params = si.get_default_sorter_params(sorter_name) + # default_params = si.get_default_sorter_params(sorter_name) # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": @@ -163,17 +163,15 @@ def make(self, key): si_probe.set_device_channel_indices(channels_details["channel_ind"]) sglx_si_recording.set_probe(probe=si_probe) - # run preprocessing and save results to output folder - sglx_si_recording_filtered = sip.bandpass_filter( - sglx_si_recording, freq_min=300, freq_max=6000 - ) + # # run preprocessing and save results to output folder + # sglx_si_recording_filtered = sip.bandpass_filter( + # sglx_si_recording, freq_min=300, freq_max=6000 + # ) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - + sglx_si_recording = mimic_catGT(sglx_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name - sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - - sglx_si_recording = mimic_catGT(sglx_si_recording) + sglx_si_recording.dump_to_pickle(file_path=save_file_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) From 0a875794726cc3d4c481825ed5566ede22f46514 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 15 May 2023 18:04:32 -0500 Subject: [PATCH 040/152] Update requirements.txt --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 528f6349..0d47a42f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,4 @@ plotly pyopenephys>=1.1.6 seaborn scikit-image -spikeinterface -nbformat>=4.2.0 \ No newline at end of file +nbformat>=4.2.0 From 22f1f65fe3773f2e4d8803cab15b694e3921d0a2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 14 Jun 2023 15:36:07 -0500 Subject: [PATCH 041/152] fix spikeglx stream loading --- element_array_ephys/spike_sorting/si_clustering.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 33704081..f99d63d2 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -146,7 +146,12 @@ def make(self, key): sglx_filepath = ephys.get_spikeglx_meta_filepath(key) # Create SI recording extractor object - sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + stream_name = sglx_filepath.stem.split(".", 1)[1] + sglx_si_recording = se.read_spikeglx( + folder_path=sglx_filepath.parent, + stream_name=stream_name, + stream_id=stream_name, + ) channels_details = ephys.get_recording_channels_details(key) xy_coords = [ From b62f16215efe91c9310383e99a72d7abdc8de983 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 11 Oct 2023 18:32:49 -0500 Subject: [PATCH 042/152] build: :pushpin: update requirements.txt & add env,.yml --- env.yml | 7 +++++++ requirements.txt | 3 ++- setup.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 env.yml diff --git a/env.yml b/env.yml new file mode 100644 index 00000000..e9b3ce13 --- /dev/null +++ b/env.yml @@ -0,0 +1,7 @@ +channels: + - conda-forge + - defaults +dependencies: + - pip + - python>=3.7,<3.11 +name: element_array_ephys diff --git a/requirements.txt b/requirements.txt index 0d47a42f..721bfeda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ datajoint>=0.13 element-interface>=0.4.0 ipywidgets +nbformat>=4.2.0 openpyxl plotly pyopenephys>=1.1.6 seaborn scikit-image -nbformat>=4.2.0 +spikeinterface \ No newline at end of file diff --git a/setup.py b/setup.py index 31b9be61..cc538478 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ setup( name=pkg_name.replace("_", "-"), + python_requires='>=3.7, <3.11', version=__version__, # noqa F821 description="DataJoint Element for Extracellular Array Electrophysiology", long_description=long_description, From 849b576c8982b9231b4b1167740d2d1a2ad1cbdd Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 11 Oct 2023 18:36:51 -0500 Subject: [PATCH 043/152] refactor: :art: clean up spikeinterface import & remove unused import --- .../spike_sorting/si_clustering.py | 157 +++++++++--------- 1 file changed, 77 insertions(+), 80 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f99d63d2..2cb5bf2e 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -16,16 +16,12 @@ - create waveform extractor object - extract templates, waveforms and snrs - quality_metrics - - """ + import datajoint as dj import os from element_array_ephys import get_logger -from decimal import Decimal -import json -import numpy as np -from datetime import datetime, timedelta +from datetime import datetime from element_interface.utils import find_full_path from element_array_ephys.readers import ( @@ -34,13 +30,7 @@ ) import element_array_ephys.probe as probe -import spikeinterface -import spikeinterface.full as si -import spikeinterface.core as sic -import spikeinterface.extractors as se -import spikeinterface.exporters as sie -import spikeinterface.sorters as ss -import spikeinterface.preprocessing as sip +import spikeinterface as si import probeinterface as pi log = get_logger(__name__) @@ -138,7 +128,7 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - # default_params = si.get_default_sorter_params(sorter_name) + # default_params = si.full.get_default_sorter_params(sorter_name) # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": @@ -147,7 +137,7 @@ def make(self, key): # Create SI recording extractor object stream_name = sglx_filepath.stem.split(".", 1)[1] - sglx_si_recording = se.read_spikeglx( + sglx_si_recording = si.extractors.read_spikeglx( folder_path=sglx_filepath.parent, stream_name=stream_name, stream_id=stream_name, @@ -169,10 +159,10 @@ def make(self, key): sglx_si_recording.set_probe(probe=si_probe) # # run preprocessing and save results to output folder - # sglx_si_recording_filtered = sip.bandpass_filter( + # sglx_si_recording_filtered = si.preprocessing.bandpass_filter( # sglx_si_recording, freq_min=300, freq_max=6000 # ) - # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + # sglx_recording_cmr = si.preprocessing.common_reference(sglx_si_recording_filtered, reference="global", operator="median") sglx_si_recording = mimic_catGT(sglx_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name @@ -190,8 +180,8 @@ def make(self, key): ] # Create SI recording extractor object - # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys( + # oe_si_recording = si.extractors.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = si.extractors.read_openephys( folder_path=oe_session_full_path, stream_name=stream_name ) @@ -212,10 +202,10 @@ def make(self, key): # run preprocessing and save results to output folder # # Switch case to allow for specified preprocessing steps - # oe_si_recording_filtered = sip.bandpass_filter( + # oe_si_recording_filtered = si.preprocessing.bandpass_filter( # oe_si_recording, freq_min=300, freq_max=6000 # ) - # oe_recording_cmr = sip.common_reference( + # oe_recording_cmr = si.preprocessing.common_reference( # oe_si_recording_filtered, reference="global", operator="median" # ) oe_si_recording = mimic_IBLdestriping(oe_si_recording) @@ -265,8 +255,8 @@ def make(self, key): if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = se.load_from_folder(recording_file) - sglx_si_recording = sic.load_extractor(recording_fullpath) + # sglx_si_recording = si.extractors.load_from_folder(recording_file) + sglx_si_recording = si.core.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 ## Assume that the worker process will trigger this sorting step @@ -278,7 +268,7 @@ def make(self, key): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( + sorting_kilosort = si.full.run_sorter( sorter_name=sorter_name, recording=sglx_si_recording, output_folder=kilosort_dir, @@ -289,13 +279,13 @@ def make(self, key): sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = sic.load_extractor(recording_fullpath) + oe_si_recording = si.core.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( + sorting_kilosort = si.full.run_sorter( sorter_name=sorter_name, recording=oe_si_recording, output_folder=kilosort_dir, @@ -345,10 +335,10 @@ def make(self, key): recording_filename = (PreProcessing & key).fetch1("recording_filename") sorting_file = kilosort_dir / "sorting_kilosort" filtered_recording_file = kilosort_dir / recording_filename - sglx_si_recording_filtered = sic.load_extractor(recording_file) - sorting_kilosort = sic.load_extractor(sorting_file) + sglx_si_recording_filtered = si.core.load_extractor(recording_file) + sorting_kilosort = si.core.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create( + we_kilosort = si.full.WaveformExtractor.create( sglx_si_recording_filtered, sorting_kilosort, "waveforms", @@ -359,15 +349,15 @@ def make(self, key): unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) - snrs = si.compute_snrs(we_kilosort) + snrs = si.full.compute_snrs(we_kilosort) # QC Metrics ( si_violations_ratio, isi_violations_rate, isi_violations_count, - ) = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics( + ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.full.compute_quality_metrics( we_kilosort, metric_names=[ "firing_rate", @@ -382,7 +372,9 @@ def make(self, key): "drift", ], ) - sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + si.exporters.export_report( + we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 + ) # ["firing_rate","snr","presence_ratio","isi_violation", # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) @@ -392,10 +384,10 @@ def make(self, key): elif acq_software == "Open Ephys": sorting_file = kilosort_dir / "sorting_kilosort" recording_file = kilosort_dir / "sglx_recording_cmr.json" - oe_si_recording = sic.load_extractor(recording_file) - sorting_kilosort = sic.load_extractor(sorting_file) + oe_si_recording = si.core.load_extractor(recording_file) + sorting_kilosort = si.core.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create( + we_kilosort = si.full.WaveformExtractor.create( oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True ) we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) @@ -403,23 +395,24 @@ def make(self, key): unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) - snrs = si.compute_snrs(we_kilosort) + snrs = si.full.compute_snrs(we_kilosort) # QC Metrics # Apply waveform extractor extensions - _ = si.compute_spike_locations(we_kilosort) - _ = si.compute_spike_amplitudes(we_kilosort) - _ = si.compute_unit_locations(we_kilosort) - _ = si.compute_template_metrics(we_kilosort) - _ = si.compute_noise_levels(we_kilosort) - _ = si.compute_principal_components(we_kilosort) - _ = si.compute_drift_metrics(we_kilosort) - _ = si.compute_tempoate_similarity(we_kilosort) - (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( - we_kilosort, isi_threshold_ms=1.5 - ) - (isi_histograms, bins) = si.compute_isi_histograms(we_kilosort) - metrics = si.compute_quality_metrics( + _ = si.full.compute_spike_locations(we_kilosort) + _ = si.full.compute_spike_amplitudes(we_kilosort) + _ = si.full.compute_unit_locations(we_kilosort) + _ = si.full.compute_template_metrics(we_kilosort) + _ = si.full.compute_noise_levels(we_kilosort) + _ = si.full.compute_principal_components(we_kilosort) + _ = si.full.compute_drift_metrics(we_kilosort) + _ = si.full.compute_tempoate_similarity(we_kilosort) + ( + isi_violations_ratio, + isi_violations_count, + ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + (isi_histograms, bins) = si.full.compute_isi_histograms(we_kilosort) + metrics = si.full.compute_quality_metrics( we_kilosort, metric_names=[ "firing_rate", @@ -434,7 +427,9 @@ def make(self, key): "drift", ], ) - sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + si.exporters.export_report( + we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 + ) we_savedir = kilosort_dir / "we_kilosort" we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) @@ -462,71 +457,73 @@ def make(self, key): # # If else # # need to figure out ordering # if preprocess_list["Filter"]: -# recording = sip.FilterRecording(recording) +# recording = si.preprocessing.FilterRecording(recording) # if preprocess_list["BandpassFilter"]: -# recording = sip.BandpassFilterRecording(recording) +# recording = si.preprocessing.BandpassFilterRecording(recording) # if preprocess_list["HighpassFilter"]: -# recording = sip.HighpassFilterRecording(recording) +# recording = si.preprocessing.HighpassFilterRecording(recording) # if preprocess_list["NormalizeByQuantile"]: -# recording = sip.NormalizeByQuantileRecording(recording) +# recording = si.preprocessing.NormalizeByQuantileRecording(recording) # if preprocess_list["Scale"]: -# recording = sip.ScaleRecording(recording) +# recording = si.preprocessing.ScaleRecording(recording) # if preprocess_list["Center"]: -# recording = sip.CenterRecording(recording) +# recording = si.preprocessing.CenterRecording(recording) # if preprocess_list["ZScore"]: -# recording = sip.ZScoreRecording(recording) +# recording = si.preprocessing.ZScoreRecording(recording) # if preprocess_list["Whiten"]: -# recording = sip.WhitenRecording(recording) +# recording = si.preprocessing.WhitenRecording(recording) # if preprocess_list["CommonReference"]: -# recording = sip.CommonReferenceRecording(recording) +# recording = si.preprocessing.CommonReferenceRecording(recording) # if preprocess_list["PhaseShift"]: -# recording = sip.PhaseShiftRecording(recording) +# recording = si.preprocessing.PhaseShiftRecording(recording) # elif preprocess_list["Rectify"]: -# recording = sip.RectifyRecording(recording) +# recording = si.preprocessing.RectifyRecording(recording) # elif preprocess_list["Clip"]: -# recording = sip.ClipRecording(recording) +# recording = si.preprocessing.ClipRecording(recording) # elif preprocess_list["BlankSaturation"]: -# recording = sip.BlankSaturationRecording(recording) +# recording = si.preprocessing.BlankSaturationRecording(recording) # elif preprocess_list["RemoveArtifacts"]: -# recording = sip.RemoveArtifactsRecording(recording) +# recording = si.preprocessing.RemoveArtifactsRecording(recording) # elif preprocess_list["RemoveBadChannels"]: -# recording = sip.RemoveBadChannelsRecording(recording) +# recording = si.preprocessing.RemoveBadChannelsRecording(recording) # elif preprocess_list["ZeroChannelPad"]: -# recording = sip.ZeroChannelPadRecording(recording) +# recording = si.preprocessing.ZeroChannelPadRecording(recording) # elif preprocess_list["DeepInterpolation"]: -# recording = sip.DeepInterpolationRecording(recording) +# recording = si.preprocessing.DeepInterpolationRecording(recording) # elif preprocess_list["Resample"]: -# recording = sip.ResampleRecording(recording) +# recording = si.preprocessing.ResampleRecording(recording) def mimic_IBLdestriping_modified(recording): # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) - recording = si.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + recording = si.full.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) # For IBL destriping interpolate bad channels recording = recording.remove_channels(bad_channel_ids) - recording = si.phase_shift(recording) - recording = si.common_reference(recording, operator="median", reference="global") + recording = si.full.phase_shift(recording) + recording = si.full.common_reference( + recording, operator="median", reference="global" + ) return recording def mimic_IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. - recording = si.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + recording = si.full.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) # For IBL destriping interpolate bad channels - recording = sip.interpolate_bad_channels(bad_channel_ids) - recording = si.phase_shift(recording) + recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) + recording = si.full.phase_shift(recording) # For IBL destriping use highpass_spatial_filter used instead of common reference - recording = si.highpass_spatial_filter( + recording = si.full.highpass_spatial_filter( recording, operator="median", reference="global" ) return recording def mimic_catGT(sglx_recording): - sglx_recording = si.phase_shift(sglx_recording) - sglx_recording = si.common_reference( + sglx_recording = si.full.phase_shift(sglx_recording) + sglx_recording = si.full.common_reference( sglx_recording, operator="median", reference="global" ) return sglx_recording From 7836a8b4e6c600ba31b8af7efd0e17030e25b158 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 13 Oct 2023 16:47:09 -0500 Subject: [PATCH 044/152] modify key_source in PreProcessing --- element_array_ephys/spike_sorting/si_clustering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 2cb5bf2e..e8b6517c 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -85,11 +85,11 @@ class PreProcessing(dj.Imported): @property def key_source(self): - return ( + return (( ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering + ) - ephys.Clustering).proj() def make(self, key): """Triggers or imports clustering analysis.""" From 6bee166f5fe356ddea10e1a5b391e6daf6929ec9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 14 Dec 2023 16:31:31 -0600 Subject: [PATCH 045/152] feat: :sparkles: improve to_probeinterface --- element_array_ephys/readers/probe_geometry.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/readers/probe_geometry.py b/element_array_ephys/readers/probe_geometry.py index 11e3ae99..7247abe9 100644 --- a/element_array_ephys/readers/probe_geometry.py +++ b/element_array_ephys/readers/probe_geometry.py @@ -132,8 +132,8 @@ def build_npx_probe( return elec_pos_df -def to_probeinterface(electrodes_df): - from probeinterface import Probe +def to_probeinterface(electrodes_df, **kwargs): + import probeinterface as pi probe_df = electrodes_df.copy() probe_df.rename( @@ -145,10 +145,22 @@ def to_probeinterface(electrodes_df): }, inplace=True, ) - probe_df["contact_shapes"] = "square" - probe_df["width"] = 12 - - return Probe.from_dataframe(probe_df) + # Get the contact shapes. By default, it's set to circle with a radius of 10. + contact_shapes = kwargs.get("contact_shapes", "circle") + assert ( + contact_shapes in pi.probe._possible_contact_shapes + ), f"contacts shape should be in {pi.probe._possible_contact_shapes}" + + probe_df["contact_shapes"] = contact_shapes + if contact_shapes == "circle": + probe_df["radius"] = kwargs.get("radius", 10) + elif contact_shapes == "square": + probe_df["width"] = kwargs.get("width", 10) + elif contact_shapes == "rect": + probe_df["width"] = kwargs.get("width") + probe_df["height"] = kwargs.get("height") + + return pi.Probe.from_dataframe(probe_df) def build_electrode_layouts( From 9ae6b4492583429db9931eda93c37d5092b05e8e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:10:07 -0600 Subject: [PATCH 046/152] create preprocessing.py --- .../spike_sorting/preprocessing.py | 85 +++++++++++++++++++ .../spike_sorting/si_clustering.py | 11 ++- 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 element_array_ephys/spike_sorting/preprocessing.py diff --git a/element_array_ephys/spike_sorting/preprocessing.py b/element_array_ephys/spike_sorting/preprocessing.py new file mode 100644 index 00000000..77a95792 --- /dev/null +++ b/element_array_ephys/spike_sorting/preprocessing.py @@ -0,0 +1,85 @@ +import spikeinterface as si +from spikeinterface import preprocessing + + +def mimic_catGT(recording): + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording + + +def mimic_IBLdestriping(recording): + # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + # For IBL destriping use highpass_spatial_filter used instead of common reference + recording = si.preprocessing.highpass_spatial_filter( + recording, operator="median", reference="global" + ) + return recording + + +def mimic_IBLdestriping_modified(recording): + # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = recording.remove_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording + + +_preprocessing_function = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, +} + + +## Example SI parameter set +""" +{'detect_threshold': 6, + 'projection_threshold': [10, 4], + 'preclust_threshold': 8, + 'car': True, + 'minFR': 0.02, + 'minfr_goodchannels': 0.1, + 'nblocks': 5, + 'sig': 20, + 'freq_min': 150, + 'sigmaMask': 30, + 'nPCs': 3, + 'ntbuff': 64, + 'nfilt_factor': 4, + 'NT': None, + 'do_correction': True, + 'wave_length': 61, + 'keep_good_only': False, + 'PreProcessing_params': {'Filter': False, + 'BandpassFilter': True, + 'HighpassFilter': False, + 'NotchFilter': False, + 'NormalizeByQuantile': False, + 'Scale': False, + 'Center': False, + 'ZScore': False, + 'Whiten': False, + 'CommonReference': False, + 'PhaseShift': False, + 'Rectify': False, + 'Clip': False, + 'BlankSaturation': False, + 'RemoveArtifacts': False, + 'RemoveBadChannels': False, + 'ZeroChannelPad': False, + 'DeepInterpolation': False, + 'Resample': False}} +""" diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e8b6517c..a8d5d8c0 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -31,7 +31,16 @@ import element_array_ephys.probe as probe import spikeinterface as si -import probeinterface as pi +from element_interface.utils import find_full_path, find_root_directory +from spikeinterface import sorters + +from element_array_ephys import get_logger, probe, readers + +from .preprocessing import ( + mimic_catGT, + mimic_IBLdestriping, + mimic_IBLdestriping_modified, +) log = get_logger(__name__) From 9d5eee66aea3d9b967bd22beeefd2651b93a3b6b Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:12:19 -0600 Subject: [PATCH 047/152] add SI_SORTERS , SI_READERS --- .../spike_sorting/si_clustering.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a8d5d8c0..b9e9cb2e 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -18,18 +18,10 @@ - quality_metrics """ -import datajoint as dj -import os -from element_array_ephys import get_logger from datetime import datetime -from element_interface.utils import find_full_path -from element_array_ephys.readers import ( - spikeglx, - kilosort_triggering, -) -import element_array_ephys.probe as probe - +import datajoint as dj +import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory from spikeinterface import sorters @@ -48,12 +40,6 @@ ephys = None -_supported_kilosort_versions = [ - "kilosort2", - "kilosort2.5", - "kilosort3", -] - def activate( schema_name, @@ -79,6 +65,15 @@ def activate( ) +SI_SORTERS = [s.replace(".", "_") for s in si.sorters.sorter_dict.keys()] + +SI_READERS = { + "Open Ephys": si.extractors.read_openephys, + "SpikeGLX": si.extractors.read_spikeglx, + "Intan": si.extractors.read_intan, +} + + @schema class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" From 1fceb90a1816613a0e86f2f7288ba56ba254de40 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:15:45 -0600 Subject: [PATCH 048/152] feat: :art: si_clustering.PreProcessing --- .../spike_sorting/si_clustering.py | 149 +++++------------- 1 file changed, 37 insertions(+), 112 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index b9e9cb2e..5510436f 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -76,34 +76,31 @@ def activate( @schema class PreProcessing(dj.Imported): - """A table to handle preprocessing of each clustering task.""" + """A table to handle preprocessing of each clustering task. The output will be serialized and stored as a si_recording.pkl in the output directory.""" definition = """ -> ephys.ClusteringTask --- - recording_filename: varchar(30) # filename where recording object is saved to - params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_duration: float # execution duration in hours """ @property def key_source(self): - return (( + return ( ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering).proj() + & f"clustering_method in {tuple(SI_SORTERS)}" + ) - ephys.Clustering def make(self, key): """Triggers or imports clustering analysis.""" execution_time = datetime.utcnow() - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + # Set the output directory + acq_software, output_dir = ( + ephys.ClusteringTask * ephys.EphysRecording & key + ).fetch1("acq_software", "clustering_output_dir") if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -113,115 +110,43 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - # add additional probe-recording and channels details into `params` - # params = {**params, **ephys.get_recording_channels_details(key)} - # params["fs"] = params["sample_rate"] - - # default_params = si.full.get_default_sorter_params(sorter_name) - # preprocess_list = params.pop("PreProcessing_params") - - if acq_software == "SpikeGLX": - # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - sglx_filepath = ephys.get_spikeglx_meta_filepath(key) - - # Create SI recording extractor object - stream_name = sglx_filepath.stem.split(".", 1)[1] - sglx_si_recording = si.extractors.read_spikeglx( - folder_path=sglx_filepath.parent, - stream_name=stream_name, - stream_id=stream_name, - ) - - channels_details = ephys.get_recording_channels_details(key) - xy_coords = [ - list(i) - for i in zip(channels_details["x_coords"], channels_details["y_coords"]) - ] - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units="um") - si_probe.set_contacts( - positions=xy_coords, shapes="square", shape_params={"width": 12} - ) - si_probe.create_auto_shape(probe_type="tip") - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - sglx_si_recording.set_probe(probe=si_probe) - - # # run preprocessing and save results to output folder - # sglx_si_recording_filtered = si.preprocessing.bandpass_filter( - # sglx_si_recording, freq_min=300, freq_max=6000 - # ) - # sglx_recording_cmr = si.preprocessing.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_si_recording = mimic_catGT(sglx_si_recording) - save_file_name = "si_recording.pkl" - save_file_path = kilosort_dir / save_file_name - sglx_si_recording.dump_to_pickle(file_path=save_file_path) - - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - oe_session_full_path = find_full_path( - ephys.get_ephys_root_data_dir(), ephys.get_session_directory(key) - ) + # Create SI recording extractor object + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=output_dir + ) - assert len(oe_probe.recording_info["recording_files"]) == 1 - stream_name = os.path.split(oe_probe.recording_info["recording_files"][0])[ - 1 - ] - - # Create SI recording extractor object - # oe_si_recording = si.extractors.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = si.extractors.read_openephys( - folder_path=oe_session_full_path, stream_name=stream_name + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) - channels_details = ephys.get_recording_channels_details(key) - xy_coords = [ - list(i) - for i in zip(channels_details["x_coords"], channels_details["y_coords"]) - ] + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_recording.set_probe(probe=si_probe, in_place=True) - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units="um") - si_probe.set_contacts( - positions=xy_coords, shapes="square", shape_params={"width": 12} - ) - si_probe.create_auto_shape(probe_type="tip") - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - oe_si_recording.set_probe(probe=si_probe) - - # run preprocessing and save results to output folder - # # Switch case to allow for specified preprocessing steps - # oe_si_recording_filtered = si.preprocessing.bandpass_filter( - # oe_si_recording, freq_min=300, freq_max=6000 - # ) - # oe_recording_cmr = si.preprocessing.common_reference( - # oe_si_recording_filtered, reference="global", operator="median" - # ) - oe_si_recording = mimic_IBLdestriping(oe_si_recording) - save_file_name = "si_recording.pkl" - save_file_path = kilosort_dir / save_file_name - oe_si_recording.dump_to_pickle(file_path=save_file_path) + # Run preprocessing and save results to output folder + preprocessing_method = "catGT" # where to load this info? + si_recording = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, + }[preprocessing_method](si_recording) + recording_file_name = output_dir / "si_recording.pkl" + si_recording.dump_to_pickle(file_path=recording_file_name) self.insert1( { **key, - "recording_filename": save_file_name, - "params": params, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From df8ed7464ebc3452148c0daa1f1e43987d7035ec Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 19 Dec 2023 21:53:43 -0600 Subject: [PATCH 049/152] feat: :art: si_clustering.SIClustering --- .../spike_sorting/si_clustering.py | 93 +++++++------------ 1 file changed, 35 insertions(+), 58 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 5510436f..debc4336 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -21,10 +21,11 @@ from datetime import datetime import datajoint as dj +import pandas as pd import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory -from spikeinterface import sorters +from spikeinterface import exporters, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -65,7 +66,7 @@ def activate( ) -SI_SORTERS = [s.replace(".", "_") for s in si.sorters.sorter_dict.keys()] +SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] SI_READERS = { "Open Ephys": si.extractors.read_openephys, @@ -141,8 +142,8 @@ def make(self, key): "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, }[preprocessing_method](si_recording) - recording_file_name = output_dir / "si_recording.pkl" - si_recording.dump_to_pickle(file_path=recording_file_name) + recording_file = output_dir / "si_recording.pkl" + si_recording.dump_to_pickle(file_path=recording_file) self.insert1( { @@ -162,72 +163,48 @@ class SIClustering(dj.Imported): definition = """ -> PreProcessing + sorter_name: varchar(30) # name of the sorter used --- - sorting_filename: varchar(30) # filename of saved sorting object - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_time: datetime # datetime of the start of this step + execution_duration: float # execution duration in hours """ def make(self, key): execution_time = datetime.utcnow() + # Load recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (PreProcessing & key).fetch1("params") - recording_filename = (PreProcessing & key).fetch1("recording_filename") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" + si_recording: si.BaseRecording = si.load_extractor(recording_file) + + # Get sorter method and create output directory. + clustering_method, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "params") + sorter_name = ( + "kilosort_2_5" if clustering_method == "kilsort2.5" else clustering_method + ) + sorter_dir = output_dir / sorter_name + + # Run sorting + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + output_folder=sorter_dir, + verbse=True, + docker_image=True, + **params, + ) - if acq_software == "SpikeGLX": - # sglx_probe = ephys.get_openephys_probe_data(key) - recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = si.extractors.load_from_folder(recording_file) - sglx_si_recording = si.core.load_extractor(recording_fullpath) - # assert len(oe_probe.recording_info["recording_files"]) == 1 - - ## Assume that the worker process will trigger this sorting step - # - Will need to store/load the sorter_name, sglx_si_recording object etc. - # - Store in shared EC2 space accessible by all containers (needs to be mounted) - # - Load into the cloud init script, and - # - Option A: Can call this function within a separate container within spike_sorting_worker - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - sorting_kilosort = si.full.run_sorter( - sorter_name=sorter_name, - recording=sglx_si_recording, - output_folder=kilosort_dir, - docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", - **params, - ) - sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" - sorting_kilosort.dump_to_pickle(sorting_save_path) - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = si.core.load_extractor(recording_fullpath) - assert len(oe_probe.recording_info["recording_files"]) == 1 - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - sorting_kilosort = si.full.run_sorter( - sorter_name=sorter_name, - recording=oe_si_recording, - output_folder=kilosort_dir, - docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", - **params, - ) - sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" - sorting_kilosort.dump_to_pickle(sorting_save_path) + # Run sorting + sorting_save_path = sorter_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path) self.insert1( { **key, - "sorting_filename": list(sorting_save_path.parts)[-1], + "sorter_name": sorter_name, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 2ed337bf0fa8245a7f8d481dc779b072cfcbadd0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 20 Dec 2023 15:57:55 -0600 Subject: [PATCH 050/152] feat: :sparkles: add PostProcessing table & clean up --- .../spike_sorting/si_clustering.py | 275 +++--------------- 1 file changed, 45 insertions(+), 230 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index debc4336..935d7360 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -25,7 +25,7 @@ import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory -from spikeinterface import exporters, qualitymetrics, sorters +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -222,126 +222,58 @@ class PostProcessing(dj.Imported): -> SIClustering --- execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_duration: float # execution duration in hours """ def make(self, key): execution_time = datetime.utcnow() + JOB_KWARGS = dict(n_jobs=-1, chunk_size=30000) + # Load sorting & recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (PreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - recording_filename = (PreProcessing & key).fetch1("recording_filename") - sorting_file = kilosort_dir / "sorting_kilosort" - filtered_recording_file = kilosort_dir / recording_filename - sglx_si_recording_filtered = si.core.load_extractor(recording_file) - sorting_kilosort = si.core.load_extractor(sorting_file) - - we_kilosort = si.full.WaveformExtractor.create( - sglx_si_recording_filtered, - sorting_kilosort, - "waveforms", - remove_if_exists=True, - ) - we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) - unit_id0 = sorting_kilosort.unit_ids[0] - waveforms = we_kilosort.get_waveforms(unit_id0) - template = we_kilosort.get_template(unit_id0) - snrs = si.full.compute_snrs(we_kilosort) - - # QC Metrics - ( - si_violations_ratio, - isi_violations_rate, - isi_violations_count, - ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.full.compute_quality_metrics( - we_kilosort, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - # "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - si.exporters.export_report( - we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 - ) - # ["firing_rate","snr","presence_ratio","isi_violation", - # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", - # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) - we_savedir = kilosort_dir / "we_kilosort" - we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) - - elif acq_software == "Open Ephys": - sorting_file = kilosort_dir / "sorting_kilosort" - recording_file = kilosort_dir / "sglx_recording_cmr.json" - oe_si_recording = si.core.load_extractor(recording_file) - sorting_kilosort = si.core.load_extractor(sorting_file) - - we_kilosort = si.full.WaveformExtractor.create( - oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True - ) - we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) - unit_id0 = sorting_kilosort.unit_ids[0] - waveforms = we_kilosort.get_waveforms(unit_id0) - template = we_kilosort.get_template(unit_id0) - snrs = si.full.compute_snrs(we_kilosort) - - # QC Metrics - # Apply waveform extractor extensions - _ = si.full.compute_spike_locations(we_kilosort) - _ = si.full.compute_spike_amplitudes(we_kilosort) - _ = si.full.compute_unit_locations(we_kilosort) - _ = si.full.compute_template_metrics(we_kilosort) - _ = si.full.compute_noise_levels(we_kilosort) - _ = si.full.compute_principal_components(we_kilosort) - _ = si.full.compute_drift_metrics(we_kilosort) - _ = si.full.compute_tempoate_similarity(we_kilosort) - ( - isi_violations_ratio, - isi_violations_count, - ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - (isi_histograms, bins) = si.full.compute_isi_histograms(we_kilosort) - metrics = si.full.compute_quality_metrics( - we_kilosort, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - # "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - si.exporters.export_report( - we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 - ) - we_savedir = kilosort_dir / "we_kilosort" - we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" + sorter_dir = output_dir / key["sorter_name"] + sorting_file = sorter_dir / "si_sorting.pkl" + + si_recording: si.BaseRecording = si.load_extractor(recording_file) + si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + + # Extract waveforms + we: si.WaveformExtractor = si.extract_waveforms( + si_recording, + si_sorting, + folder=sorter_dir / "waveform", # The folder where waveforms are cached + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=500, + overwrite=True, + **JOB_KWARGS, + ) - metrics_savefile = kilosort_dir / "metrics.csv" - metrics.to_csv(metrics_savefile) + # Calculate QC Metrics + metrics: pd.DataFrame = si.qualitymetrics.compute_quality_metrics( + we, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) + # Add PCA based metrics. These will be added to the metrics dataframe above. + _ = si.postprocessing.compute_principal_components( + waveform_extractor=we, n_components=5, mode="by_channel_local" + ) # TODO: the parameters need to be checked + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + # Save results self.insert1( { **key, @@ -357,120 +289,3 @@ def make(self, key): ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -# def runPreProcessList(preprocess_list, recording): -# # If else -# # need to figure out ordering -# if preprocess_list["Filter"]: -# recording = si.preprocessing.FilterRecording(recording) -# if preprocess_list["BandpassFilter"]: -# recording = si.preprocessing.BandpassFilterRecording(recording) -# if preprocess_list["HighpassFilter"]: -# recording = si.preprocessing.HighpassFilterRecording(recording) -# if preprocess_list["NormalizeByQuantile"]: -# recording = si.preprocessing.NormalizeByQuantileRecording(recording) -# if preprocess_list["Scale"]: -# recording = si.preprocessing.ScaleRecording(recording) -# if preprocess_list["Center"]: -# recording = si.preprocessing.CenterRecording(recording) -# if preprocess_list["ZScore"]: -# recording = si.preprocessing.ZScoreRecording(recording) -# if preprocess_list["Whiten"]: -# recording = si.preprocessing.WhitenRecording(recording) -# if preprocess_list["CommonReference"]: -# recording = si.preprocessing.CommonReferenceRecording(recording) -# if preprocess_list["PhaseShift"]: -# recording = si.preprocessing.PhaseShiftRecording(recording) -# elif preprocess_list["Rectify"]: -# recording = si.preprocessing.RectifyRecording(recording) -# elif preprocess_list["Clip"]: -# recording = si.preprocessing.ClipRecording(recording) -# elif preprocess_list["BlankSaturation"]: -# recording = si.preprocessing.BlankSaturationRecording(recording) -# elif preprocess_list["RemoveArtifacts"]: -# recording = si.preprocessing.RemoveArtifactsRecording(recording) -# elif preprocess_list["RemoveBadChannels"]: -# recording = si.preprocessing.RemoveBadChannelsRecording(recording) -# elif preprocess_list["ZeroChannelPad"]: -# recording = si.preprocessing.ZeroChannelPadRecording(recording) -# elif preprocess_list["DeepInterpolation"]: -# recording = si.preprocessing.DeepInterpolationRecording(recording) -# elif preprocess_list["Resample"]: -# recording = si.preprocessing.ResampleRecording(recording) - - -def mimic_IBLdestriping_modified(recording): - # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) - recording = si.full.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) - # For IBL destriping interpolate bad channels - recording = recording.remove_channels(bad_channel_ids) - recording = si.full.phase_shift(recording) - recording = si.full.common_reference( - recording, operator="median", reference="global" - ) - return recording - - -def mimic_IBLdestriping(recording): - # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. - recording = si.full.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) - # For IBL destriping interpolate bad channels - recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) - recording = si.full.phase_shift(recording) - # For IBL destriping use highpass_spatial_filter used instead of common reference - recording = si.full.highpass_spatial_filter( - recording, operator="median", reference="global" - ) - return recording - - -def mimic_catGT(sglx_recording): - sglx_recording = si.full.phase_shift(sglx_recording) - sglx_recording = si.full.common_reference( - sglx_recording, operator="median", reference="global" - ) - return sglx_recording - - -## Example SI parameter set -""" -{'detect_threshold': 6, - 'projection_threshold': [10, 4], - 'preclust_threshold': 8, - 'car': True, - 'minFR': 0.02, - 'minfr_goodchannels': 0.1, - 'nblocks': 5, - 'sig': 20, - 'freq_min': 150, - 'sigmaMask': 30, - 'nPCs': 3, - 'ntbuff': 64, - 'nfilt_factor': 4, - 'NT': None, - 'do_correction': True, - 'wave_length': 61, - 'keep_good_only': False, - 'PreProcessing_params': {'Filter': False, - 'BandpassFilter': True, - 'HighpassFilter': False, - 'NotchFilter': False, - 'NormalizeByQuantile': False, - 'Scale': False, - 'Center': False, - 'ZScore': False, - 'Whiten': False, - 'CommonReference': False, - 'PhaseShift': False, - 'Rectify': False, - 'Clip': False, - 'BlankSaturation': False, - 'RemoveArtifacts': False, - 'RemoveBadChannels': False, - 'ZeroChannelPad': False, - 'DeepInterpolation': False, - 'Resample': False}} -""" From f6e3e4624255b9b42e89de3e520c8696bc60089f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 2 Jan 2024 18:04:25 -0600 Subject: [PATCH 051/152] fix: :bug: fix input/output data directory --- .../spike_sorting/si_clustering.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 935d7360..80449c88 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -18,6 +18,7 @@ - quality_metrics """ +import pathlib from datetime import datetime import datajoint as dj @@ -111,11 +112,19 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + output_full_dir = find_full_path( + ephys.get_ephys_root_data_dir(), output_dir + ) # output directory in the processed data directory # Create SI recording extractor object + data_dir = ( + ephys.get_ephys_root_data_dir()[0] / pathlib.Path(output_dir).parent + ) # raw data directory + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=output_dir + folder_path=data_dir, stream_name=stream_names[0] ) # Add probe information to recording object @@ -142,7 +151,7 @@ def make(self, key): "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, }[preprocessing_method](si_recording) - recording_file = output_dir / "si_recording.pkl" + recording_file = output_full_dir / "si_recording.pkl" si_recording.dump_to_pickle(file_path=recording_file) self.insert1( From e1c0d689d6b7958c231389e8d11b7ef2e326657f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 3 Jan 2024 11:47:03 -0600 Subject: [PATCH 052/152] check for presence of recording file --- .../spike_sorting/si_clustering.py | 107 ++++++++++-------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 80449c88..8a5adffb 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -25,7 +25,7 @@ import pandas as pd import probeinterface as pi import spikeinterface as si -from element_interface.utils import find_full_path, find_root_directory +from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -112,58 +112,65 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) + output_dir = pathlib.Path(output_dir) output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir - ) # output directory in the processed data directory - - # Create SI recording extractor object - data_dir = ( - ephys.get_ephys_root_data_dir()[0] / pathlib.Path(output_dir).parent - ) # raw data directory - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir - ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=data_dir, stream_name=stream_names[0] - ) - - # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") - electrodes_df = ( - ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key + ephys.get_ephys_root_data_dir(), output_dir.parent + ) # recording object will be stored in the parent dir since it can be re-used for multiple sorters + + recording_file = ( + output_full_dir / "si_recording.pkl" + ) # recording cache to be created for each key + + if not recording_file.exists(): # skip if si_recording.pkl already exists + # Create SI recording extractor object + data_dir = ( + ephys.get_ephys_root_data_dir()[0] / output_dir.parent + ) # raw data directory + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=data_dir, stream_name=stream_names[0] ) - .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] - ) - - # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_recording.set_probe(probe=si_probe, in_place=True) - - # Run preprocessing and save results to output folder - preprocessing_method = "catGT" # where to load this info? - si_recording = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, - }[preprocessing_method](si_recording) - recording_file = output_full_dir / "si_recording.pkl" - si_recording.dump_to_pickle(file_path=recording_file) - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key + ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_recording.set_probe(probe=si_probe, in_place=True) + + # Run preprocessing and save results to output folder + preprocessing_method = "catGT" # where to load this info? + si_recording = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, + }[preprocessing_method](si_recording) + si_recording.dump_to_pickle(file_path=recording_file) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema From 653e7e84bcacd3cf7ae382e5236b732db500ad06 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 3 Jan 2024 15:05:53 -0600 Subject: [PATCH 053/152] fix: :bug: fix path & typo --- element_array_ephys/spike_sorting/si_clustering.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 8a5adffb..32804645 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -190,8 +190,8 @@ def make(self, key): # Load recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" + output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_full_dir.parent / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. @@ -199,9 +199,9 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "params") sorter_name = ( - "kilosort_2_5" if clustering_method == "kilsort2.5" else clustering_method + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - sorter_dir = output_dir / sorter_name + sorter_dir = output_full_dir / sorter_name # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( From 8c25bd21f69bb5151508729f559f7515b8fc3d08 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jan 2024 12:56:42 -0600 Subject: [PATCH 054/152] code review --- .../{preprocessing.py => si_preprocessing.py} | 2 +- .../{si_clustering.py => si_spike_sorting.py} | 131 +++++++++--------- 2 files changed, 70 insertions(+), 63 deletions(-) rename element_array_ephys/spike_sorting/{preprocessing.py => si_preprocessing.py} (98%) rename element_array_ephys/spike_sorting/{si_clustering.py => si_spike_sorting.py} (72%) diff --git a/element_array_ephys/spike_sorting/preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py similarity index 98% rename from element_array_ephys/spike_sorting/preprocessing.py rename to element_array_ephys/spike_sorting/si_preprocessing.py index 77a95792..2edf443d 100644 --- a/element_array_ephys/spike_sorting/preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -37,7 +37,7 @@ def mimic_IBLdestriping_modified(recording): return recording -_preprocessing_function = { +preprocessing_function_mapping = { "catGT": mimic_catGT, "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_spike_sorting.py similarity index 72% rename from element_array_ephys/spike_sorting/si_clustering.py rename to element_array_ephys/spike_sorting/si_spike_sorting.py index 32804645..f491b5b5 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -30,11 +30,7 @@ from element_array_ephys import get_logger, probe, readers -from .preprocessing import ( - mimic_catGT, - mimic_IBLdestriping, - mimic_IBLdestriping_modified, -) +from . import si_preprocessing log = get_logger(__name__) @@ -100,9 +96,13 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, output_dir = ( - ephys.ClusteringTask * ephys.EphysRecording & key - ).fetch1("acq_software", "clustering_output_dir") + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + for req_key in ("SI_PREPROCESSING_METHOD", "SI_SORTING_PARAMS", "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS"): + if req_key not in params: + raise ValueError(f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution") if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -114,63 +114,68 @@ def make(self, key): ) output_dir = pathlib.Path(output_dir) output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir.parent - ) # recording object will be stored in the parent dir since it can be re-used for multiple sorters + ephys.get_ephys_root_data_dir(), output_dir + ) recording_file = ( output_full_dir / "si_recording.pkl" ) # recording cache to be created for each key - if not recording_file.exists(): # skip if si_recording.pkl already exists - # Create SI recording extractor object - data_dir = ( - ephys.get_ephys_root_data_dir()[0] / output_dir.parent - ) # raw data directory - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir - ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=data_dir, stream_name=stream_names[0] - ) + # Create SI recording extractor object + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + data_dir = spikeglx_meta_filepath.parent + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + assert len(oe_probe.recording_info["recording_files"]) == 1 + data_dir = oe_probe.recording_info["recording_files"][0] + else: + raise NotImplementedError(f"Not implemented for {acq_software}") + + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=data_dir, stream_name=stream_names[0] + ) - # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") - electrodes_df = ( - ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ) - .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] - ) - channels_details = ephys.get_recording_channels_details(key) - - # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - si_recording.set_probe(probe=si_probe, in_place=True) - - # Run preprocessing and save results to output folder - preprocessing_method = "catGT" # where to load this info? - si_recording = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, - }[preprocessing_method](si_recording) - si_recording.dump_to_pickle(file_path=recording_file) - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_recording.set_probe(probe=si_probe, in_place=True) + + # Run preprocessing and save results to output folder + preprocessing_method = params["SI_PREPROCESSING_METHOD"] + si_preproc_func = si_preprocessing.preprocessing_function_mapping[preprocessing_method] + si_recording = si_preproc_func(si_recording) + si_recording.dump_to_pickle(file_path=recording_file) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema @@ -203,6 +208,8 @@ def make(self, key): ) sorter_dir = output_full_dir / sorter_name + si_sorting_params = params["SI_SORTING_PARAMS"] + # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, @@ -210,7 +217,7 @@ def make(self, key): output_folder=sorter_dir, verbse=True, docker_image=True, - **params, + **si_sorting_params, ) # Run sorting @@ -255,14 +262,14 @@ def make(self, key): si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + si_waveform_extraction_params = params["SI_WAVEFORM_EXTRACTION_PARAMS"] + # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, folder=sorter_dir / "waveform", # The folder where waveforms are cached - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, + **si_waveform_extraction_params overwrite=True, **JOB_KWARGS, ) From 90d1ba4e4c87493c4e79ae93ed139ee0f5e3218f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 1 Feb 2024 15:27:57 -0600 Subject: [PATCH 055/152] feat: :sparkles: modify QualityMetrics make function --- element_array_ephys/ephys_no_curation.py | 34 +++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8ee7ee8b..ca293d95 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1358,24 +1358,34 @@ class Waveform(dj.Part): def make(self, key): """Populates tables with quality metrics data.""" + # Load metrics.csv output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + metric_fp = output_dir / "metrics.csv" if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) + + # Conform the dataframe to match the table definition + if "cluster_id" in metrics_df.columns: + metrics_df.set_index("cluster_id", inplace=True) + else: + metrics_df.rename( + columns={metrics_df.columns[0]: "cluster_id"}, inplace=True + ) + metrics_df.set_index("cluster_id", inplace=True) metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) + + metrics_df.rename( + columns={ + "isi_viol": "isi_violation", + "num_viol": "number_violation", + "contam_rate": "contamination_rate", + }, + inplace=True, + ) + metrics_list = [ dict(metrics_df.loc[unit_key["unit"]], **unit_key) for unit_key in (CuratedClustering.Unit & key).fetch("KEY") From cacefaceb0246bbc7b1a86b7e9afbaff335f5a86 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 6 Feb 2024 21:02:47 +0000 Subject: [PATCH 056/152] update si_spike_sorting.PreProcessing make function --- .../spike_sorting/si_spike_sorting.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index f491b5b5..82729fe7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -96,13 +96,20 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, clustering_method, params = ( + acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - for req_key in ("SI_PREPROCESSING_METHOD", "SI_SORTING_PARAMS", "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS"): + ).fetch1("acq_software", "clustering_output_dir", "params") + + for req_key in ( + "SI_SORTING_PARAMS", + "SI_PREPROCESSING_METHOD", + "SI_WAVEFORM_EXTRACTION_PARAMS", + "SI_QUALITY_METRICS_PARAMS", + ): if req_key not in params: - raise ValueError(f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution") + raise ValueError( + f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" + ) if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -113,9 +120,7 @@ def make(self, key): {**key, "clustering_output_dir": output_dir.as_posix()} ) output_dir = pathlib.Path(output_dir) - output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir - ) + output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = ( output_full_dir / "si_recording.pkl" @@ -124,7 +129,9 @@ def make(self, key): # Create SI recording extractor object if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording = readers.spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent elif acq_software == "Open Ephys": @@ -161,8 +168,9 @@ def make(self, key): si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder - preprocessing_method = params["SI_PREPROCESSING_METHOD"] - si_preproc_func = si_preprocessing.preprocessing_function_mapping[preprocessing_method] + si_preproc_func = si_preprocessing.preprocessing_function_mapping[ + params["SI_PREPROCESSING_METHOD"] + ] si_recording = si_preproc_func(si_recording) si_recording.dump_to_pickle(file_path=recording_file) From f98e1ed332c7c5ca61b9321e461c0101b063a064 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:20:15 +0000 Subject: [PATCH 057/152] update SIClustering make function --- .../spike_sorting/si_spike_sorting.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 82729fe7..a2f2db24 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -25,11 +25,10 @@ import pandas as pd import probeinterface as pi import spikeinterface as si +from element_array_ephys import get_logger, probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters -from element_array_ephys import get_logger, probe, readers - from . import si_preprocessing log = get_logger(__name__) @@ -202,34 +201,31 @@ def make(self, key): execution_time = datetime.utcnow() # Load recording object. - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_full_dir.parent / "si_recording.pkl" + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. - clustering_method, params = ( - ephys.ClusteringTask * ephys.ClusteringParamSet & key - ).fetch1("clustering_method", "params") sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - sorter_dir = output_full_dir / sorter_name - - si_sorting_params = params["SI_SORTING_PARAMS"] # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorter_dir, - verbse=True, + output_folder=output_dir / sorter_name, + remove_existing_folder=True, + verbose=True, docker_image=True, - **si_sorting_params, + **params.get("SI_SORTING_PARAMS", {}), ) # Run sorting - sorting_save_path = sorter_dir / "si_sorting.pkl" + sorting_save_path = output_dir / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path) self.insert1( From 7a060ef5b875dd74189841aa5a72d3eea55d7dda Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:21:16 +0000 Subject: [PATCH 058/152] update PostProcessing make function --- .../spike_sorting/si_spike_sorting.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a2f2db24..5eb2b822 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -254,28 +254,26 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - JOB_KWARGS = dict(n_jobs=-1, chunk_size=30000) # Load sorting & recording object. - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + output_dir, params = (ephys.ClusteringTask & key).fetch1( + "clustering_output_dir", "params" + ) output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / "si_recording.pkl" - sorter_dir = output_dir / key["sorter_name"] - sorting_file = sorter_dir / "si_sorting.pkl" + sorting_file = output_dir / "si_sorting.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) - si_waveform_extraction_params = params["SI_WAVEFORM_EXTRACTION_PARAMS"] - # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=sorter_dir / "waveform", # The folder where waveforms are cached - **si_waveform_extraction_params + folder=output_dir / "waveform", # The folder where waveforms are cached overwrite=True, - **JOB_KWARGS, + **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), + **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) # Calculate QC Metrics @@ -296,9 +294,11 @@ def make(self, key): ) # Add PCA based metrics. These will be added to the metrics dataframe above. _ = si.postprocessing.compute_principal_components( - waveform_extractor=we, n_components=5, mode="by_channel_local" - ) # TODO: the parameters need to be checked + waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) + ) + # Save the output (metrics.csv to the output dir) metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + metrics.to_csv(output_dir / "metrics.csv") # Save results self.insert1( @@ -312,7 +312,7 @@ def make(self, key): } ) - # all finished, insert this `key` into ephys.Clustering + # Once finished, insert this `key` into ephys.Clustering ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) From 6daf0c5a1f3f40d3700ce09bfa047326b4477cab Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:01:52 -0600 Subject: [PATCH 059/152] feat: :sparkles: add n.a. to ClusterQualityLabel --- element_array_ephys/ephys_no_curation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index ca293d95..aa743598 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -703,6 +703,7 @@ class ClusterQualityLabel(dj.Lookup): ("ok", "probably a single unit, but could be contaminated"), ("mua", "multi-unit activity"), ("noise", "bad unit"), + ("n.a.", "not available"), ] From e41ff1daeaddeb2847b426893f430906a86e91d6 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:15:05 -0600 Subject: [PATCH 060/152] extract all waveforms --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 5eb2b822..432b6c10 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -271,6 +271,7 @@ def make(self, key): si_recording, si_sorting, folder=output_dir / "waveform", # The folder where waveforms are cached + max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), From e8d9854f4014302248d483604faaa2b4f2858fc2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 19:15:34 -0600 Subject: [PATCH 061/152] feat: :sparkles: modify CuratedClustering make function for spike interface --- element_array_ephys/ephys_no_curation.py | 195 ++++++++++++++++------- 1 file changed, 138 insertions(+), 57 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index aa743598..70ce87cf 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -959,75 +959,156 @@ class Unit(dj.Part): def make(self, key): """Automated population of Unit information.""" output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) + if (output_dir / "waveform").exists(): # read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms( + output_dir / "waveform", with_recording=False + ) + si_sorting: si.sorters.BaseSorter = si.load_extractor( + output_dir / "sorting.pkl" + ) - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) + unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( + we, outputs="index" + ) # {unit: peak_channel_index} - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() + spike_count_dict = dict[int, int] = si_sorting.count_num_spikes_per_unit() + # {unit: spike_count} - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) + spikes = si_sorting.to_spike_vector( + extremum_channel_inds=unit_peak_channel_map + ) + + # Get electrode info + electrode_config_key = ( + EphysRecording * probe.ElectrodeConfig & key + ).fetch1("KEY") + + electrode_query = ( + probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + channel2electrode_map = dict( + zip(*electrode_query.fetch("channel", "electrode")) + ) + + # Get channel to electrode mapping + channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) + + peak_electrode_ind = np.array( + [ + channel2electrode_map[unit_peak_channel_map[unit_id]] + for unit_id in si_sorting.unit_ids + ] + ) + + # Get channel to depth mapping + electrode_depth_ind = np.array( + [ + channel2depth_map[unit_peak_channel_map[unit_id]] + for unit_id in si_sorting.unit_ids + ] + ) + spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] + spikes["depth"] = electrode_depth_ind[spikes["unit_index"]] + + units = [] + for unit_id in si_sorting.unit_ids: + unit_id = int(unit_id) units.append( { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit + "unit": unit_id, + "cluster_quality_label": "n.a.", + "spike_times": si_sorting.get_unit_spike_train( + unit_id, return_times=True + ), + "spike_count": spike_count_dict[unit_id], + "spike_sites": spikes["electrode"][ + spikes["unit_index"] == unit_id ], - "spike_depths": spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit + "spike_depths": spikes["depth"][ + spikes["unit_index"] == unit_id ], } ) + else: + kilosort_dataset = kilosort.Kilosort(output_dir) + acq_software, sample_rate = (EphysRecording & key).fetch1( + "acq_software", "sampling_rate" + ) + + sample_rate = kilosort_dataset.data["params"].get( + "sample_rate", sample_rate + ) + + # ---------- Unit ---------- + # -- Remove 0-spike units + withspike_idx = [ + i + for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) + if (kilosort_dataset.data["spike_clusters"] == u).any() + ] + valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] + valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] + + # -- Spike-times -- + # spike_times_sec_adj > spike_times_sec > spike_times + spike_time_key = ( + "spike_times_sec_adj" + if "spike_times_sec_adj" in kilosort_dataset.data + else ( + "spike_times_sec" + if "spike_times_sec" in kilosort_dataset.data + else "spike_times" + ) + ) + spike_times = kilosort_dataset.data[spike_time_key] + kilosort_dataset.extract_spike_depths() + + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software + ) + + # -- Spike-sites and Spike-depths -- + spike_sites = np.array( + [ + channel2electrodes[s]["electrode"] + for s in kilosort_dataset.data["spike_sites"] + ] + ) + spike_depths = kilosort_dataset.data["spike_depths"] + + # -- Insert unit, label, peak-chn + units = [] + for unit, unit_lbl in zip(valid_units, valid_unit_labels): + if (kilosort_dataset.data["spike_clusters"] == unit).any(): + unit_channel, _ = kilosort_dataset.get_best_channel(unit) + unit_spike_times = ( + spike_times[kilosort_dataset.data["spike_clusters"] == unit] + / sample_rate + ) + spike_count = len(unit_spike_times) + + units.append( + { + "unit": unit, + "cluster_quality_label": unit_lbl, + **channel2electrodes[unit_channel], + "spike_times": unit_spike_times, + "spike_count": spike_count, + "spike_sites": spike_sites[ + kilosort_dataset.data["spike_clusters"] == unit + ], + "spike_depths": spike_depths[ + kilosort_dataset.data["spike_clusters"] == unit + ], + } + ) + self.insert1(key) self.Unit.insert([{**key, **u} for u in units]) From 00b82f81017fdb92459cd334cda8bb3dc49b0fde Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 19:17:15 -0600 Subject: [PATCH 062/152] refactor: :recycle: import si module & re-organize imports --- element_array_ephys/ephys_no_curation.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 70ce87cf..63e72951 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1,17 +1,18 @@ -import datajoint as dj +import gc +import importlib +import inspect import pathlib import re -import numpy as np -import inspect -import importlib -import gc from decimal import Decimal -import pandas as pd -from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid -from .readers import spikeglx, kilosort, openephys -from element_array_ephys import probe, get_logger, ephys_report +import datajoint as dj +import numpy as np +import pandas as pd +from element_array_ephys import ephys_report, get_logger, probe +from element_interface.utils import (dict_to_uuid, find_full_path, + find_root_directory) +from .readers import kilosort, openephys, spikeglx log = get_logger(__name__) @@ -19,8 +20,8 @@ _linking_module = None -import spikeinterface -import spikeinterface.full as si +import spikeinterface as si +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters def activate( From b01c36c81595e8e4cb38e22f2dee146986508b79 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 12 Feb 2024 09:39:04 -0600 Subject: [PATCH 063/152] update WaveformSet ingestion --- element_array_ephys/ephys_no_curation.py | 368 ++++++++++++++--------- 1 file changed, 218 insertions(+), 150 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 63e72951..4887096c 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1168,177 +1168,245 @@ class Waveform(dj.Part): def make(self, key): """Populates waveform tables.""" output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - kilosort_dataset = kilosort.Kilosort(kilosort_dir) + if (output_dir / "waveform").exists(): # read from spikeinterface outputs - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") + we: si.WaveformExtractor = si.load_waveforms( + output_dir / "waveform", with_recording=False + ) + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) + units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } + # Get electrode info + electrode_config_key = ( + EphysRecording * probe.ElectrodeConfig & key + ).fetch1("KEY") - waveforms_folder = [ - f for f in kilosort_dir.parent.rglob(r"*/waveforms*") if f.is_dir() - ] + electrode_query = ( + probe.ProbeType.Electrode.proj() * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + electrode_info = electrode_query.fetch( + "KEY", order_by="electrode", as_dict=True + ) - if (kilosort_dir / "mean_waveforms.npy").exists(): - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample + # Get mean waveform for each unit from all channels + mean_waveforms = we.get_all_templates( + mode="average" + ) # (unit x sample x channel) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - # Spike interface mean and peak waveform extraction from we object - - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): - we_kilosort = si.load_waveforms(waveforms_folder[0].parent) - unit_templates = we_kilosort.get_all_templates() - unit_waveforms = np.reshape( - unit_templates, - ( - unit_templates.shape[1], - unit_templates.shape[3], - unit_templates.shape[2], - ), + unit_peak_waveform = [] + unit_electrode_waveforms = [] + + for unit in units: + unit_peak_waveform.append( + { + **unit, + "peak_electrode_waveform": we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + )[:, unit_id_to_peak_channel_indices[unit["unit"]][0]], + } + ) + + unit_electrode_waveforms.extend( + [ + { + **unit, + **e, + "waveform_mean": mean_waveforms[ + unit["unit"], :, e["electrode"] + ], + } + for e in electrode_info + ] + ) + + self.insert1(key) + self.PeakWaveform.insert(unit_peak_waveform) + self.Waveform.insert(unit_electrode_waveforms) + + else: + kilosort_dataset = kilosort.Kilosort(output_dir) + + acq_software, probe_serial_number = ( + EphysRecording * ProbeInsertion & key + ).fetch1("acq_software", "probe") + + # -- Get channel and electrode-site mapping + recording_key = (EphysRecording & key).fetch1("KEY") + channel2electrodes = get_neuropixels_channel2electrode_map( + recording_key, acq_software ) - # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: + # Get all units + units = { + u["unit"]: u + for u in (CuratedClustering.Unit & key).fetch( + as_dict=True, order_by="unit" + ) + } + + waveforms_folder = [ + f for f in output_dir.parent.rglob(r"*/waveforms*") if f.is_dir() + ] + + if (output_dir / "mean_waveforms.npy").exists(): + unit_waveforms = np.load( + output_dir / "mean_waveforms.npy" + ) # unit x channel x sample + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms + + # Spike interface mean and peak waveform extraction from we object + + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): + we_kilosort = si.load_waveforms(waveforms_folder[0].parent) + unit_templates = we_kilosort.get_all_templates() + unit_waveforms = np.reshape( + unit_templates, + ( + unit_templates.shape[1], + unit_templates.shape[3], + unit_templates.shape[2], + ), + ) + + # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms + + # Approach not using spike interface templates (ie. taking mean of each unit waveform) + # def yield_unit_waveforms(): + # for unit_id in we_kilosort.unit_ids: + # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) + # unit_peak_waveform = {} + # unit_electrode_waveforms = [] + # if unit_id in units: + # for channel, channel_waveform in zip( + # kilosort_dataset.data["channel_map"], unit_waveform + # ): + # unit_electrode_waveforms.append( + # { + # **units[unit_id], + # **channel2electrodes[channel], + # "waveform_mean": channel_waveform, + # } + # ) + # if ( + # channel2electrodes[channel]["electrode"] + # == units[unit_id]["electrode"] + # ): + # unit_peak_waveform = { + # **units[unit_id], + # "peak_electrode_waveform": channel_waveform, + # } + # yield unit_peak_waveform, unit_electrode_waveforms + + else: + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + neuropixels_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + elif acq_software == "Open Ephys": + session_dir = find_full_path( + get_ephys_root_data_dir(), get_session_directory(key) + ) + openephys_dataset = openephys.OpenEphys(session_dir) + neuropixels_recording = openephys_dataset.probes[ + probe_serial_number + ] + + def yield_unit_waveforms(): + for unit_dict in units.values(): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + spikes = unit_dict["spike_times"] + waveforms = neuropixels_recording.extract_spike_waveforms( + spikes, kilosort_dataset.data["channel_map"] + ) # (sample x channel x spike) + waveforms = waveforms.transpose( + (1, 2, 0) + ) # (channel x spike x sample) for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform + kilosort_dataset.data["channel_map"], waveforms ): unit_electrode_waveforms.append( { - **units[unit_no], + **unit_dict, **channel2electrodes[channel], - "waveform_mean": channel_waveform, + "waveform_mean": channel_waveform.mean(axis=0), + "waveforms": channel_waveform, } ) if ( channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] + == unit_dict["electrode"] ): unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, + **unit_dict, + "peak_electrode_waveform": channel_waveform.mean( + axis=0 + ), } - yield unit_peak_waveform, unit_electrode_waveforms - - # Approach not using spike interface templates (ie. taking mean of each unit waveform) - # def yield_unit_waveforms(): - # for unit_id in we_kilosort.unit_ids: - # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) - # unit_peak_waveform = {} - # unit_electrode_waveforms = [] - # if unit_id in units: - # for channel, channel_waveform in zip( - # kilosort_dataset.data["channel_map"], unit_waveform - # ): - # unit_electrode_waveforms.append( - # { - # **units[unit_id], - # **channel2electrodes[channel], - # "waveform_mean": channel_waveform, - # } - # ) - # if ( - # channel2electrodes[channel]["electrode"] - # == units[unit_id]["electrode"] - # ): - # unit_peak_waveform = { - # **units[unit_id], - # "peak_electrode_waveform": channel_waveform, - # } - # yield unit_peak_waveform, unit_electrode_waveforms - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms + yield unit_peak_waveform, unit_electrode_waveforms # insert waveform on a per-unit basis to mitigate potential memory issue self.insert1(key) @@ -1448,7 +1516,7 @@ def make(self, key): if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") metrics_df = pd.read_csv(metric_fp) - + # Conform the dataframe to match the table definition if "cluster_id" in metrics_df.columns: metrics_df.set_index("cluster_id", inplace=True) From 853b66f5bd39edd82ed14de635064f24c52855e4 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:39 -0600 Subject: [PATCH 064/152] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 4887096c..c82f986f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -319,8 +319,8 @@ def make(self, key): break else: raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" + "Ephys recording data not found!" + "Neither SpikeGLX nor Open Ephys recording files found" ) supported_probe_types = probe.ProbeType.fetch("probe_type") From 67ebf4e994411617826263782142fcfc270b98f0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:47 -0600 Subject: [PATCH 065/152] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index c82f986f..0efddf9a 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -618,7 +618,7 @@ class ClusteringParamSet(dj.Lookup): ClusteringMethod (dict): ClusteringMethod primary key. paramset_desc (varchar(128) ): Description of the clustering parameter set. param_set_hash (uuid): UUID hash for the parameter set. - params (longblob) + params (longblob): Set of clustering parameters. """ definition = """ From 5fe60434655f8f0c9f8cb2b2ecaa63e4a3a28e2a Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:51 -0600 Subject: [PATCH 066/152] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 0efddf9a..92224409 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1488,7 +1488,7 @@ class Waveform(dj.Part): recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float) inverse velocity of waveform propagation from soma toward the bottom of the probe. + velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. """ definition = """ From ac08163cb55ac46d97a1b5995d965be6e8140ed8 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:48:41 -0600 Subject: [PATCH 067/152] ci: run test only on the main branch --- .github/workflows/test.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index acaddca0..fec7ce0c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,11 +1,13 @@ name: Test on: push: + branches: + - main pull_request: + branches: + - main workflow_dispatch: jobs: - # devcontainer-build: - # uses: datajoint/.github/.github/workflows/devcontainer-build.yaml@main tests: runs-on: ubuntu-latest strategy: @@ -31,4 +33,3 @@ jobs: run: | python_version=${{matrix.py_ver}} black element_array_ephys --check --verbose --target-version py${python_version//.} - From e95331c54babe966ccbb6ce902463eb48c869c6d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 14:33:26 -0600 Subject: [PATCH 068/152] build: :heavy_plus_sign: add spikingcircus dependencies --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index f93247b6..ebf5d114 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from os import path -from setuptools import find_packages, setup +from setuptools import find_packages, setup pkg_name = "element_array_ephys" here = path.abspath(path.dirname(__file__)) @@ -16,7 +16,7 @@ setup( name=pkg_name.replace("_", "-"), - python_requires='>=3.7, <3.11', + python_requires=">=3.7, <3.11", version=__version__, # noqa F821 description="Extracellular Array Electrophysiology DataJoint Element", long_description=long_description, @@ -50,5 +50,6 @@ ], "nwb": ["dandi", "neuroconv[ecephys]", "pynwb"], "tests": ["pre-commit", "pytest", "pytest-cov"], + "spikingcircus": ["hdbscan", "numba"], }, ) From be5135e05ba40bb2054a18e9b00486ea1ba411d0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 16:00:47 -0600 Subject: [PATCH 069/152] refactor: fix typo & black formatting --- element_array_ephys/ephys_no_curation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 76fdeefc..b105f8f8 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -35,7 +35,7 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe scehma. + probe_schema_name (str): A string containing the name of the probe schema. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -1174,11 +1174,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 4b6fc0e9fe45f2a6b44be466f0731faad301734c Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 14 Feb 2024 20:10:56 -0600 Subject: [PATCH 070/152] feat: :sparkles: add EphysRecording.Channel part table --- element_array_ephys/ephys_no_curation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b105f8f8..25b2c147 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -284,6 +284,15 @@ class EphysRecording(dj.Imported): recording_duration: float # (seconds) duration of the recording from this probe """ + class Channel(dj.Part): + definitoin = """ + -> master + channel_idx: int # channel index + --- + -> probe.ElectrodeConfig.Electrode + channel_name="": varchar(64) + """ + class EphysFile(dj.Part): """Paths of electrophysiology recording files for each insertion. From 48025112da9acaab8b6e043b8da56e54a9e9725d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 14:25:45 -0600 Subject: [PATCH 071/152] fix: :bug: fix get_logger missing error --- element_array_ephys/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 1c0c7285..3a0e5af6 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1 +1,22 @@ +""" +isort:skip_file +""" + +import logging +import os + +import datajoint as dj + + +__all__ = ["ephys", "get_logger"] + +dj.config["enable_python_native_blobs"] = True + + +def get_logger(name): + log = logging.getLogger(name) + log.setLevel(os.getenv("LOGLEVEL", "INFO")) + return log + + from . import ephys_acute as ephys From bca67b3e47138cd2d7eaaf6c0b892557ff576786 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 21:03:28 +0000 Subject: [PATCH 072/152] fix typo & remove sorter_name --- element_array_ephys/ephys_no_curation.py | 2 +- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 25b2c147..5894fe16 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -285,7 +285,7 @@ class EphysRecording(dj.Imported): """ class Channel(dj.Part): - definitoin = """ + definition = """ -> master channel_idx: int # channel index --- diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 432b6c10..461987e6 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -191,7 +191,6 @@ class SIClustering(dj.Imported): definition = """ -> PreProcessing - sorter_name: varchar(30) # name of the sorter used --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours @@ -231,7 +230,6 @@ def make(self, key): self.insert1( { **key, - "sorter_name": sorter_name, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 1faa8f2e0e7f78f1097a220e4282ea4f8d6e7d1b Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 21:08:15 +0000 Subject: [PATCH 073/152] feat: :sparkles: add memoized_result implementation in SIClustering --- .../spike_sorting/si_spike_sorting.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 461987e6..6acd5a2b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path +from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -213,7 +213,17 @@ def make(self, key): ) # Run sorting - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + @memoized_result( + parameters={**key, **params}, + output_directory=output_dir / sorter_name, + ) + def _run_sorter(*args, **kwargs): + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) + sorting_save_path = output_dir / sorter_name / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path) + return sorting_save_path + + sorting_save_path = _run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name, From 58f3a4453a821e67bef49971d74f4e97d9f97ef2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:40:01 +0000 Subject: [PATCH 074/152] create a folder for storing recording pickle object --- .../spike_sorting/si_spike_sorting.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6acd5a2b..3205d056 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -95,11 +95,16 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, output_dir, params = ( + clustering_method, acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_output_dir", "params") - - for req_key in ( + ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") + + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + + for required_key in ( "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", "SI_WAVEFORM_EXTRACTION_PARAMS", @@ -110,6 +115,7 @@ def make(self, key): f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" ) + # Set directory to store recording file. if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( key, relative=True, mkdir=True @@ -118,11 +124,11 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) - output_dir = pathlib.Path(output_dir) - output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_dir = output_dir / sorter_name / "recording" + recording_dir.mkdir(parents=True, exist_ok=True) recording_file = ( - output_full_dir / "si_recording.pkl" + recording_dir / "si_recording.pkl" ) # recording cache to be created for each key # Create SI recording extractor object From 4fcea517577095232ef317e4a77421e3181b3632 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:41:31 +0000 Subject: [PATCH 075/152] install element_interface from datajoint upstream --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ebf5d114..532c72f6 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "element-lab @ git+https://github.com/datajoint/element-lab.git", "element-session @ git+https://github.com/datajoint/element-session.git", ], From 4f0e0204cd5b6c31d78576dd3c3ff9b88c5598c3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:52:49 +0000 Subject: [PATCH 076/152] add required_key for parameters --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 3205d056..d09a2c6b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -110,9 +110,9 @@ def make(self, key): "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS", ): - if req_key not in params: + if required_key not in params: raise ValueError( - f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" + f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution" ) # Set directory to store recording file. From 83e7a166c18e7b20d83e05a6c050ff02bde9b74e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:55:58 +0000 Subject: [PATCH 077/152] set recording channel info --- element_array_ephys/spike_sorting/si_spike_sorting.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index d09a2c6b..6bf3c4bd 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -165,11 +165,10 @@ def make(self, key): .fetch(format="frame") .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] ) - channels_details = ephys.get_recording_channels_details(key) - + # Create SI probe object si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_probe.set_device_channel_indices(range(len(electrodes_df))) si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder From 1d18a39cf22d81d3ce74aa591c5c8f4138895b01 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:57:55 +0000 Subject: [PATCH 078/152] fix loading preprocessor --- .../spike_sorting/si_preprocessing.py | 56 ++----------------- .../spike_sorting/si_spike_sorting.py | 4 +- 2 files changed, 5 insertions(+), 55 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 2edf443d..07a49293 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -2,7 +2,7 @@ from spikeinterface import preprocessing -def mimic_catGT(recording): +def catGT(recording): recording = si.preprocessing.phase_shift(recording) recording = si.preprocessing.common_reference( recording, operator="median", reference="global" @@ -10,7 +10,7 @@ def mimic_catGT(recording): return recording -def mimic_IBLdestriping(recording): +def IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) @@ -24,7 +24,7 @@ def mimic_IBLdestriping(recording): return recording -def mimic_IBLdestriping_modified(recording): +def IBLdestriping_modified(recording): # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) @@ -34,52 +34,4 @@ def mimic_IBLdestriping_modified(recording): recording = si.preprocessing.common_reference( recording, operator="median", reference="global" ) - return recording - - -preprocessing_function_mapping = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, -} - - -## Example SI parameter set -""" -{'detect_threshold': 6, - 'projection_threshold': [10, 4], - 'preclust_threshold': 8, - 'car': True, - 'minFR': 0.02, - 'minfr_goodchannels': 0.1, - 'nblocks': 5, - 'sig': 20, - 'freq_min': 150, - 'sigmaMask': 30, - 'nPCs': 3, - 'ntbuff': 64, - 'nfilt_factor': 4, - 'NT': None, - 'do_correction': True, - 'wave_length': 61, - 'keep_good_only': False, - 'PreProcessing_params': {'Filter': False, - 'BandpassFilter': True, - 'HighpassFilter': False, - 'NotchFilter': False, - 'NormalizeByQuantile': False, - 'Scale': False, - 'Center': False, - 'ZScore': False, - 'Whiten': False, - 'CommonReference': False, - 'PhaseShift': False, - 'Rectify': False, - 'Clip': False, - 'BlankSaturation': False, - 'RemoveArtifacts': False, - 'RemoveBadChannels': False, - 'ZeroChannelPad': False, - 'DeepInterpolation': False, - 'Resample': False}} -""" + return recording \ No newline at end of file diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6bf3c4bd..c8d2f1b7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -172,9 +172,7 @@ def make(self, key): si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder - si_preproc_func = si_preprocessing.preprocessing_function_mapping[ - params["SI_PREPROCESSING_METHOD"] - ] + si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) si_recording.dump_to_pickle(file_path=recording_file) From b0a863fec91f82a1171ae2c7041f76cf7dc612aa Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 23:18:42 +0000 Subject: [PATCH 079/152] make all output dir non-sharable --- .../spike_sorting/si_spike_sorting.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c8d2f1b7..13f569e7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -207,39 +207,35 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + si_recording: si.BaseRecording = si.load_extractor(recording_file) + # Run sorting @memoized_result( parameters={**key, **params}, - output_directory=output_dir / sorter_name, + output_directory=output_dir / sorter_name / "spike_sorting", ) def _run_sorter(*args, **kwargs): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = output_dir / sorter_name / "si_sorting.pkl" + sorting_save_path = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path) return sorting_save_path sorting_save_path = _run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=output_dir / sorter_name, + output_folder=output_dir / sorter_name / "spike_sorting", remove_existing_folder=True, verbose=True, docker_image=True, **params.get("SI_SORTING_PARAMS", {}), ) - # Run sorting - sorting_save_path = output_dir / "si_sorting.pkl" - si_sorting.dump_to_pickle(sorting_save_path) - self.insert1( { **key, @@ -266,13 +262,20 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - # Load sorting & recording object. - output_dir, params = (ephys.ClusteringTask & key).fetch1( - "clustering_output_dir", "params" + # Load recording object. + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" - sorting_file = output_dir / "si_sorting.pkl" + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) @@ -281,7 +284,7 @@ def make(self, key): we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=output_dir / "waveform", # The folder where waveforms are cached + folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), @@ -309,8 +312,11 @@ def make(self, key): waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) # Save the output (metrics.csv to the output dir) + metrics_output_dir = output_dir / sorter_name / "metrics" + metrics_output_dir.mkdir(parents=True, exist_ok=True) + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) - metrics.to_csv(output_dir / "metrics.csv") + metrics.to_csv(metrics_output_dir / "metrics.csv") # Save results self.insert1( From 7d28351baa629e1a72e0d9a52c67f0d6590af689 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 19 Feb 2024 15:30:30 -0600 Subject: [PATCH 080/152] refactor & accept changes from code review --- element_array_ephys/ephys_no_curation.py | 12 +++++----- .../spike_sorting/si_preprocessing.py | 2 +- .../spike_sorting/si_spike_sorting.py | 22 +++++++++++-------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5894fe16..acf7c76f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -989,7 +989,7 @@ def make(self, key): extremum_channel_inds=unit_peak_channel_map ) - # Get electrode info + # Get electrode info !#TODO: need to be modified electrode_config_key = ( EphysRecording * probe.ElectrodeConfig & key ).fetch1("KEY") @@ -1183,11 +1183,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 07a49293..4db5f303 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -34,4 +34,4 @@ def IBLdestriping_modified(recording): recording = si.preprocessing.common_reference( recording, operator="median", reference="global" ) - return recording \ No newline at end of file + return recording diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 13f569e7..1b8366dc 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path, memoized_result +from element_interface.utils import find_full_path # , memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -98,12 +98,12 @@ def make(self, key): clustering_method, acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") - + # Get sorter method and create output directory. sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + for required_key in ( "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", @@ -165,7 +165,7 @@ def make(self, key): .fetch(format="frame") .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] ) - + # Create SI probe object si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) si_probe.set_device_channel_indices(range(len(electrodes_df))) @@ -214,7 +214,7 @@ def make(self, key): ) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) - + # Run sorting @memoized_result( parameters={**key, **params}, @@ -222,7 +222,9 @@ def make(self, key): ) def _run_sorter(*args, **kwargs): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + sorting_save_path = ( + output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + ) si_sorting.dump_to_pickle(sorting_save_path) return sorting_save_path @@ -272,7 +274,7 @@ def make(self, key): sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -284,7 +286,9 @@ def make(self, key): we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached + folder=output_dir + / sorter_name + / "waveform", # The folder where waveforms are cached max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), @@ -314,7 +318,7 @@ def make(self, key): # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) - + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) metrics.to_csv(metrics_output_dir / "metrics.csv") From 95f5286704ca51a8768a1c8bad41d2ae2de94767 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 19 Feb 2024 16:50:15 -0600 Subject: [PATCH 081/152] remove memoized_result for testing --- .../spike_sorting/si_spike_sorting.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 1b8366dc..0e3da684 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path # , memoized_result +from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -216,19 +216,7 @@ def make(self, key): si_recording: si.BaseRecording = si.load_extractor(recording_file) # Run sorting - @memoized_result( - parameters={**key, **params}, - output_directory=output_dir / sorter_name / "spike_sorting", - ) - def _run_sorter(*args, **kwargs): - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = ( - output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - ) - si_sorting.dump_to_pickle(sorting_save_path) - return sorting_save_path - - sorting_save_path = _run_sorter( + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name / "spike_sorting", @@ -238,6 +226,11 @@ def _run_sorter(*args, **kwargs): **params.get("SI_SORTING_PARAMS", {}), ) + sorting_save_path = ( + output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + ) + si_sorting.dump_to_pickle(sorting_save_path) + self.insert1( { **key, From a7ebb9a61c6c2e90fd01d7f5529b8409b36889c3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 20 Feb 2024 17:59:46 -0600 Subject: [PATCH 082/152] build: :heavy_plus_sign: add element-interface to required packages --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 532c72f6..204008a6 100644 --- a/setup.py +++ b/setup.py @@ -39,12 +39,12 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", + "element-interface @ git+https://github.com/datajoint/element-interface.git", ], extras_require={ "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git", - "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "element-lab @ git+https://github.com/datajoint/element-lab.git", "element-session @ git+https://github.com/datajoint/element-session.git", ], From 134ff54eb124896f6fd70f5d335680ed7c2c9a06 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 20 Feb 2024 18:00:07 -0600 Subject: [PATCH 083/152] update pre-commit with the latest hooks --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d513df7..6d28ef11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ exclude: (^.github/|^docs/|^images/) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -16,7 +16,7 @@ repos: # black - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 24.2.0 hooks: - id: black - id: black-jupyter @@ -25,7 +25,7 @@ repos: # isort - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] @@ -33,7 +33,7 @@ repos: # flake8 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.0.0 hooks: - id: flake8 args: # arguments to configure flake8 From 79724268bd3b0c29f129b2fc3781577a931eb098 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 27 Feb 2024 17:25:37 -0600 Subject: [PATCH 084/152] build: :heavy_plus_sign: Add numba as required package --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 204008a6..52cd38b1 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "nbformat>=4.2.0", "pyopenephys>=1.1.6", "element-interface @ git+https://github.com/datajoint/element-interface.git", + "numba", ], extras_require={ "elements": [ @@ -50,6 +51,6 @@ ], "nwb": ["dandi", "neuroconv[ecephys]", "pynwb"], "tests": ["pre-commit", "pytest", "pytest-cov"], - "spikingcircus": ["hdbscan", "numba"], + "spikingcircus": ["hdbscan"], }, ) From ed11526a00649cf6b1849802720c151a33beb374 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 27 Feb 2024 17:26:24 -0600 Subject: [PATCH 085/152] adjust extract_waveforms parameters --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 0e3da684..6df25de8 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -282,8 +282,8 @@ def make(self, key): folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached - max_spikes_per_unit=None, overwrite=True, + allow_unfiltered=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) From bab86b7c1471f02e2c9d95c1f4ee8442345b3817 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 28 Feb 2024 22:52:28 -0600 Subject: [PATCH 086/152] refactor: :recycle: update the output dir for CuratedClustering --- element_array_ephys/ephys_no_curation.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index acf7c76f..71648fec 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -967,22 +967,31 @@ class Unit(dj.Part): def make(self, key): """Automated population of Unit information.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - if (output_dir / "waveform").exists(): # read from spikeinterface outputs + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + waveform_dir = output_dir / sorter_name / "waveform" + sorting_dir = output_dir / sorter_name / "spike_sorting" + + if waveform_dir.exists(): # read from spikeinterface outputs we: si.WaveformExtractor = si.load_waveforms( - output_dir / "waveform", with_recording=False + waveform_dir, with_recording=False ) si_sorting: si.sorters.BaseSorter = si.load_extractor( - output_dir / "sorting.pkl" + sorting_dir / "si_sorting.pkl" ) unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( we, outputs="index" ) # {unit: peak_channel_index} - spike_count_dict = dict[int, int] = si_sorting.count_num_spikes_per_unit() + spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} spikes = si_sorting.to_spike_vector( From 0898ce5cdaeb7656ed7142f395351dfcde652d40 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 29 Feb 2024 10:08:22 -0600 Subject: [PATCH 087/152] feat: :sparkles: add quality label mapping --- element_array_ephys/ephys_no_curation.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 71648fec..cb91baa9 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1011,6 +1011,21 @@ def make(self, key): zip(*electrode_query.fetch("channel", "electrode")) ) + # Get unit id to quality label mapping + cluster_quality_label_map = {} + try: + cluster_quality_label_map = pd.read_csv( + sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", + delimiter="\t", + ) + cluster_quality_label_map: dict[ + int, str + ] = cluster_quality_label_map.set_index("cluster_id")[ + "KSLabel" + ].to_dict() # {unit: quality_label} + except FileNotFoundError: + pass + # Get channel to electrode mapping channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) @@ -1038,7 +1053,9 @@ def make(self, key): units.append( { "unit": unit_id, - "cluster_quality_label": "n.a.", + "cluster_quality_label": cluster_quality_label_map.get( + unit_id, "n.a." + ), "spike_times": si_sorting.get_unit_spike_train( unit_id, return_times=True ), From 727af24cca6f00fc35651f88f7f258d2ab67e43e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 10:19:09 -0600 Subject: [PATCH 088/152] feat: :sparkles: Ingest EphysRecording.Channel --- element_array_ephys/ephys_no_curation.py | 242 +++++++++++++---------- 1 file changed, 133 insertions(+), 109 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index cb91baa9..e9adc290 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -8,7 +8,9 @@ import datajoint as dj import numpy as np import pandas as pd +import spikeinterface as si from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx @@ -19,9 +21,6 @@ _linking_module = None -import spikeinterface as si -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters - def activate( ephys_schema_name: str, @@ -327,129 +326,154 @@ def make(self, key): break else: raise FileNotFoundError( - "Ephys recording data not found!" + f"Ephys recording data not found! for {key}." "Neither SpikeGLX nor Open Ephys recording files found" ) - supported_probe_types = probe.ProbeType.fetch("probe_type") + if acq_software not in AcquisitionSoftware.fetch("acq_software"): + raise NotImplementedError( + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." + ) - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + else: + supported_probe_types = probe.ProbeType.fetch("probe_type") + + if acq_software == "SpikeGLX": + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) + + if spikeglx_meta.probe_model in supported_probe_types: + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch( + "KEY", "shank", "shank_col", "shank_row" + ) + ) + } - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ + "data" + ] + ] + else: + raise NotImplementedError( + "Processing for neuropixels probe model" + " {} not yet implemented".format(spikeglx_meta.probe_model) ) - } - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath) + ), + } ) - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) + root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + self.EphysFile.insert1( + {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} ) + elif acq_software == "Open Ephys": + dataset = openephys.OpenEphys(session_dir) + for serial_number, probe_data in dataset.probes.items(): + if str(serial_number) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No Open Ephys data found for probe insertion: {}".format(key) + ) - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) + if not probe_data.ap_meta: + raise IOError( + 'No analog signals found - check "structure.oebin" file or "continuous" directory' + ) - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + if probe_data.probe_model in supported_probe_types: + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] + else: + raise NotImplementedError( + "Processing for neuropixels" + " probe model {} not yet implemented".format( + probe_data.probe_model + ) + ) + + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info[ + "recording_datetimes" + ][0], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } ) - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) + root_dir = find_root_directory( + get_ephys_root_data_dir(), + probe_data.recording_info["recording_files"][0], + ) + self.EphysFile.insert( + [ + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] + ] + ) + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], + # Insert channel information + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software ) - self.EphysFile.insert( + self.Channel.insert( [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrodes.items() ] ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) @schema @@ -1209,11 +1233,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 1df41ea6ee8668c5b3ca5cb970dad024f65a668c Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 11:05:05 -0600 Subject: [PATCH 089/152] get channel to electrode mapping in CuratedClustering --- element_array_ephys/ephys_no_curation.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index e9adc290..c313684f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -9,7 +9,8 @@ import numpy as np import pandas as pd import spikeinterface as si -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory +from element_interface.utils import (dict_to_uuid, find_full_path, + find_root_directory) from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe @@ -1022,7 +1023,7 @@ def make(self, key): extremum_channel_inds=unit_peak_channel_map ) - # Get electrode info !#TODO: need to be modified + # Get electrode & channel info electrode_config_key = ( EphysRecording * probe.ElectrodeConfig & key ).fetch1("KEY") @@ -1030,10 +1031,11 @@ def make(self, key): electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & electrode_config_key - ) + ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + channel2electrode_map = dict( - zip(*electrode_query.fetch("channel", "electrode")) - ) + zip(*electrode_query.fetch("channel_idx", "electrode")) + ) # {channel: electrode} # Get unit id to quality label mapping cluster_quality_label_map = {} @@ -1051,24 +1053,24 @@ def make(self, key): pass # Get channel to electrode mapping - channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) + channel2depth_map = dict(zip(*electrode_query.fetch("channel_idx", "y_coord"))) # {channel: depth} peak_electrode_ind = np.array( [ channel2electrode_map[unit_peak_channel_map[unit_id]] for unit_id in si_sorting.unit_ids ] - ) + ) # get the electrode where peak unit activity is recorded # Get channel to depth mapping - electrode_depth_ind = np.array( + channel_depth_ind = np.array( [ channel2depth_map[unit_peak_channel_map[unit_id]] for unit_id in si_sorting.unit_ids ] ) spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] - spikes["depth"] = electrode_depth_ind[spikes["unit_index"]] + spikes["depth"] = channel_depth_ind[spikes["unit_index"]] units = [] @@ -1233,11 +1235,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 417219fd4baeeadc1d9e4feeaa29ef04c4b21555 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 12:10:29 -0600 Subject: [PATCH 090/152] refactor: :recycle: Fix metrics directory in QualityMetrics --- element_array_ephys/ephys_no_curation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index c313684f..70525451 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -9,8 +9,7 @@ import numpy as np import pandas as pd import spikeinterface as si -from element_interface.utils import (dict_to_uuid, find_full_path, - find_root_directory) +from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe @@ -1032,7 +1031,7 @@ def make(self, key): probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & electrode_config_key ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) - + channel2electrode_map = dict( zip(*electrode_query.fetch("channel_idx", "electrode")) ) # {channel: electrode} @@ -1053,7 +1052,9 @@ def make(self, key): pass # Get channel to electrode mapping - channel2depth_map = dict(zip(*electrode_query.fetch("channel_idx", "y_coord"))) # {channel: depth} + channel2depth_map = dict( + zip(*electrode_query.fetch("channel_idx", "y_coord")) + ) # {channel: depth} peak_electrode_ind = np.array( [ @@ -1570,9 +1571,14 @@ class Waveform(dj.Part): def make(self, key): """Populates tables with quality metrics data.""" # Load metrics.csv - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - metric_fp = output_dir / "metrics.csv" + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv" if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") metrics_df = pd.read_csv(metric_fp) From f70ae4ee1e294b0ba1173fa6a9b1255e2d27f6b3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 13:06:22 -0600 Subject: [PATCH 091/152] feat: :sparkles: replace get_neuropixels_channel2electrode_map with channel_info --- element_array_ephys/ephys_no_curation.py | 54 ++++++++++++++++-------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 70525451..5730ebf3 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1096,7 +1096,7 @@ def make(self, key): } ) - else: + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, sample_rate = (EphysRecording & key).fetch1( "acq_software", "sampling_rate" @@ -1131,14 +1131,19 @@ def make(self, key): kilosort_dataset.extract_spike_depths() # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} # -- Spike-sites and Spike-depths -- spike_sites = np.array( [ - channel2electrodes[s]["electrode"] + channel_info[s]["electrode"] for s in kilosort_dataset.data["spike_sites"] ] ) @@ -1157,9 +1162,10 @@ def make(self, key): units.append( { + **key, "unit": unit, "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], + **channel_info[unit_channel], "spike_times": unit_spike_times, "spike_count": spike_count, "spike_sites": spike_sites[ @@ -1228,13 +1234,21 @@ class Waveform(dj.Part): def make(self, key): """Populates waveform tables.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) - if (output_dir / "waveform").exists(): # read from spikeinterface outputs + if ( + output_dir / sorter_name / "waveform" + ).exists(): # read from spikeinterface outputs + waveform_dir = output_dir / sorter_name / "waveform" we: si.WaveformExtractor = si.load_waveforms( - output_dir / "waveform", with_recording=False + waveform_dir, with_recording=False ) unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( @@ -1299,11 +1313,15 @@ def make(self, key): EphysRecording * ProbeInsertion & key ).fetch1("acq_software", "probe") - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software + # Get channel and electrode-site mapping + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} # Get all units units = { @@ -1335,12 +1353,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1377,12 +1395,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1451,13 +1469,13 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **unit_dict, - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform.mean(axis=0), "waveforms": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == unit_dict["electrode"] ): unit_peak_waveform = { From 0ccbec962b5a11107d8fe28906e79bf19c693475 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 21:59:06 +0000 Subject: [PATCH 092/152] fix CuratedClustering make function --- element_array_ephys/ephys_no_curation.py | 36 +++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5730ebf3..6ddd8fec 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1032,6 +1032,12 @@ def make(self, key): & electrode_config_key ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") + + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } + channel2electrode_map = dict( zip(*electrode_query.fetch("channel_idx", "electrode")) ) # {channel: electrode} @@ -1058,7 +1064,7 @@ def make(self, key): peak_electrode_ind = np.array( [ - channel2electrode_map[unit_peak_channel_map[unit_id]] + channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] ) # get the electrode where peak unit activity is recorded @@ -1066,19 +1072,29 @@ def make(self, key): # Get channel to depth mapping channel_depth_ind = np.array( [ - channel2depth_map[unit_peak_channel_map[unit_id]] + channel_info[unit_peak_channel_map[unit_id]]["y_coord"] for unit_id in si_sorting.unit_ids ] ) - spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] - spikes["depth"] = channel_depth_ind[spikes["unit_index"]] + + # Assign electrode and depth for each spike + new_spikes = np.empty(spikes.shape, spikes.dtype.descr + [('electrode', ' Date: Fri, 1 Mar 2024 23:11:25 +0000 Subject: [PATCH 093/152] improve try except logic --- element_array_ephys/ephys_no_curation.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 6ddd8fec..b8c22f05 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1043,25 +1043,20 @@ def make(self, key): ) # {channel: electrode} # Get unit id to quality label mapping - cluster_quality_label_map = {} try: cluster_quality_label_map = pd.read_csv( sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", delimiter="\t", ) + except FileNotFoundError: + cluster_quality_label_map = {} + else: cluster_quality_label_map: dict[ int, str ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - except FileNotFoundError: - pass - - # Get channel to electrode mapping - channel2depth_map = dict( - zip(*electrode_query.fetch("channel_idx", "y_coord")) - ) # {channel: depth} - + peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] From 673faedc1c8232098888340bab687896cd89e7f1 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:14:45 +0000 Subject: [PATCH 094/152] docs: :memo: update comments in ephys_no_curation --- element_array_ephys/ephys_no_curation.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b8c22f05..f4ba5b29 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -286,10 +286,10 @@ class EphysRecording(dj.Imported): class Channel(dj.Part): definition = """ -> master - channel_idx: int # channel index + channel_idx: int # channel index (index of the raw data) --- -> probe.ElectrodeConfig.Electrode - channel_name="": varchar(64) + channel_name="": varchar(64) # alias of the channel """ class EphysFile(dj.Part): @@ -1033,14 +1033,9 @@ def make(self, key): ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") - channel_info: dict[int, dict] = { ch.pop("channel_idx"): ch for ch in channel_info } - - channel2electrode_map = dict( - zip(*electrode_query.fetch("channel_idx", "electrode")) - ) # {channel: electrode} # Get unit id to quality label mapping try: @@ -1056,15 +1051,16 @@ def make(self, key): ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - + + # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] - ) # get the electrode where peak unit activity is recorded + ) - # Get channel to depth mapping + # Get channel depth channel_depth_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["y_coord"] @@ -1707,7 +1703,7 @@ def get_openephys_probe_data(ephys_recording_key: dict) -> list: def get_neuropixels_channel2electrode_map( ephys_recording_key: dict, acq_software: str -) -> dict: +) -> dict: #TODO: remove this function """Get the channel map for neuropixels probe.""" if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) From 8ab4c58a9d5c744bc3c071f31f360873ff8ee8a2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:15:21 +0000 Subject: [PATCH 095/152] refactor: :recycle: improve if else block in EphysRecording --- element_array_ephys/ephys_no_curation.py | 237 +++++++++++------------ 1 file changed, 118 insertions(+), 119 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index f4ba5b29..5d77a041 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -335,145 +335,144 @@ def make(self, key): f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) - else: - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + supported_probe_types = probe.ProbeType.fetch("probe_type") - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + if acq_software == "SpikeGLX": + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch( - "KEY", "shank", "shank_col", "shank_row" - ) - ) - } + if spikeglx_meta.probe_model in supported_probe_types: + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ - "data" - ] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch( + "KEY", "shank", "shank_col", "shank_row" + ) ) + } - self.insert1( - { - **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ + "data" + ] + ] + else: + raise NotImplementedError( + "Processing for neuropixels probe model" + " {} not yet implemented".format(spikeglx_meta.probe_model) ) - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath) + ), + } + ) + + root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + self.EphysFile.insert1( + {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + ) + elif acq_software == "Open Ephys": + dataset = openephys.OpenEphys(session_dir) + for serial_number, probe_data in dataset.probes.items(): + if str(serial_number) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No Open Ephys data found for probe insertion: {}".format(key) ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) + if not probe_data.ap_meta: + raise IOError( + 'No analog signals found - check "structure.oebin" file or "continuous" directory' + ) - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + if probe_data.probe_model in supported_probe_types: + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format( - probe_data.probe_model - ) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] + else: + raise NotImplementedError( + "Processing for neuropixels" + " probe model {} not yet implemented".format( + probe_data.probe_model ) - - self.insert1( - { - **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } ) - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info[ + "recording_datetimes" + ][0], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } + ) - # Insert channel information - # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software + root_dir = find_root_directory( + get_ephys_root_data_dir(), + probe_data.recording_info["recording_files"][0], ) - self.Channel.insert( + self.EphysFile.insert( [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrodes.items() + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] ] ) + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() + + # Insert channel information + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software + ) + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrodes.items() + ] + ) @schema From 226142b82614f4964a3f7c8655ffd2ca6f43dd3d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:52:16 +0000 Subject: [PATCH 096/152] feat: :sparkles: Update WaveformSet make function --- element_array_ephys/ephys_no_curation.py | 47 ++++++++---------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5d77a041..a721e5b6 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1248,6 +1248,16 @@ def make(self, key): "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) + # Get channel and electrode-site mapping + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") + ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} + if ( output_dir / sorter_name / "waveform" ).exists(): # read from spikeinterface outputs @@ -1256,27 +1266,12 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( waveform_dir, with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( we, 1, peak_sign="neg" ).unit_id_to_channel_indices ) # {unit: peak_channel_index} - units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") - - # Get electrode info - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode.proj() * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - electrode_info = electrode_query.fetch( - "KEY", order_by="electrode", as_dict=True - ) - # Get mean waveform for each unit from all channels mean_waveforms = we.get_all_templates( mode="average" @@ -1285,13 +1280,13 @@ def make(self, key): unit_peak_waveform = [] unit_electrode_waveforms = [] - for unit in units: + for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): unit_peak_waveform.append( { **unit, "peak_electrode_waveform": we.get_template( unit_id=unit["unit"], mode="average", force_dense=True - )[:, unit_id_to_peak_channel_indices[unit["unit"]][0]], + )[:, unit_id_to_peak_channel_map[unit["unit"]][0]], } ) @@ -1299,12 +1294,12 @@ def make(self, key): [ { **unit, - **e, + **channel_info[c], "waveform_mean": mean_waveforms[ - unit["unit"], :, e["electrode"] + unit["unit"] - 1, :, c ], } - for e in electrode_info + for c in channel_info ] ) @@ -1319,16 +1314,6 @@ def make(self, key): EphysRecording * ProbeInsertion & key ).fetch1("acq_software", "probe") - # Get channel and electrode-site mapping - channel_info = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") - ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} - # Get all units units = { u["unit"]: u From ea398391223d6ce8ac26ea22efc2240c86a95525 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 09:24:14 -0600 Subject: [PATCH 097/152] refactor: :fire: remove & get_neuropixels_channel2electrode_map and generate_electrode_config --- element_array_ephys/ephys_no_curation.py | 305 +++++++++++------------ 1 file changed, 146 insertions(+), 159 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index a721e5b6..26608997 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -315,7 +315,7 @@ def make(self, key): "probe" ) - # search session dir and determine acquisition software + # Search session dir and determine acquisition software for ephys_pattern, ephys_acq_type in ( ("*.ap.meta", "SpikeGLX"), ("*.oebin", "Open Ephys"), @@ -338,62 +338,117 @@ def make(self, key): supported_probe_types = probe.ProbeType.fetch("probe_type") if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - if spikeglx_meta.probe_model in supported_probe_types: + if spikeglx_meta.probe_model not in supported_probe_types: + raise NotImplementedError( + f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + ) + else: probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} probe_electrodes = { (shank, shank_col, shank_row): key for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch( - "KEY", "shank", "shank_col", "shank_row" - ) + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") ) - } - + } # electrode configuration electrode_group_members = [ probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ - "data" + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] + ] # recording session-specific electrode configuration + + # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid( + {k["electrode"]: k for k in electrode_group_members} + ) + + electrode_list = sorted( + [k["electrode"] for k in electrode_group_members] + ) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) ] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + + # Insert into ElectrodeConfig + if not probe.ElectrodeConfig & electrode_config_key: + probe.ElectrodeConfig.insert1( + { + **electrode_config_key, + "probe_type": probe_type, + "electrode_config_name": electrode_config_name, + } + ) + probe.ElectrodeConfig.Electrode.insert( + {**electrode_config_key, **electrode} + for electrode in electrode_group_members + ) + self.insert1( { **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), + "electrode_config_hash": electrode_config_hash, "acq_software": acq_software, "sampling_rate": spikeglx_meta.meta["imSampRate"], "recording_datetime": spikeglx_meta.recording_time, "recording_duration": ( spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) + or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) ), } ) - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + root_dir = find_root_directory( + get_ephys_root_data_dir(), spikeglx_meta_filepath + ) self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + { + **key, + "file_path": spikeglx_meta_filepath.relative_to( + root_dir + ).as_posix(), + } + ) + + # Insert channel information + # Get channel and electrode-site mapping + electrode_query = ( + probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") + ) + } + + channel2electrode_map = { + recorded_site: probe_electrodes[(shank, shank_col, shank_row)] + for recorded_site, (shank, shank_col, shank_row, _) in enumerate( + spikeglx_meta.shankmap["data"] + ) + } + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] ) + elif acq_software == "Open Ephys": dataset = openephys.OpenEphys(session_dir) for serial_number, probe_data in dataset.probes.items(): @@ -409,11 +464,13 @@ def make(self, key): 'No analog signals found - check "structure.oebin" file or "continuous" directory' ) - if probe_data.probe_model in supported_probe_types: + if probe_data.probe_model not in supported_probe_types: + raise NotImplementedError( + f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + ) + else: probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} probe_electrodes = { key["electrode"]: key for key in electrode_query.fetch("KEY") @@ -423,20 +480,33 @@ def make(self, key): probe_electrodes[channel_idx] for channel_idx in probe_data.ap_meta["channels_indices"] ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format( - probe_data.probe_model - ) + + # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid( + {k["electrode"]: k for k in electrode_group_members} + ) + + electrode_list = sorted( + [k["electrode"] for k in electrode_group_members] + ) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) + ] ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + self.insert1( { **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), + "electrode_config_hash": electrode_config_hash, "acq_software": acq_software, "sampling_rate": probe_data.ap_meta["sample_rate"], "recording_datetime": probe_data.recording_info[ @@ -462,17 +532,26 @@ def make(self, key): del probe_data, dataset gc.collect() - # Insert channel information - # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software - ) - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrodes.items() - ] - ) + probe_dataset = get_openephys_probe_data(key) + electrode_query = ( + probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * EphysRecording + & key + ) + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } + channel2electrode_map = { + channel_idx: probe_electrodes[channel_idx] + for channel_idx in probe_dataset.ap_meta["channels_indices"] + } + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] + ) @schema @@ -1034,7 +1113,7 @@ def make(self, key): channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") channel_info: dict[int, dict] = { ch.pop("channel_idx"): ch for ch in channel_info - } + } # Get unit id to quality label mapping try: @@ -1050,14 +1129,14 @@ def make(self, key): ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - + # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] - ) + ) # Get channel depth channel_depth_ind = np.array( @@ -1066,14 +1145,17 @@ def make(self, key): for unit_id in si_sorting.unit_ids ] ) - + # Assign electrode and depth for each spike - new_spikes = np.empty(spikes.shape, spikes.dtype.descr + [('electrode', ' list: return probe_data -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: #TODO: remove this function - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - def get_recording_channels_details(ephys_recording_key: dict) -> np.array: """Get details of recording channels for a given recording.""" channels_details = {} From 5bfe201293b5c2ee34a0fb0d666ed280caa22cf9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 14:53:51 -0600 Subject: [PATCH 098/152] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Thinh Nguyen --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 26608997..2e45d4fe 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -326,7 +326,7 @@ def make(self, key): break else: raise FileNotFoundError( - f"Ephys recording data not found! for {key}." + f"Ephys recording data not found in {session_dir}." "Neither SpikeGLX nor Open Ephys recording files found" ) From 1d805849f20ee1e5ce6a911d06d189f9d699900f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 15:19:36 -0600 Subject: [PATCH 099/152] add generate_electrode_config_name --- element_array_ephys/ephys_no_curation.py | 96 ++++++++++++------------ 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 2e45d4fe..7ce99df2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -364,38 +364,10 @@ def make(self, key): electrode_config_hash = dict_to_uuid( {k["electrode"]: k for k in electrode_group_members} ) - - electrode_list = sorted( - [k["electrode"] for k in electrode_group_members] - ) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] + electrode_config_name = generate_electrode_config_name( + probe_type, electrode_group_members ) - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # Insert into ElectrodeConfig - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} - for electrode in electrode_group_members - ) - self.insert1( { **key, @@ -426,7 +398,7 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key + & {"electrode_config_hash": electrode_config_hash} ) probe_electrodes = { @@ -474,34 +446,20 @@ def make(self, key): probe_electrodes = { key["electrode"]: key for key in electrode_query.fetch("KEY") - } + } # electrode configuration electrode_group_members = [ probe_electrodes[channel_idx] for channel_idx in probe_data.ap_meta["channels_indices"] - ] + ] # recording session-specific electrode configuration # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) electrode_config_hash = dict_to_uuid( {k["electrode"]: k for k in electrode_group_members} ) - - electrode_list = sorted( - [k["electrode"] for k in electrode_group_members] - ) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] + electrode_config_name = generate_electrode_config_name( + probe_type, electrode_group_members ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} self.insert1( { @@ -553,6 +511,20 @@ def make(self, key): ] ) + # Insert into probe.ElectrodeConfig (recording configuration) + if not probe.ElectrodeConfig & {"electrode_config_hash": electrode_config_hash}: + probe.ElectrodeConfig.insert1( + { + "probe_type": probe_type, + "electrode_config_hash": electrode_config_hash, + "electrode_config_name": electrode_config_name, + } + ) + probe.ElectrodeConfig.Electrode.insert( + {"electrode_config_hash": electrode_config_hash, **electrode} + for electrode in electrode_group_members + ) + @schema class LFP(dj.Imported): @@ -1820,3 +1792,29 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array: ) return channels_details + + +def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str: + """Generate electrode config name. + + Args: + probe_type (str): probe type (e.g. neuropixels 2.0 - SS) + electrode_keys (list): list of keys of the probe.ProbeType.Electrode table + + Returns: + electrode_config_name (str) + """ + electrode_list = sorted([k["electrode"] for k in electrode_keys]) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) + ] + ) + + return electrode_config_name From 88ce139bec18a0a01b13943b92c57dbcbc64e074 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 7 Mar 2024 12:29:50 -0600 Subject: [PATCH 100/152] refactor: :recycle: change sorter_name --- .../spike_sorting/si_spike_sorting.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6df25de8..a7d1b963 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -100,9 +100,7 @@ def make(self, key): ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") for required_key in ( "SI_SORTING_PARAMS", @@ -209,9 +207,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) @@ -264,10 +260,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) - + sorter_name = clustering_method.replace(".", "_") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" From 7eaefa49852f0e0807497de37a3c19d30ee1c5f2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 8 Mar 2024 11:11:38 -0600 Subject: [PATCH 101/152] address review comments for generate_electrode_config_entry --- element_array_ephys/ephys_no_curation.py | 55 +++++++++++------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 7ce99df2..a6edbe54 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -360,18 +360,14 @@ def make(self, key): for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] ] # recording session-specific electrode configuration - # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid( - {k["electrode"]: k for k in electrode_group_members} - ) - electrode_config_name = generate_electrode_config_name( + econfig_entry, econfig_electrodes = generate_electrode_config_entry( probe_type, electrode_group_members ) self.insert1( { **key, - "electrode_config_hash": electrode_config_hash, + "electrode_config_hash": econfig_entry["electrode_config_hash"], "acq_software": acq_software, "sampling_rate": spikeglx_meta.meta["imSampRate"], "recording_datetime": spikeglx_meta.recording_time, @@ -398,7 +394,7 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & {"electrode_config_hash": electrode_config_hash} + & {"electrode_config_hash": econfig_entry["electrode_config_hash"]} ) probe_electrodes = { @@ -453,18 +449,14 @@ def make(self, key): for channel_idx in probe_data.ap_meta["channels_indices"] ] # recording session-specific electrode configuration - # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid( - {k["electrode"]: k for k in electrode_group_members} - ) - electrode_config_name = generate_electrode_config_name( + econfig_entry, econfig_electrodes = generate_electrode_config_entry( probe_type, electrode_group_members ) self.insert1( { **key, - "electrode_config_hash": electrode_config_hash, + "electrode_config_hash": econfig_entry["electrode_config_hash"], "acq_software": acq_software, "sampling_rate": probe_data.ap_meta["sample_rate"], "recording_datetime": probe_data.recording_info[ @@ -512,18 +504,11 @@ def make(self, key): ) # Insert into probe.ElectrodeConfig (recording configuration) - if not probe.ElectrodeConfig & {"electrode_config_hash": electrode_config_hash}: - probe.ElectrodeConfig.insert1( - { - "probe_type": probe_type, - "electrode_config_hash": electrode_config_hash, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {"electrode_config_hash": electrode_config_hash, **electrode} - for electrode in electrode_group_members - ) + if not probe.ElectrodeConfig & { + "electrode_config_hash": econfig_entry["electrode_config_hash"] + }: + probe.ElectrodeConfig.insert1(econfig_entry) + probe.ElectrodeConfig.Electrode.insert(econfig_electrodes) @schema @@ -1794,16 +1779,19 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array: return channels_details -def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str: - """Generate electrode config name. +def generate_electrode_config_entry(probe_type: str, electrode_keys: list) -> dict: + """Generate and insert new ElectrodeConfig Args: probe_type (str): probe type (e.g. neuropixels 2.0 - SS) electrode_keys (list): list of keys of the probe.ProbeType.Electrode table Returns: - electrode_config_name (str) + dict: representing a key of the probe.ElectrodeConfig table """ + # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) + electrode_list = sorted([k["electrode"] for k in electrode_keys]) electrode_gaps = ( [-1] @@ -1816,5 +1804,14 @@ def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) ] ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + econfig_entry = { + **electrode_config_key, + "probe_type": probe_type, + "electrode_config_name": electrode_config_name, + } + econfig_electrodes = [ + {**electrode, **electrode_config_key} for electrode in electrode_keys + ] - return electrode_config_name + return econfig_entry, econfig_electrodes From d47be56dd8baeb88d8238701c1aa5ada457d3c36 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:39:58 -0500 Subject: [PATCH 102/152] refactor: :art: refactor PostProcessing --- element_array_ephys/spike_sorting/si_spike_sorting.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a7d1b963..a3f54e1e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -253,15 +253,14 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - # Load recording object. + # Load recording & sorting object. clustering_method, output_dir, params = ( ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -301,14 +300,13 @@ def make(self, key): _ = si.postprocessing.compute_principal_components( waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) - - metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) metrics.to_csv(metrics_output_dir / "metrics.csv") - # Save results self.insert1( { **key, From 8dfc8583017864ce23d22865b8075339a274b1f3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:43:09 -0500 Subject: [PATCH 103/152] chore: :art: run docker if the package is not built into spikeinterface --- element_array_ephys/spike_sorting/si_spike_sorting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a3f54e1e..c74ee9d4 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,23 +205,23 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Run sorting + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name / "spike_sorting", remove_existing_folder=True, verbose=True, - docker_image=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), **params.get("SI_SORTING_PARAMS", {}), ) + # Save sorting object sorting_save_path = ( output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" ) From 6e20a11e92455220eebc065cd4870dc856739544 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:45:21 -0500 Subject: [PATCH 104/152] refactor: :recycle: clean up import & docstring --- .../spike_sorting/si_spike_sorting.py | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c74ee9d4..36956d8f 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -1,34 +1,17 @@ """ -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"spikeinterface" pipeline. -Spikeinterface developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) - -The DataJoint pipeline currently incorporated Spikeinterfaces approach of running Kilosort using a container - -The follow pipeline features intermediary tables: -1. PreProcessing - for preprocessing steps (no GPU required) - - create recording extractor and link it to a probe - - bandpass filtering - - common mode referencing -2. SIClustering - kilosort (MATLAB) - requires GPU and docker/singularity containers - - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. PostProcessing - for postprocessing steps (no GPU required) - - create waveform extractor object - - extract templates, waveforms and snrs - - quality_metrics +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) """ -import pathlib from datetime import datetime import datajoint as dj import pandas as pd -import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters +from .. import get_logger, probe, readers from . import si_preprocessing log = get_logger(__name__) From 8d04e10ce8e08370d9052b1a085bfbdf89d53fbd Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:55:01 -0500 Subject: [PATCH 105/152] revert: :art: replace SI_READERS with si_extractor --- .../spike_sorting/si_spike_sorting.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 36956d8f..2ebe90ba 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -47,12 +47,6 @@ def activate( SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] -SI_READERS = { - "Open Ephys": si.extractors.read_openephys, - "SpikeGLX": si.extractors.read_spikeglx, - "Intan": si.extractors.read_intan, -} - @schema class PreProcessing(dj.Imported): @@ -108,9 +102,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_dir = output_dir / sorter_name / "recording" recording_dir.mkdir(parents=True, exist_ok=True) - recording_file = ( - recording_dir / "si_recording.pkl" - ) # recording cache to be created for each key + recording_file = recording_dir / "si_recording.pkl" # Create SI recording extractor object if acq_software == "SpikeGLX": @@ -125,12 +117,16 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: - raise NotImplementedError(f"Not implemented for {acq_software}") + si_extractor: si.extractors.neoextractors = ( + si.extractors.extractorlist.recording_extractor_full_dict[ + acq_software.replace(" ", "").lower() + ] + ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( acq_software.strip().lower(), folder_path=data_dir ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( + si_recording: si.BaseRecording = si_extractor[acq_software]( folder_path=data_dir, stream_name=stream_names[0] ) From bb39194aeb2a06390be6b0415afe0bd46310dbbf Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 12 Mar 2024 10:39:15 -0500 Subject: [PATCH 106/152] fix acq_software name --- element_array_ephys/spike_sorting/si_spike_sorting.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 2ebe90ba..12fc069b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -117,14 +117,13 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: + acq_software = acq_software.replace(" ", "").lower() si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[ - acq_software.replace(" ", "").lower() - ] + si.extractors.extractorlist.recording_extractor_full_dict[acq_software] ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir + acq_software, folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor[acq_software]( folder_path=data_dir, stream_name=stream_names[0] From 01ff816fd4e2077c3b9c3ff4c5c439faddbb43c9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 12 Mar 2024 10:54:39 -0500 Subject: [PATCH 107/152] feat: :ambulance: make all secondary attributes nullable in QualityMetrics some sorters don't output values expected by the table --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index a6edbe54..bfb3e2ad 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1603,8 +1603,8 @@ class Waveform(dj.Part): -> master -> CuratedClustering.Unit --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough + amplitude=null: float # (uV) absolute difference between waveform peak and trough + duration=null: float # (ms) time between waveform peak and trough halfwidth=null: float # (ms) spike width at half max amplitude pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak From 67a1ffc767261e5a9c7d9e7c85d418005c3dac80 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 17 Apr 2024 09:15:31 -0500 Subject: [PATCH 108/152] feat: save spike interface results with relative path --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 12fc069b..ba310d6e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -150,7 +150,7 @@ def make(self, key): # Run preprocessing and save results to output folder si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) - si_recording.dump_to_pickle(file_path=recording_file) + si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) self.insert1( { @@ -203,7 +203,7 @@ def make(self, key): sorting_save_path = ( output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" ) - si_sorting.dump_to_pickle(sorting_save_path) + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) self.insert1( { From d44dbaa03aa8debb2f9d15fe60811a4fcb52a535 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 11:50:15 -0500 Subject: [PATCH 109/152] fix(spikeglx): bugfix loading spikeglx data --- element_array_ephys/ephys_no_curation.py | 11 +++++++++-- element_array_ephys/spike_sorting/si_spike_sorting.py | 11 +++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index bfb3e2ad..2dde282b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -338,8 +338,15 @@ def make(self, key): supported_probe_types = probe.ProbeType.fetch("probe_type") if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + spikeglx_meta_filepath = meta_filepath + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) if spikeglx_meta.probe_model not in supported_probe_types: raise NotImplementedError( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 12fc069b..0b53bf1d 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -117,10 +117,13 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: - acq_software = acq_software.replace(" ", "").lower() - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[acq_software] - ) # data extractor object + raise NotImplementedError( + f"SpikeInterface processing for {acq_software} not yet implemented." + ) + acq_software = acq_software.replace(" ", "").lower() + si_extractor: si.extractors.neoextractors = ( + si.extractors.extractorlist.recording_extractor_full_dict[acq_software] + ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( acq_software, folder_path=data_dir From d86928bf41a2bb0e30c7136d74fc485c9de2b90f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 12:13:15 -0500 Subject: [PATCH 110/152] fix: bugfix inserting `ElectrodeConfig` --- element_array_ephys/ephys_no_curation.py | 94 ++++++++++++------------ 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 2dde282b..dcb2ded6 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -371,31 +371,30 @@ def make(self, key): probe_type, electrode_group_members ) - self.insert1( - { - **key, - "electrode_config_hash": econfig_entry["electrode_config_hash"], - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) - ), - } - ) + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) + ), + } root_dir = find_root_directory( get_ephys_root_data_dir(), spikeglx_meta_filepath ) - self.EphysFile.insert1( + + ephys_file_entries = [ { **key, "file_path": spikeglx_meta_filepath.relative_to( root_dir ).as_posix(), } - ) + ] # Insert channel information # Get channel and electrode-site mapping @@ -417,13 +416,11 @@ def make(self, key): spikeglx_meta.shankmap["data"] ) } - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrode_map.items() - ] - ) + ephys_channel_entries = [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] elif acq_software == "Open Ephys": dataset = openephys.OpenEphys(session_dir) for serial_number, probe_data in dataset.probes.items(): @@ -460,31 +457,29 @@ def make(self, key): probe_type, electrode_group_members ) - self.insert1( - { - **key, - "electrode_config_hash": econfig_entry["electrode_config_hash"], - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info["recording_datetimes"][ + 0 + ], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } root_dir = find_root_directory( get_ephys_root_data_dir(), probe_data.recording_info["recording_files"][0], ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) + + ephys_file_entries = [ + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] + ] + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough del probe_data, dataset gc.collect() @@ -503,11 +498,14 @@ def make(self, key): channel_idx: probe_electrodes[channel_idx] for channel_idx in probe_dataset.ap_meta["channels_indices"] } - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrode_map.items() - ] + + ephys_channel_entries = [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] + else: + raise NotImplementedError( + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) # Insert into probe.ElectrodeConfig (recording configuration) @@ -517,6 +515,10 @@ def make(self, key): probe.ElectrodeConfig.insert1(econfig_entry) probe.ElectrodeConfig.Electrode.insert(econfig_electrodes) + self.insert1(ephys_recording_entry) + self.EphysFile.insert(ephys_file_entries) + self.Channel.insert(ephys_channel_entries) + @schema class LFP(dj.Imported): From f8ffd7760cb1be6ac19d24e37ebf69d11d773972 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 3 Apr 2024 14:14:35 -0500 Subject: [PATCH 111/152] feat(spikesorting): save to phy and generate report --- element_array_ephys/spike_sorting/si_spike_sorting.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 90a88260..52c96709 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -288,6 +288,11 @@ def make(self, key): metrics_output_dir.mkdir(parents=True, exist_ok=True) metrics.to_csv(metrics_output_dir / "metrics.csv") + # Save to phy format + si.exporters.export_to_phy(waveform_extractor=we, output_folder=output_dir / sorter_name / "phy") + # Generate spike interface report + si.exporters.export_report(waveform_extractor=we, output_folder=output_dir / sorter_name / "spikeinterface_report") + self.insert1( { **key, From 7309082858b5210dcbf9566f2e8afd72416e9655 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 12:35:39 -0500 Subject: [PATCH 112/152] chore: cleanup init --- element_array_ephys/__init__.py | 21 ------------------- .../spike_sorting/ecephys_spike_sorting.py | 3 +-- .../spike_sorting/si_spike_sorting.py | 5 ++--- 3 files changed, 3 insertions(+), 26 deletions(-) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 3a0e5af6..1c0c7285 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1,22 +1 @@ -""" -isort:skip_file -""" - -import logging -import os - -import datajoint as dj - - -__all__ = ["ephys", "get_logger"] - -dj.config["enable_python_native_blobs"] = True - - -def get_logger(name): - log = logging.getLogger(name) - log.setLevel(os.getenv("LOGLEVEL", "INFO")) - return log - - from . import ephys_acute as ephys diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 4de349eb..3a43c384 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -22,7 +22,6 @@ import datajoint as dj -from element_array_ephys import get_logger from decimal import Decimal import json from datetime import datetime, timedelta @@ -33,7 +32,7 @@ kilosort_triggering, ) -log = get_logger(__name__) +log = dj.logger schema = dj.schema() diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 52c96709..306c1eb6 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -7,14 +7,13 @@ import datajoint as dj import pandas as pd import spikeinterface as si -from element_array_ephys import get_logger, probe, readers +from element_array_ephys import probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters -from .. import get_logger, probe, readers from . import si_preprocessing -log = get_logger(__name__) +log = dj.logger schema = dj.schema() From d778b1e7d8822173ad43d60707fbb8fa8c7ff801 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 13:24:45 -0500 Subject: [PATCH 113/152] fix: update channel-electrode mapping --- element_array_ephys/ephys_no_curation.py | 164 ++++++++---------- .../spike_sorting/si_spike_sorting.py | 13 +- 2 files changed, 81 insertions(+), 96 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index dcb2ded6..68251309 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1040,51 +1040,47 @@ def make(self, key): ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + # Get channel and electrode-site mapping + electrode_query = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") ) - waveform_dir = output_dir / sorter_name / "waveform" - sorting_dir = output_dir / sorter_name / "spike_sorting" + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map + } - if waveform_dir.exists(): # read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms( - waveform_dir, with_recording=False - ) + # Get sorter method and create output directory. + sorter_name = clustering_method.replace(".", "_") + si_waveform_dir = output_dir / sorter_name / "waveform" + si_sorting_dir = output_dir / sorter_name / "spike_sorting" + + if si_waveform_dir.exists(): + + # Read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) si_sorting: si.sorters.BaseSorter = si.load_extractor( - sorting_dir / "si_sorting.pkl" + si_sorting_dir / "si_sorting.pkl" ) - unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_index} + unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( + we, outputs="id" + ) # {unit: peak_channel_id} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector( - extremum_channel_inds=unit_peak_channel_map - ) - - # Get electrode & channel info - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + spikes = si_sorting.to_spike_vector() - channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = { + chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids } # Get unit id to quality label mapping try: cluster_quality_label_map = pd.read_csv( - sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", + si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", delimiter="\t", ) except FileNotFoundError: @@ -1099,7 +1095,7 @@ def make(self, key): # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["electrode"] + channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] ) @@ -1107,7 +1103,7 @@ def make(self, key): # Get channel depth channel_depth_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["y_coord"] + channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"] for unit_id in si_sorting.unit_ids ] ) @@ -1132,7 +1128,7 @@ def make(self, key): units.append( { **key, - **channel_info[unit_peak_channel_map[unit_id]], + **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map.get( unit_id, "n.a." @@ -1143,10 +1139,10 @@ def make(self, key): "spike_count": spike_count_dict[unit_id], "spike_sites": new_spikes["electrode"][ new_spikes["unit_index"] == unit_id - ], + ], "spike_depths": new_spikes["depth"][ new_spikes["unit_index"] == unit_id - ], + ], } ) @@ -1184,20 +1180,10 @@ def make(self, key): spike_times = kilosort_dataset.data[spike_time_key] kilosort_dataset.extract_spike_depths() - # Get channel and electrode-site mapping - channel_info = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") - ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} - # -- Spike-sites and Spike-depths -- spike_sites = np.array( [ - channel_info[s]["electrode"] + channel2electrode_map[s]["electrode"] for s in kilosort_dataset.data["spike_sites"] ] ) @@ -1219,7 +1205,7 @@ def make(self, key): **key, "unit": unit, "cluster_quality_label": unit_lbl, - **channel_info[unit_channel], + **channel2electrode_map[unit_channel], "spike_times": unit_spike_times, "spike_count": spike_count, "spike_sites": spike_sites[ @@ -1292,33 +1278,31 @@ def make(self, key): ClusteringTask * ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") # Get channel and electrode-site mapping - channel_info = ( + electrode_query = ( (EphysRecording.Channel & key) .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map + } - if ( - output_dir / sorter_name / "waveform" - ).exists(): # read from spikeinterface outputs + si_waveform_dir = output_dir / sorter_name / "waveform" + if si_waveform_dir.exists(): # read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) + unit_id_to_peak_channel_map: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} - waveform_dir = output_dir / sorter_name / "waveform" - we: si.WaveformExtractor = si.load_waveforms( - waveform_dir, with_recording=False - ) - unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = { + chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + } # Get mean waveform for each unit from all channels mean_waveforms = we.get_all_templates( @@ -1329,30 +1313,32 @@ def make(self, key): unit_electrode_waveforms = [] for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): + unit_waveforms = we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + ) # (sample x channel) + peak_chn_idx = list(we.channel_ids).index( + unit_id_to_peak_channel_map[unit["unit"]][0] + ) unit_peak_waveform.append( { **unit, - "peak_electrode_waveform": we.get_template( - unit_id=unit["unit"], mode="average", force_dense=True - )[:, unit_id_to_peak_channel_map[unit["unit"]][0]], + "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], } ) - unit_electrode_waveforms.extend( [ { **unit, - **channel_info[c], - "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c], + **channel2electrode_map[c], + "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx], } - for c in channel_info + for c_idx, c in enumerate(channel2electrode_map) ] ) self.insert1(key) self.PeakWaveform.insert(unit_peak_waveform) self.Waveform.insert(unit_electrode_waveforms) - else: kilosort_dataset = kilosort.Kilosort(output_dir) @@ -1390,12 +1376,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1405,7 +1391,6 @@ def yield_unit_waveforms(): yield unit_peak_waveform, unit_electrode_waveforms # Spike interface mean and peak waveform extraction from we object - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): we_kilosort = si.load_waveforms(waveforms_folder[0].parent) unit_templates = we_kilosort.get_all_templates() @@ -1432,12 +1417,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1506,13 +1491,13 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **unit_dict, - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform.mean(axis=0), "waveforms": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == unit_dict["electrode"] ): unit_peak_waveform = { @@ -1630,12 +1615,15 @@ def make(self, key): ClusteringTask * ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) - metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv" - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") + sorter_name = clustering_method.replace(".", "_") + + # find metric_fp + for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]: + if metric_fp.exists(): + break + else: + raise FileNotFoundError(f"QC metrics file not found in: {output_dir}") + metrics_df = pd.read_csv(metric_fp) # Conform the dataframe to match the table definition diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 306c1eb6..d14746fb 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -132,21 +132,18 @@ def make(self, key): ) # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") electrodes_df = ( ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key + ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & key ) .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + .reset_index() ) # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(range(len(electrodes_df))) + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]) + si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder From 015341c1127300e10e9011ec5d49a96abc3322f0 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 17:23:29 -0500 Subject: [PATCH 114/152] feat: test spikeinterface for spikeglx data --- element_array_ephys/ephys_no_curation.py | 144 ++++++++---------- .../spike_sorting/si_preprocessing.py | 2 +- .../spike_sorting/si_spike_sorting.py | 8 +- 3 files changed, 72 insertions(+), 82 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 68251309..333a189a 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -352,24 +352,24 @@ def make(self, key): raise NotImplementedError( f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." ) - else: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } # electrode configuration - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] # recording session-specific electrode configuration - - econfig_entry, econfig_electrodes = generate_electrode_config_entry( - probe_type, electrode_group_members + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") ) + } # electrode configuration + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members + ) ephys_recording_entry = { **key, @@ -398,18 +398,6 @@ def make(self, key): # Insert channel information # Get channel and electrode-site mapping - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & {"electrode_config_hash": econfig_entry["electrode_config_hash"]} - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - channel2electrode_map = { recorded_site: probe_electrodes[(shank, shank_col, shank_row)] for recorded_site, (shank, shank_col, shank_row, _) in enumerate( @@ -418,7 +406,12 @@ def make(self, key): } ephys_channel_entries = [ - {**key, "channel_idx": channel_idx, **channel_info} + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } for channel_idx, channel_info in channel2electrode_map.items() ] elif acq_software == "Open Ephys": @@ -438,24 +431,24 @@ def make(self, key): if probe_data.probe_model not in supported_probe_types: raise NotImplementedError( - f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented." ) - else: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } # electrode configuration + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] # recording session-specific electrode configuration + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } # electrode configuration - econfig_entry, econfig_electrodes = generate_electrode_config_entry( - probe_type, electrode_group_members - ) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members + ) ephys_recording_entry = { **key, @@ -480,29 +473,24 @@ def make(self, key): for fp in probe_data.recording_info["recording_files"] ] - # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - - probe_dataset = get_openephys_probe_data(key) - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } channel2electrode_map = { channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] + for channel_idx in probe_data.ap_meta["channels_indices"] } ephys_channel_entries = [ - {**key, "channel_idx": channel_idx, **channel_info} + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } for channel_idx, channel_info in channel2electrode_map.items() ] + + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError( f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." @@ -1041,10 +1029,7 @@ def make(self, key): output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) # Get channel and electrode-site mapping - electrode_query = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - ) + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { chn.pop("channel_idx"): chn for chn in channel2electrode_map @@ -1058,7 +1043,9 @@ def make(self, key): if si_waveform_dir.exists(): # Read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) + we: si.WaveformExtractor = si.load_waveforms( + si_waveform_dir, with_recording=False + ) si_sorting: si.sorters.BaseSorter = si.load_extractor( si_sorting_dir / "si_sorting.pkl" ) @@ -1139,10 +1126,10 @@ def make(self, key): "spike_count": spike_count_dict[unit_id], "spike_sites": new_spikes["electrode"][ new_spikes["unit_index"] == unit_id - ], + ], "spike_depths": new_spikes["depth"][ new_spikes["unit_index"] == unit_id - ], + ], } ) @@ -1281,10 +1268,7 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") # Get channel and electrode-site mapping - electrode_query = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - ) + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { chn.pop("channel_idx"): chn for chn in channel2electrode_map @@ -1292,12 +1276,14 @@ def make(self, key): si_waveform_dir = output_dir / sorter_name / "waveform" if si_waveform_dir.exists(): # read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) - unit_id_to_peak_channel_map: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + we: si.WaveformExtractor = si.load_waveforms( + si_waveform_dir, with_recording=False + ) + unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { @@ -1391,6 +1377,7 @@ def yield_unit_waveforms(): yield unit_peak_waveform, unit_electrode_waveforms # Spike interface mean and peak waveform extraction from we object + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): we_kilosort = si.load_waveforms(waveforms_folder[0].parent) unit_templates = we_kilosort.get_all_templates() @@ -1618,7 +1605,10 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") # find metric_fp - for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]: + for metric_fp in [ + output_dir / "metrics.csv", + output_dir / sorter_name / "metrics" / "metrics.csv", + ]: if metric_fp.exists(): break else: diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 4db5f303..22adbdca 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -2,7 +2,7 @@ from spikeinterface import preprocessing -def catGT(recording): +def CatGT(recording): recording = si.preprocessing.phase_shift(recording) recording = si.preprocessing.common_reference( recording, operator="median", reference="global" diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index d14746fb..c1a906ea 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -127,7 +127,7 @@ def make(self, key): stream_names, stream_ids = si.extractors.get_neo_streams( acq_software, folder_path=data_dir ) - si_recording: si.BaseRecording = si_extractor[acq_software]( + si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] ) @@ -184,7 +184,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) + si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) # Run sorting # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. @@ -241,8 +241,8 @@ def make(self, key): recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) - si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) + si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir) # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( From 05ccfdb80cee7418e58322ebb3bbb9f4a1df6b8e Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 29 Apr 2024 11:46:32 -0500 Subject: [PATCH 115/152] fix: update ingestion from spikeinterface results --- element_array_ephys/ephys_no_curation.py | 137 +++++------------------ 1 file changed, 27 insertions(+), 110 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 333a189a..0cf2021c 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1040,18 +1040,16 @@ def make(self, key): si_waveform_dir = output_dir / sorter_name / "waveform" si_sorting_dir = output_dir / sorter_name / "spike_sorting" - if si_waveform_dir.exists(): - - # Read from spikeinterface outputs + if si_waveform_dir.exists(): # Read from spikeinterface outputs we: si.WaveformExtractor = si.load_waveforms( si_waveform_dir, with_recording=False ) si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl" + si_sorting_dir / "si_sorting.pkl", base_folder=output_dir ) unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="id" + we, outputs="index" ) # {unit: peak_channel_id} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() @@ -1061,7 +1059,8 @@ def make(self, key): # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { - chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + chn_idx: channel2electrode_map[chn_idx] + for chn_idx in we.channel_ids_to_indices(we.channel_ids) } # Get unit id to quality label mapping @@ -1090,7 +1089,7 @@ def make(self, key): # Get channel depth channel_depth_ind = np.array( [ - channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"] + we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] for unit_id in si_sorting.unit_ids ] ) @@ -1132,7 +1131,6 @@ def make(self, key): ], } ) - else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, sample_rate = (EphysRecording & key).fetch1( @@ -1286,46 +1284,38 @@ def make(self, key): ) # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids + channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist() channel2electrode_map = { - chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices } - # Get mean waveform for each unit from all channels - mean_waveforms = we.get_all_templates( - mode="average" - ) # (unit x sample x channel) - - unit_peak_waveform = [] - unit_electrode_waveforms = [] - - for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): - unit_waveforms = we.get_template( - unit_id=unit["unit"], mode="average", force_dense=True - ) # (sample x channel) - peak_chn_idx = list(we.channel_ids).index( - unit_id_to_peak_channel_map[unit["unit"]][0] - ) - unit_peak_waveform.append( - { + def yield_unit_waveforms(): + for unit in (CuratedClustering.Unit & key).fetch( + "KEY", order_by="unit" + ): + # Get mean waveform for this unit from all channels - (sample x channel) + unit_waveforms = we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + ) + peak_chn_idx = channel_indices.index( + unit_id_to_peak_channel_map[unit["unit"]][0] + ) + unit_peak_waveform = { **unit, "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], } - ) - unit_electrode_waveforms.extend( - [ + + unit_electrode_waveforms = [ { **unit, - **channel2electrode_map[c], - "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx], + **channel2electrode_map[chn_idx], + "waveform_mean": unit_waveforms[:, chn_idx], } - for c_idx, c in enumerate(channel2electrode_map) + for chn_idx in channel_indices ] - ) - self.insert1(key) - self.PeakWaveform.insert(unit_peak_waveform) - self.Waveform.insert(unit_electrode_waveforms) - else: + yield unit_peak_waveform, unit_electrode_waveforms + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, probe_serial_number = ( @@ -1340,10 +1330,6 @@ def make(self, key): ) } - waveforms_folder = [ - f for f in output_dir.parent.rglob(r"*/waveforms*") if f.is_dir() - ] - if (output_dir / "mean_waveforms.npy").exists(): unit_waveforms = np.load( output_dir / "mean_waveforms.npy" @@ -1376,75 +1362,6 @@ def yield_unit_waveforms(): } yield unit_peak_waveform, unit_electrode_waveforms - # Spike interface mean and peak waveform extraction from we object - - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): - we_kilosort = si.load_waveforms(waveforms_folder[0].parent) - unit_templates = we_kilosort.get_all_templates() - unit_waveforms = np.reshape( - unit_templates, - ( - unit_templates.shape[1], - unit_templates.shape[3], - unit_templates.shape[2], - ), - ) - - # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrode_map[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrode_map[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - # Approach not using spike interface templates (ie. taking mean of each unit waveform) - # def yield_unit_waveforms(): - # for unit_id in we_kilosort.unit_ids: - # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) - # unit_peak_waveform = {} - # unit_electrode_waveforms = [] - # if unit_id in units: - # for channel, channel_waveform in zip( - # kilosort_dataset.data["channel_map"], unit_waveform - # ): - # unit_electrode_waveforms.append( - # { - # **units[unit_id], - # **channel2electrodes[channel], - # "waveform_mean": channel_waveform, - # } - # ) - # if ( - # channel2electrodes[channel]["electrode"] - # == units[unit_id]["electrode"] - # ): - # unit_peak_waveform = { - # **units[unit_id], - # "peak_electrode_waveform": channel_waveform, - # } - # yield unit_peak_waveform, unit_electrode_waveforms - else: if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) From 93895a965471902b3a3aa5448c7648ce09432928 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 8 May 2024 00:23:28 +0200 Subject: [PATCH 116/152] Refactor Quality Metrics Logic + blackformatting --- element_array_ephys/ephys_no_curation.py | 22 +++-- .../spike_sorting/si_spike_sorting.py | 84 +++++++++++++------ 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 0cf2021c..b0a8bc26 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1277,11 +1277,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( si_waveform_dir, with_recording=False ) - unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_map: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist() @@ -1315,6 +1315,7 @@ def yield_unit_waveforms(): ] yield unit_peak_waveform, unit_electrode_waveforms + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) @@ -1546,9 +1547,14 @@ def make(self, key): metrics_df.rename( columns={ - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", + "isi_violations_ratio": "isi_violation", + "isi_violations_count": "number_violation", + "silhouette": "silhouette_score", + "rp_contamination": "contamination_rate", + "drift_ptp": "max_drift", + "drift_mad": "cumulative_drift", + "half_width": "halfwidth", + "peak_trough_ratio": "pt_ratio", }, inplace=True, ) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c1a906ea..94f12f84 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -134,7 +134,9 @@ def make(self, key): # Add probe information to recording object electrodes_df = ( ( - ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + ephys.EphysRecording.Channel + * probe.ElectrodeConfig.Electrode + * probe.ProbeType.Electrode & key ) .fetch(format="frame") @@ -142,7 +144,9 @@ def make(self, key): ) # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]) + si_probe = readers.probe_geometry.to_probeinterface( + electrodes_df[["electrode", "x_coord", "y_coord", "shank"]] + ) si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) @@ -184,7 +188,9 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) # Run sorting # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. @@ -241,8 +247,12 @@ def make(self, key): recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) - si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir) + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) + si_sorting: si.sorters.BaseSorter = si.load_extractor( + sorting_file, base_folder=output_dir + ) # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( @@ -257,27 +267,46 @@ def make(self, key): **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) - # Calculate QC Metrics - metrics: pd.DataFrame = si.qualitymetrics.compute_quality_metrics( - we, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - # Add PCA based metrics. These will be added to the metrics dataframe above. + # Calculate Cluster and Waveform Metrics + + # To provide waveform_principal_component _ = si.postprocessing.compute_principal_components( waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) - metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + + # To estimate the location of each spike in the sorting output. + # The drift metrics require the `spike_locations` waveform extension. + _ = si.postprocessing.compute_spike_locations(waveform_extractor=we) + + # The `sd_ratio` metric requires the `spike_amplitudes` waveform extension. + # It is highly recommended before calculating amplitude-based quality metrics. + _ = si.postprocessing.compute_spike_amplitudes(waveform_extractor=we) + + # To compute correlograms for spike trains. + _ = si.postprocessing.compute_correlograms(we) + + metric_names = si.qualitymetrics.get_quality_metric_list() + metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + + # To compute commonly used cluster quality metrics. + qc_metrics = si.qualitymetrics.compute_quality_metrics( + waveform_extractor=we, + metric_names=metric_names, + ) + + # To compute commonly used waveform/template metrics. + template_metric_names = si.postprocessing.get_template_metric_names() + template_metric_names.extend(["amplitude", "duration"]) + + template_metrics = si.postprocessing.compute_template_metrics( + waveform_extractor=we, + include_multi_channel_metrics=True, + metric_names=template_metric_names, + ) + + # Save the output (metrics.csv to the output dir) + metrics = pd.DataFrame() + metrics = pd.concat([qc_metrics, template_metrics], axis=1) # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" @@ -285,9 +314,14 @@ def make(self, key): metrics.to_csv(metrics_output_dir / "metrics.csv") # Save to phy format - si.exporters.export_to_phy(waveform_extractor=we, output_folder=output_dir / sorter_name / "phy") + si.exporters.export_to_phy( + waveform_extractor=we, output_folder=output_dir / sorter_name / "phy" + ) # Generate spike interface report - si.exporters.export_report(waveform_extractor=we, output_folder=output_dir / sorter_name / "spikeinterface_report") + si.exporters.export_report( + waveform_extractor=we, + output_folder=output_dir / sorter_name / "spikeinterface_report", + ) self.insert1( { From bd3bb8e9eccb7df3f44fce7398549325f994dec8 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 8 May 2024 10:34:09 -0500 Subject: [PATCH 117/152] Update si_spike_sorting.py --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c1a906ea..1aea4ad0 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -42,6 +42,7 @@ def activate( create_tables=create_tables, add_objects=ephys.__dict__, ) + ephys.Clustering.key_source -= PreProcessing.key_source.proj() SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] From 403d1df30c18eb63f84b200ea8a861c59d9d6ac5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 9 May 2024 18:57:31 +0200 Subject: [PATCH 118/152] update `postprocessing` logic --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 94f12f84..4c90337e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -286,7 +286,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -308,7 +308,7 @@ def make(self, key): metrics = pd.DataFrame() metrics = pd.concat([qc_metrics, template_metrics], axis=1) - # Save the output (metrics.csv to the output dir) + # Save metrics.csv to the output dir metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) metrics.to_csv(metrics_output_dir / "metrics.csv") From c934e67ea6e5de2e30b35dbc10ab547e49917159 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 11:00:06 -0500 Subject: [PATCH 119/152] feat: prototyping with the new `sorting_analyzer` --- .../spike_sorting/si_spike_sorting.py | 27 +++++++++++++++++-- setup.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index ab803490..f7cb1e57 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -255,6 +255,29 @@ def make(self, key): sorting_file, base_folder=output_dir ) + # Sorting Analyzer + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + if analyzer_output_dir.exists(): + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + else: + sorting_analyzer = si.create_sorting_analyzer( + sorting=si_sorting, + recording=si_recording, + format="binary_folder", + folder=analyzer_output_dir, + sparse=True, + overwrite=True, + ) + + job_kwargs = params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_duration": "1s"}) + all_computable_extensions = ['random_spikes', 'waveforms', 'templates', 'noise_levels', 'amplitude_scalings', 'correlograms', 'isi_histograms', 'principal_components', 'spike_amplitudes', 'spike_locations', 'template_metrics', 'template_similarity', 'unit_locations', 'quality_metrics'] + extensions_to_compute = ['random_spikes', 'waveforms', 'templates', 'noise_levels', + 'spike_amplitudes', 'spike_locations', 'unit_locations', + 'principal_components', + 'template_metrics', 'quality_metrics'] + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, @@ -287,7 +310,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + # metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # TODO: temporarily removed # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -297,7 +320,7 @@ def make(self, key): # To compute commonly used waveform/template metrics. template_metric_names = si.postprocessing.get_template_metric_names() - template_metric_names.extend(["amplitude", "duration"]) + template_metric_names.extend(["amplitude", "duration"]) # TODO: does this do anything? template_metrics = si.postprocessing.compute_template_metrics( waveform_extractor=we, diff --git a/setup.py b/setup.py index 52cd38b1..e62719d8 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface", + "spikeinterface>=0.101.0", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 3666cda077448cc40d7b7e9c219c9c489396cbd6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 14:34:05 -0500 Subject: [PATCH 120/152] feat: update ingestion to be compatible with spikeinterface 0.101+ --- element_array_ephys/ephys_no_curation.py | 209 ++++++++---------- .../spike_sorting/si_spike_sorting.py | 93 ++------ 2 files changed, 116 insertions(+), 186 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b0a8bc26..413868da 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1037,98 +1037,69 @@ def make(self, key): # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") - si_waveform_dir = output_dir / sorter_name / "waveform" - si_sorting_dir = output_dir / sorter_name / "spike_sorting" + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - if si_waveform_dir.exists(): # Read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms( - si_waveform_dir, with_recording=False - ) - si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl", base_folder=output_dir - ) + if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + si_sorting = sorting_analyzer.sorting - unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_id} + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} + unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector() - # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] - for chn_idx in we.channel_ids_to_indices(we.channel_ids) + for chn_idx in sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ) } # Get unit id to quality label mapping - try: - cluster_quality_label_map = pd.read_csv( - si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", - delimiter="\t", + cluster_quality_label_map = { + int(unit_id): ( + si_sorting.get_unit_property(unit_id, "KSLabel") + if "KSLabel" in si_sorting.get_property_keys() + else "n.a." ) - except FileNotFoundError: - cluster_quality_label_map = {} - else: - cluster_quality_label_map: dict[ - int, str - ] = cluster_quality_label_map.set_index("cluster_id")[ - "KSLabel" - ].to_dict() # {unit: quality_label} - - # Get electrode where peak unit activity is recorded - peak_electrode_ind = np.array( - [ - channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] - for unit_id in si_sorting.unit_ids - ] - ) - - # Get channel depth - channel_depth_ind = np.array( - [ - we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] - for unit_id in si_sorting.unit_ids - ] - ) - - # Assign electrode and depth for each spike - new_spikes = np.empty( - spikes.shape, - spikes.dtype.descr + [("electrode", " Date: Fri, 24 May 2024 14:52:45 -0500 Subject: [PATCH 121/152] format: black formatting --- element_array_ephys/ephys_no_curation.py | 10 +++++++--- .../spike_sorting/si_spike_sorting.py | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 413868da..99247e35 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1256,7 +1256,9 @@ def make(self, key): unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices(sorting_analyzer.channel_ids).tolist() + channel_indices = sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ).tolist() channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices } @@ -1500,7 +1502,9 @@ def make(self, key): if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() - template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() + template_metrics = sorting_analyzer.get_extension( + "template_metrics" + ).get_data() metrics_df = pd.concat([qc_metrics, template_metrics], axis=1) metrics_df.rename( @@ -1514,7 +1518,7 @@ def make(self, key): "drift_mad": "cumulative_drift", "half_width": "halfwidth", "peak_trough_ratio": "pt_ratio", - "peak_to_valley": "duration" + "peak_to_valley": "duration", }, inplace=True, ) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 55c6efdd..33201d86 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -270,28 +270,33 @@ def make(self, key): overwrite=True, ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get("job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_to_compute = {ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params} + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } sorting_analyzer.compute(extensions_to_compute, **job_kwargs) # Save to phy format if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "phy", - **job_kwargs + sorting_analyzer=sorting_analyzer, + output_folder=output_dir / sorter_name / "phy", + **job_kwargs, ) # Generate spike interface report if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): si.exporters.export_report( sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs + **job_kwargs, ) self.insert1( From 07a09f6152b9632ce713287a85dedd0ad1bf8e9b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:28:52 -0500 Subject: [PATCH 122/152] chore: code clean up --- .../spike_sorting/si_spike_sorting.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 33201d86..a0ff2035 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -80,11 +80,9 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") for required_key in ( - "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", + "SI_SORTING_PARAMS", "SI_POSTPROCESSING_PARAMS", - "SI_WAVEFORM_EXTRACTION_PARAMS", - "SI_QUALITY_METRICS_PARAMS", ): if required_key not in params: raise ValueError( @@ -256,6 +254,10 @@ def make(self, key): sorting_file, base_folder=output_dir ) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" if (analyzer_output_dir / "extensions").exists(): @@ -268,14 +270,12 @@ def make(self, key): folder=analyzer_output_dir, sparse=True, overwrite=True, + **job_kwargs ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( - "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} - ) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) extensions_to_compute = { ext_name: extensions_params[ext_name] for ext_name in sorting_analyzer.get_computable_extensions() From 3fcf542d1435f4f891f2bbf93eaa3668da1986ea Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:29:09 -0500 Subject: [PATCH 123/152] update: update requirements to install `SpikeInterface` from github (latest version) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e62719d8..f1ba9c90 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface>=0.101.0", + "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 76dfc94568bf28296da18905d0b187588bc99397 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:32:19 -0500 Subject: [PATCH 124/152] fix: minor bug in spikes ingestion --- element_array_ephys/ephys_no_curation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 99247e35..9222ccd2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1048,8 +1048,8 @@ def make(self, key): si.ChannelSparsity.from_best_channels( sorting_analyzer, 1, peak_sign="neg" ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} - unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} + ) + unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1076,9 +1076,9 @@ def make(self, key): spikes_df = pd.DataFrame(spike_locations.spikes) units = [] - for unit_id in si_sorting.unit_ids: + for unit_idx, unit_id in enumerate(si_sorting.unit_ids): unit_id = int(unit_id) - unit_spikes_df = spikes_df[spikes_df.unit_index == unit_id] + unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx] spike_sites = np.array( [ channel2electrode_map[chn_idx]["electrode"] @@ -1087,6 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates + spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + + assert len(spike_times) == len(spike_sites) == len(spike_depths) units.append( { @@ -1094,9 +1097,7 @@ def make(self, key): **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map[unit_id], - "spike_times": si_sorting.get_unit_spike_train( - unit_id, return_times=True - ), + "spike_times": spike_times, "spike_count": spike_count_dict[unit_id], "spike_sites": spike_sites, "spike_depths": spike_depths, From 9094754b6f23bd65a71390094ac509e06d22b34c Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:38:59 -0500 Subject: [PATCH 125/152] update: bump version --- CHANGELOG.md | 5 +++++ element_array_ephys/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e45e427..5d81dcba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.4.0] - 2024-05-28 + ++ Add - support for SpikeInterface version >= 0.101.0 (updated API) + + ## [0.3.4] - 2024-03-22 + Add - pytest diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 148bac24..2e6de55a 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.3.4" +__version__ = "0.4.0" From 51e2ced3f36fa1b69bacf69ea1fbf295c84eaf16 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 13:14:00 -0500 Subject: [PATCH 126/152] feat: add `memoized_result` on spike sorting --- CHANGELOG.md | 1 + .../spike_sorting/si_spike_sorting.py | 103 ++++++++++-------- setup.py | 2 +- 3 files changed, 60 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81dcba..cd8bb5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and ## [0.4.0] - 2024-05-28 + Add - support for SpikeInterface version >= 0.101.0 (updated API) ++ Add - feature for memoization of spike sorting results (prevent duplicated runs) ## [0.3.4] - 2024-03-22 diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a0ff2035..dff74dd7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -8,7 +8,7 @@ import pandas as pd import spikeinterface as si from element_array_ephys import probe, readers -from element_interface.utils import find_full_path +from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -192,23 +192,29 @@ def make(self, key): recording_file, base_folder=output_dir ) + sorting_params = params["SI_SORTING_PARAMS"] + sorting_output_dir = output_dir / sorter_name / "spike_sorting" + # Run sorting - # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( - sorter_name=sorter_name, - recording=si_recording, - output_folder=output_dir / sorter_name / "spike_sorting", - remove_existing_folder=True, - verbose=True, - docker_image=sorter_name not in si.sorters.installed_sorters(), - **params.get("SI_SORTING_PARAMS", {}), + @memoized_result( + uniqueness_dict=sorting_params, + output_directory=sorting_output_dir, ) + def _run_sorter(): + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + output_folder=sorting_output_dir, + remove_existing_folder=True, + verbose=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), + **sorting_params, + ) - # Save sorting object - sorting_save_path = ( - output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - ) - si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + # Save sorting object + sorting_save_path = sorting_output_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) self.insert1( { @@ -254,15 +260,20 @@ def make(self, key): sorting_file, base_folder=output_dir ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} ) - # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" - if (analyzer_output_dir / "extensions").exists(): - sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) - else: + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir, + ) + def _sorting_analyzer_compute(): + # Sorting Analyzer sorting_analyzer = si.create_sorting_analyzer( sorting=si_sorting, recording=si_recording, @@ -273,31 +284,33 @@ def make(self, key): **job_kwargs ) - # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() - # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) - extensions_to_compute = { - ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params - } - - sorting_analyzer.compute(extensions_to_compute, **job_kwargs) - - # Save to phy format - if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): - si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "phy", - **job_kwargs, - ) - # Generate spike interface report - if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): - si.exporters.export_report( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs, - ) + # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() + # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = postprocessing_params.get("extensions", {}) + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + + # Save to phy format + if postprocessing_params.get("export_to_phy", False): + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + **job_kwargs, + ) + # Generate spike interface report + if postprocessing_params.get("export_report", True): + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + _sorting_analyzer_compute() self.insert1( { diff --git a/setup.py b/setup.py index f1ba9c90..66789740 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "numba", ], extras_require={ From 0afb4529de262fbee6b21461e5aec58765fd0e12 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 14:22:20 -0500 Subject: [PATCH 127/152] chore: minor code cleanup --- element_array_ephys/ephys_no_curation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 9222ccd2..b49d4422 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -8,14 +8,12 @@ import datajoint as dj import numpy as np import pandas as pd -import spikeinterface as si from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx -log = dj.logger +logger = dj.logger schema = dj.schema() @@ -824,7 +822,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): if mkdir: output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") + logger.info(f"{output_dir} created!") return output_dir.relative_to(processed_dir) if relative else output_dir @@ -1040,6 +1038,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) si_sorting = sorting_analyzer.sorting @@ -1246,6 +1246,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit @@ -1501,6 +1503,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( From e8f445c3b4b532b3159638e71d231e2048939a90 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 16:47:22 -0500 Subject: [PATCH 128/152] fix: merge fix & formatting --- element_array_ephys/spike_sorting/si_spike_sorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index dff74dd7..9e14f636 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -248,7 +248,6 @@ def make(self, key): ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -281,7 +280,7 @@ def _sorting_analyzer_compute(): folder=analyzer_output_dir, sparse=True, overwrite=True, - **job_kwargs + **job_kwargs, ) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() From 6155f13fd755ac76ec79fdd1594b0e96ef8d550b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 17:01:10 -0500 Subject: [PATCH 129/152] fix: calling `_run_sorter()` --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 9e14f636..5c1d6567 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -216,6 +216,8 @@ def _run_sorter(): sorting_save_path = sorting_output_dir / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + _run_sorter() + self.insert1( { **key, From f6a52d9d3f31b7ebe2853da4545551898cfa50ae Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 20:07:27 -0500 Subject: [PATCH 130/152] chore: more robust channel mapping --- element_array_ephys/ephys_no_curation.py | 29 ++++++++---------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b49d4422..142f350b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1028,9 +1028,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } # Get sorter method and create output directory. @@ -1054,12 +1053,10 @@ def make(self, key): spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - # reorder channel2electrode_map according to recording channel ids + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] - for chn_idx in sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ) + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } # Get unit id to quality label mapping @@ -1239,9 +1236,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" @@ -1258,12 +1254,10 @@ def make(self, key): ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} - # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ).tolist() + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } templates = sorting_analyzer.get_extension("templates") @@ -1276,12 +1270,9 @@ def yield_unit_waveforms(): unit_waveforms = templates.get_unit_template( unit_id=unit["unit"], operator="average" ) - peak_chn_idx = channel_indices.index( - unit_peak_channel[unit["unit"]] - ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], + "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], } unit_electrode_waveforms = [ @@ -1290,7 +1281,7 @@ def yield_unit_waveforms(): **channel2electrode_map[chn_idx], "waveform_mean": unit_waveforms[:, chn_idx], } - for chn_idx in channel_indices + for chn_idx in channel2electrode_map ] yield unit_peak_waveform, unit_electrode_waveforms From 1ff92dd15db6ff9e8458f53ec96fdffb6b93305d Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:09:16 -0500 Subject: [PATCH 131/152] fix: use relative path for phy output --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 5c1d6567..93619303 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -301,6 +301,7 @@ def _sorting_analyzer_compute(): si.exporters.export_to_phy( sorting_analyzer=sorting_analyzer, output_folder=analyzer_output_dir / "phy", + use_relative_path=True, **job_kwargs, ) # Generate spike interface report From b45970974df001319a4ebae182bf291313f5e39a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:16:21 -0500 Subject: [PATCH 132/152] feat: in data ingestion, set peak_sign="both" --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 142f350b..8eadba49 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,7 +1045,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1249,7 +1249,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} From 1a1b18f8a52b83298bffc8d82555ccc147151dd1 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 3 Jun 2024 13:22:49 -0500 Subject: [PATCH 133/152] feat: replace `output_folder` with `folder` when calling `run_sorter`, use default value for `peak_sign` --- element_array_ephys/ephys_no_curation.py | 21 ++++++++++++------- .../spike_sorting/si_spike_sorting.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8eadba49..891cee0f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,10 +1045,13 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, + 1, ).unit_id_to_channel_indices ) - unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1084,7 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates - spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + spike_times = si_sorting.get_unit_spike_train( + unit_id, return_times=True + ) assert len(spike_times) == len(spike_sites) == len(spike_depths) @@ -1243,13 +1248,13 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, 1 ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1272,7 +1277,9 @@ def yield_unit_waveforms(): ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], + "peak_electrode_waveform": unit_waveforms[ + :, unit_peak_channel[unit["unit"]] + ], } unit_electrode_waveforms = [ @@ -1495,7 +1502,7 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 93619303..57aa0ba1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,7 +205,7 @@ def _run_sorter(): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorting_output_dir, + folder=sorting_output_dir, remove_existing_folder=True, verbose=True, docker_image=sorter_name not in si.sorters.installed_sorters(), From 4e645ebd9b83f5e607e1d18188c0c3ce5f84eb4a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 5 Jun 2024 16:15:25 -0500 Subject: [PATCH 134/152] fix: remove `job_kwargs` for sparsity calculation - memory error in linux container --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 57aa0ba1..b93d9c10 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -282,7 +282,6 @@ def _sorting_analyzer_compute(): folder=analyzer_output_dir, sparse=True, overwrite=True, - **job_kwargs, ) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() From 38fdfb2a5fd44f1115aa4f1660482e1639eaa3c2 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 10:59:37 -0500 Subject: [PATCH 135/152] feat: separate `export` (phy and report) into a separate table --- .../spike_sorting/si_spike_sorting.py | 94 +++++++++++++++---- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index b93d9c10..463af3df 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -239,6 +239,7 @@ class PostProcessing(dj.Imported): --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours + do_si_export=1: bool # whether to export to phy """ def make(self, key): @@ -295,22 +296,6 @@ def _sorting_analyzer_compute(): sorting_analyzer.compute(extensions_to_compute, **job_kwargs) - # Save to phy format - if postprocessing_params.get("export_to_phy", False): - si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, - output_folder=analyzer_output_dir / "phy", - use_relative_path=True, - **job_kwargs, - ) - # Generate spike interface report - if postprocessing_params.get("export_report", True): - si.exporters.export_report( - sorting_analyzer=sorting_analyzer, - output_folder=analyzer_output_dir / "spikeinterface_report", - **job_kwargs, - ) - _sorting_analyzer_compute() self.insert1( @@ -321,6 +306,8 @@ def _sorting_analyzer_compute(): datetime.utcnow() - execution_time ).total_seconds() / 3600, + "do_si_export": postprocessing_params.get("export_to_phy", False) + or postprocessing_params.get("export_report", False), } ) @@ -328,3 +315,78 @@ def _sorting_analyzer_compute(): ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + + +@schema +class SIExport(dj.Computed): + """A SpikeInterface export report and to Phy""" + + definition = """ + -> PostProcessing + --- + execution_time: datetime + execution_duration: float + """ + + @property + def key_source(self): + return PostProcessing & "do_si_export = 1" + + def make(self, key): + execution_time = datetime.utcnow() + + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "phy", + ) + def _export_to_phy(): + # Save to phy format + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + use_relative_path=True, + **job_kwargs, + ) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "spikeinterface_report", + ) + def _export_report(): + # Generate spike interface report + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + if postprocessing_params.get("export_report", False): + _export_report() + if postprocessing_params.get("export_to_phy", False): + _export_to_phy() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) From a4a8380405673bf2c85861223afa0c9e5e481296 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 11:00:42 -0500 Subject: [PATCH 136/152] fix: export default to `False` --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 463af3df..6f2d7b53 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -239,7 +239,7 @@ class PostProcessing(dj.Imported): --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours - do_si_export=1: bool # whether to export to phy + do_si_export=0: bool # whether to export to phy """ def make(self, key): @@ -331,7 +331,7 @@ class SIExport(dj.Computed): @property def key_source(self): return PostProcessing & "do_si_export = 1" - + def make(self, key): execution_time = datetime.utcnow() From 1f05998e25d848b6aeb73231fa90e616580cd1d8 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 16:45:40 -0500 Subject: [PATCH 137/152] fix: `spikes` object no longer available from `ComputeSpikeLocations` (https://github.com/SpikeInterface/spikeinterface/pull/3015) --- element_array_ephys/ephys_no_curation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 891cee0f..5df8bad0 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1073,7 +1073,9 @@ def make(self, key): } spike_locations = sorting_analyzer.get_extension("spike_locations") - spikes_df = pd.DataFrame(spike_locations.spikes) + extremum_channel_inds = si.template_tools.get_template_extremum_channel(sorting_analyzer, outputs="index") + spikes_df = pd.DataFrame( + sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)) units = [] for unit_idx, unit_id in enumerate(si_sorting.unit_ids): From 7cd8ac8ce8eeb731f149924279bb3b0d990caa45 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 21:26:48 -0500 Subject: [PATCH 138/152] chore: code cleanup --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6f2d7b53..8624e073 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -9,7 +9,7 @@ import spikeinterface as si from element_array_ephys import probe, readers from element_interface.utils import find_full_path, memoized_result -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters +from spikeinterface import exporters, extractors, sorters from . import si_preprocessing From c87e49332f90386acc8eb696e65f87bfd7b6ae24 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Sat, 6 Jul 2024 08:00:44 -0500 Subject: [PATCH 139/152] fix: recording_extractor_full_dict is deprecated (https://github.com/SpikeInterface/spikeinterface/pull/3153) --- .../spike_sorting/si_spike_sorting.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 8624e073..7133b81c 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -111,25 +111,30 @@ def make(self, key): ) spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent + + si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software, folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] + si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software, folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) else: raise NotImplementedError( f"SpikeInterface processing for {acq_software} not yet implemented." ) - acq_software = acq_software.replace(" ", "").lower() - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[acq_software] - ) # data extractor object - - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir - ) - si_recording: si.BaseRecording = si_extractor( - folder_path=data_dir, stream_name=stream_names[0] - ) # Add probe information to recording object electrodes_df = ( From 097d9bbf7694e40a839b4cebb49890d1acd325f1 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 30 Jul 2024 18:26:37 -0500 Subject: [PATCH 140/152] fix: bugfix spikeinterface extractor name --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 7133b81c..550ae4a1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -114,7 +114,7 @@ def make(self, key): si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir + "spikeglx", folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] @@ -126,7 +126,7 @@ def make(self, key): si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir + "openephysbinary", folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] From b6f131b814ed9dba2e2cc38d6918df52668dd590 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 15 Aug 2024 17:24:09 -0500 Subject: [PATCH 141/152] update: element-interface `main` branch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 66789740..f1ba9c90 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", + "element-interface @ git+https://github.com/datajoint/element-interface.git", "numba", ], extras_require={ From ccd23fc413d126f897c20cececcc35b86cb5190f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 12:04:48 -0500 Subject: [PATCH 142/152] rearrange(all): major refactor of modules --- CHANGELOG.md | 9 + element_array_ephys/__init__.py | 4 +- .../{ephys_no_curation.py => ephys.py} | 13 +- element_array_ephys/ephys_acute.py | 1594 ----------------- element_array_ephys/ephys_chronic.py | 1523 ---------------- element_array_ephys/ephys_precluster.py | 1435 --------------- element_array_ephys/ephys_report.py | 14 +- element_array_ephys/export/nwb/nwb.py | 9 +- .../spike_sorting/si_spike_sorting.py | 28 +- element_array_ephys/version.py | 2 +- tests/tutorial_pipeline.py | 6 +- 11 files changed, 43 insertions(+), 4594 deletions(-) rename element_array_ephys/{ephys_no_curation.py => ephys.py} (99%) delete mode 100644 element_array_ephys/ephys_acute.py delete mode 100644 element_array_ephys/ephys_chronic.py delete mode 100644 element_array_ephys/ephys_precluster.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7068216b..34d1a2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,15 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. + +## [1.0.0] - 2024-09-10 + ++ Update - No longer support multiple variation of ephys module, keep only `ephys_no_curation` module, renamed to `ephys` ++ Update - Remove other ephys modules (e.g. `ephys_acute`, `ephys_chronic`) (moved to different branches) ++ Update - Add support for `SpikeInterface` ++ Update - Remove support for `ecephys_spike_sorting` (moved to a different branch) ++ Update - Simplify the "activate" mechanism + ## [0.4.0] - 2024-08-16 + Add - support for SpikeInterface version >= 0.101.0 (updated API) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 1c0c7285..079950b4 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1 +1,3 @@ -from . import ephys_acute as ephys +from . import ephys + +ephys_no_curation = ephys # alias for backward compatibility diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys.py similarity index 99% rename from element_array_ephys/ephys_no_curation.py rename to element_array_ephys/ephys.py index 5df8bad0..3025d289 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys.py @@ -10,7 +10,7 @@ import pandas as pd from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from . import ephys_report, probe +from . import probe from .readers import kilosort, openephys, spikeglx logger = dj.logger @@ -22,7 +22,6 @@ def activate( ephys_schema_name: str, - probe_schema_name: str = None, *, create_schema: bool = True, create_tables: bool = True, @@ -32,7 +31,6 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -46,7 +44,6 @@ def activate( get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ if isinstance(linking_module, str): @@ -58,17 +55,15 @@ def activate( global _linking_module _linking_module = linking_module - # activate - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + schema.activate( ephys_schema_name, create_schema=create_schema, create_tables=create_tables, add_objects=_linking_module.__dict__, ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) # -------------- Functions required by the elements-ephys --------------- diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py deleted file mode 100644 index c2627fc9..00000000 --- a/element_array_ephys/ephys_acute.py +++ /dev/null @@ -1,1594 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -import re -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - @classmethod - def auto_generate_entries(cls, session_key): - """Automatically populate entries in ProbeInsertion table for a session.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(session_key) - ) - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}" - ) - - probe_list, probe_insertion_list = [], [] - if acq_software == "SpikeGLX": - for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - probe_key = { - "probe_type": spikeglx_meta.probe_model, - "probe": spikeglx_meta.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - - probe_dir = meta_filepath.parent - try: - probe_number = re.search("(imec)?\d{1}$", probe_dir.name).group() - probe_number = int(probe_number.replace("imec", "")) - except AttributeError: - probe_number = meta_fp_idx - - probe_insertion_list.append( - { - **session_key, - "probe": spikeglx_meta.probe_SN, - "insertion_number": int(probe_number), - } - ) - elif acq_software == "Open Ephys": - loaded_oe = openephys.OpenEphys(session_dir) - for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()): - probe_key = { - "probe_type": oe_probe.probe_model, - "probe": oe_probe.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - probe_insertion_list.append( - { - **session_key, - "probe": oe_probe.probe_SN, - "insertion_number": probe_idx, - } - ) - else: - raise NotImplementedError(f"Unknown acquisition software: {acq_software}") - - probe.Probe.insert(probe_list, skip_duplicates=True) - cls.insert(probe_insertion_list, skip_duplicates=True) - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description ( varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_output_dir ( varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir( - cls, key: dict, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / session_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version ( varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir ( varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note ( varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note=""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py deleted file mode 100644 index 772e885f..00000000 --- a/element_array_ephys/ephys_chronic.py +++ /dev/null @@ -1,1523 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion chronically implanted into an animal. - -> Subject - insertion_number: tinyint unsigned - --- - -> probe.Probe - insertion_datetime=null: datetime - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> Session - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - f"No SpikeGLX data found for probe insertion: {key}" - + " The probe serial number does not match." - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - sess_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(key)) - root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / sess_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_precluster.py b/element_array_ephys/ephys_precluster.py deleted file mode 100644 index 4d52c610..00000000 --- a/element_array_ephys/ephys_precluster.py +++ /dev/null @@ -1,1435 +0,0 @@ -import importlib -import inspect -import re - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - return _linking_module.get_ephys_root_data_dir() - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis - - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)] - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", spikeglx_meta.probe_model): - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", probe_data.probe_model): - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_ids"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class PreClusterMethod(dj.Lookup): - """Pre-clustering method - - Attributes: - precluster_method (foreign key, varchar(16) ): Pre-clustering method for the dataset. - precluster_method_desc(varchar(1000) ): Pre-clustering method description. - """ - - definition = """ - # Method for pre-clustering - precluster_method: varchar(16) - --- - precluster_method_desc: varchar(1000) - """ - - contents = [("catgt", "Time shift, Common average referencing, Zeroing")] - - -@schema -class PreClusterParamSet(dj.Lookup): - """Parameters for the pre-clustering method. - - Attributes: - paramset_idx (foreign key): Unique parameter set ID. - PreClusterMethod (dict): PreClusterMethod query for this dataset. - paramset_desc (varchar(128) ): Description for the pre-clustering parameter set. - param_set_hash (uuid): Unique hash for parameter set. - params (longblob): All parameters for the pre-clustering method. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> PreClusterMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, precluster_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - param_dict = { - "precluster_method": precluster_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class PreClusterParamSteps(dj.Manual): - """Ordered list of parameter sets that will be run. - - Attributes: - precluster_param_steps_id (foreign key): Unique ID for the pre-clustering parameter sets to be run. - precluster_param_steps_name (varchar(32) ): User-friendly name for the parameter steps. - precluster_param_steps_desc (varchar(128) ): Description of the parameter steps. - """ - - definition = """ - # Ordered list of paramset_idx that are to be run - # When pre-clustering is not performed, do not create an entry in `Step` Part table - precluster_param_steps_id: smallint - --- - precluster_param_steps_name: varchar(32) - precluster_param_steps_desc: varchar(128) - """ - - class Step(dj.Part): - """Define the order of operations for parameter sets. - - Attributes: - PreClusterParamSteps (foreign key): PreClusterParamSteps primary key. - step_number (foreign key, smallint): Order of operations. - PreClusterParamSet (dict): PreClusterParamSet to be used in pre-clustering. - """ - - definition = """ - -> master - step_number: smallint # Order of operations - --- - -> PreClusterParamSet - """ - - -@schema -class PreClusterTask(dj.Manual): - """Defines a pre-clustering task ready to be run. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - PreclusterParamSteps (foreign key): PreClusterParam Steps primary key. - precluster_output_dir (varchar(255) ): relative path to directory for storing results of pre-clustering. - task_mode (enum ): `none` (no pre-clustering), `load` results from file, or `trigger` automated pre-clustering. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> PreClusterParamSteps - --- - precluster_output_dir='': varchar(255) # pre-clustering output directory relative to the root data directory - task_mode='none': enum('none','load', 'trigger') # 'none': no pre-clustering analysis - # 'load': load analysis results - # 'trigger': trigger computation - """ - - -@schema -class PreCluster(dj.Imported): - """ - A processing table to handle each PreClusterTask: - - Attributes: - PreClusterTask (foreign key): PreClusterTask primary key. - precluster_time (datetime): Time of generation of this set of pre-clustering results. - package_version (varchar(16) ): Package version used for performing pre-clustering. - """ - - definition = """ - -> PreClusterTask - --- - precluster_time: datetime # time of generation of this set of pre-clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Populate pre-clustering tables.""" - task_mode, output_dir = (PreClusterTask & key).fetch1( - "task_mode", "precluster_output_dir" - ) - precluster_output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "none": - if len((PreClusterParamSteps.Step & key).fetch()) > 0: - raise ValueError( - "There are entries in the PreClusterParamSteps.Step " - "table and task_mode=none" - ) - creation_time = (EphysRecording & key).fetch1("recording_datetime") - elif task_mode == "load": - acq_software = (EphysRecording & key).fetch1("acq_software") - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in precluster_output_dir.rglob("*.ap.meta"): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - creation_time = spikeglx_meta.recording_time - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - else: - raise NotImplementedError( - f"Pre-clustering analysis of {acq_software}" "is not yet supported." - ) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" - " pre-clustering analysis is not yet supported." - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "precluster_time": creation_time, "package_version": ""}) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> PreCluster - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1( - "acq_software", "probe" - ) - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - loaded_oe = openephys.OpenEphys(session_dir) - oe_probe = loaded_oe.probes[probe_sn] - - lfp_channel_ind = np.arange(len(oe_probe.lfp_meta["channels_ids"]))[ - -1 :: -self._skip_channel_counts - ] - - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - for channel_idx in np.array(oe_probe.lfp_meta["channels_ids"])[ - lfp_channel_ind - ]: - electrode_keys.append(probe_electrodes[channel_idx]) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort", "kilosort clustering method"), - ("kilosort2", "kilosort2 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Paramset, dictionary of all applicable parameters. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, processing_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - processing_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - param_dict = { - "clustering_method": processing_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> PreCluster - -> ClusteringParamSet - --- - clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - _ = kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" " clustering analysis is not yet supported" - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software = (EphysRecording & key).fetch1("acq_software") - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / kilosort_dataset.data["params"]["sample_rate"] - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = ( - EphysRecording.EphysFile & ephys_recording_key & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1("probe") - probe_dataset = openephys_dataset.probes[probe_serial_number] - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_ids"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py index 48bcf613..c962d33d 100644 --- a/element_array_ephys/ephys_report.py +++ b/element_array_ephys/ephys_report.py @@ -7,26 +7,24 @@ import datajoint as dj from element_interface.utils import dict_to_uuid -from . import probe +from . import probe, ephys schema = dj.schema() -ephys = None - -def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): +def activate(schema_name, *, create_schema=True, create_tables=True): """Activate the current schema. Args: schema_name (str): schema name on the database server to activate the `ephys_report` schema. - ephys_schema_name (str): schema name of the activated ephys element for which - this ephys_report schema will be downstream from. create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. """ + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") - global ephys - ephys = dj.create_virtual_module("ephys", ephys_schema_name) schema.activate( schema_name, create_schema=create_schema, diff --git a/element_array_ephys/export/nwb/nwb.py b/element_array_ephys/export/nwb/nwb.py index a45eb754..8d7da8f5 100644 --- a/element_array_ephys/export/nwb/nwb.py +++ b/element_array_ephys/export/nwb/nwb.py @@ -17,14 +17,7 @@ from spikeinterface import extractors from tqdm import tqdm -from ... import ephys_no_curation as ephys -from ... import probe - -ephys_mode = os.getenv("EPHYS_MODE", dj.config["custom"].get("ephys_mode", "acute")) -if ephys_mode != "no-curation": - raise NotImplementedError( - "This export function is designed for the no_curation " + "schema" - ) +from ... import probe, ephys class DecimalEncoder(json.JSONEncoder): diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 550ae4a1..547fd8ce 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -1,5 +1,7 @@ """ -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. +Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) +If you use this pipeline, please cite SpikeInterface and the relevant sorter(s) used in your publication (see https://github.com/SpikeInterface for additional details for citation). """ from datetime import datetime @@ -7,7 +9,7 @@ import datajoint as dj import pandas as pd import spikeinterface as si -from element_array_ephys import probe, readers +from element_array_ephys import probe, ephys, readers from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, extractors, sorters @@ -17,25 +19,25 @@ schema = dj.schema() -ephys = None - def activate( schema_name, *, - ephys_module, create_schema=True, create_tables=True, ): + """Activate the current schema. + + Args: + schema_name (str): schema name on the database server to activate the `si_spike_sorting` schema. + create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. + create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. """ - activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) - :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from - :param create_schema: when True (default), create schema in the database if it does not yet exist. - :param create_tables: when True (default), create tables in the database if they do not yet exist. - """ - global ephys - ephys = ephys_module + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") + schema.activate( schema_name, create_schema=create_schema, diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 2e6de55a..19ba4c76 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.4.0" +__version__ = "1.0.0" diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py index 74b27ddc..1b27027d 100644 --- a/tests/tutorial_pipeline.py +++ b/tests/tutorial_pipeline.py @@ -3,7 +3,7 @@ import datajoint as dj from element_animal import subject from element_animal.subject import Subject -from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report +from element_array_ephys import probe, ephys, ephys_report from element_lab import lab from element_lab.lab import Lab, Location, Project, Protocol, Source, User from element_lab.lab import Device as Equipment @@ -62,7 +62,9 @@ def get_session_directory(session_key): return pathlib.Path(session_directory) -ephys.activate(db_prefix + "ephys", db_prefix + "probe", linking_module=__name__) +probe.activate(db_prefix + "probe") +ephys.activate(db_prefix + "ephys", linking_module=__name__) +ephys_report.activate(db_prefix + "ephys_report") probe.create_neuropixels_probe_types() From 0eef1cbaec2494b7dec7a5af2e8d9d62986280cb Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 15:07:34 -0500 Subject: [PATCH 143/152] rearrange: remove the `ecephys_spike_sorting` flow --- element_array_ephys/ephys.py | 2 +- .../spike_sorting/ecephys_spike_sorting.py | 317 ------------------ .../kilosort_triggering.py | 0 3 files changed, 1 insertion(+), 318 deletions(-) delete mode 100644 element_array_ephys/spike_sorting/ecephys_spike_sorting.py rename element_array_ephys/{readers => spike_sorting}/kilosort_triggering.py (100%) diff --git a/element_array_ephys/ephys.py b/element_array_ephys/ephys.py index 3025d289..f17527c1 100644 --- a/element_array_ephys/ephys.py +++ b/element_array_ephys/ephys.py @@ -897,7 +897,7 @@ def make(self, key): ).fetch1("acq_software", "clustering_method", "params") if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering + from .spike_sorting import kilosort_triggering # add additional probe-recording and channels details into `params` params = {**params, **get_recording_channels_details(key)} diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py deleted file mode 100644 index 3a43c384..00000000 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ /dev/null @@ -1,317 +0,0 @@ -""" -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"ecephys_spike_sorting" pipeline. -The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. -Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. - -At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. -https://github.com/datajoint-company/ecephys_spike_sorting - -The follow pipeline features three tables: -1. KilosortPreProcessing - for preprocessing steps (no GPU required) - - median_subtraction for Open Ephys - - or the CatGT step for SpikeGLX -2. KilosortClustering - kilosort (MATLAB) - requires GPU - - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. KilosortPostProcessing - for postprocessing steps (no GPU required) - - kilosort_postprocessing - - noise_templates - - mean_waveforms - - quality_metrics -""" - - -import datajoint as dj -from decimal import Decimal -import json -from datetime import datetime, timedelta - -from element_interface.utils import find_full_path -from element_array_ephys.readers import ( - spikeglx, - kilosort_triggering, -) - -log = dj.logger - -schema = dj.schema() - -ephys = None - -_supported_kilosort_versions = [ - "kilosort2", - "kilosort2.5", - "kilosort3", -] - - -def activate( - schema_name, - *, - ephys_module, - create_schema=True, - create_tables=True, -): - """ - activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) - :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from - :param create_schema: when True (default), create schema in the database if it does not yet exist. - :param create_tables: when True (default), create tables in the database if they do not yet exist. - """ - global ephys - ephys = ephys_module - schema.activate( - schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=ephys.__dict__, - ) - - -@schema -class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> ephys.ClusteringTask - --- - params: longblob # finalized parameterset for this run - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - @property - def key_source(self): - return ( - ephys.ClusteringTask * ephys.ClusteringParamSet - & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering - - def make(self, key): - """Triggers or imports clustering analysis.""" - execution_time = datetime.utcnow() - - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' - - if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir( - key, relative=True, mkdir=True - ) - # update clustering_output_dir - ephys.ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.get("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_CatGT() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["depth_estimation", "median_subtraction"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "params": params, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortPreProcessing - --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortClustering - --- - modules_status: longblob # dictionary of summary status for all modules - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort.run_modules() - - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) - - self.insert1( - { - **key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1( - {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True - ) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/spike_sorting/kilosort_triggering.py similarity index 100% rename from element_array_ephys/readers/kilosort_triggering.py rename to element_array_ephys/spike_sorting/kilosort_triggering.py From c2bd5adb07096a04c7afc28418f1f996533fdf8b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 15:23:04 -0500 Subject: [PATCH 144/152] chore: clean up diagrams --- ...n.svg => attached_array_ephys_element.svg} | 0 images/attached_array_ephys_element_acute.svg | 451 --------------- .../attached_array_ephys_element_chronic.svg | 456 --------------- ...ttached_array_ephys_element_precluster.svg | 535 ------------------ 4 files changed, 1442 deletions(-) rename images/{attached_array_ephys_element_no_curation.svg => attached_array_ephys_element.svg} (100%) delete mode 100644 images/attached_array_ephys_element_acute.svg delete mode 100644 images/attached_array_ephys_element_chronic.svg delete mode 100644 images/attached_array_ephys_element_precluster.svg diff --git a/images/attached_array_ephys_element_no_curation.svg b/images/attached_array_ephys_element.svg similarity index 100% rename from images/attached_array_ephys_element_no_curation.svg rename to images/attached_array_ephys_element.svg diff --git a/images/attached_array_ephys_element_acute.svg b/images/attached_array_ephys_element_acute.svg deleted file mode 100644 index 5b2bc265..00000000 --- a/images/attached_array_ephys_element_acute.svg +++ /dev/null @@ -1,451 +0,0 @@ - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.Probe - - -probe.Probe - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -subject.Subject - - -subject.Subject - - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.EphysRecording->ephys.LFP - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_chronic.svg b/images/attached_array_ephys_element_chronic.svg deleted file mode 100644 index 808a2f17..00000000 --- a/images/attached_array_ephys_element_chronic.svg +++ /dev/null @@ -1,456 +0,0 @@ - - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->ephys.ProbeInsertion - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.EphysRecording->ephys.LFP - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -session.Session->ephys.EphysRecording - - - - -ephys.Clustering->ephys.Curation - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_precluster.svg b/images/attached_array_ephys_element_precluster.svg deleted file mode 100644 index 7d854d2e..00000000 --- a/images/attached_array_ephys_element_precluster.svg +++ /dev/null @@ -1,535 +0,0 @@ - - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.PreClusterTask - - -ephys.PreClusterTask - - - - - -ephys.EphysRecording->ephys.PreClusterTask - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.PreCluster - - -ephys.PreCluster - - - - - -ephys.PreClusterTask->ephys.PreCluster - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.PreCluster->ephys.LFP - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.PreCluster->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.PreClusterParamSteps.Step - - -ephys.PreClusterParamSteps.Step - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -session.Session - - -session.Session - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->session.Session - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.PreClusterMethod - - -ephys.PreClusterMethod - - - - - -ephys.PreClusterParamSet - - -ephys.PreClusterParamSet - - - - - -ephys.PreClusterMethod->ephys.PreClusterParamSet - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType->probe.Probe - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.PreClusterParamSteps - - -ephys.PreClusterParamSteps - - - - - -ephys.PreClusterParamSteps->ephys.PreClusterTask - - - - -ephys.PreClusterParamSteps->ephys.PreClusterParamSteps.Step - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.PreClusterParamSet->ephys.PreClusterParamSteps.Step - - - - \ No newline at end of file From 497110816058ae0655cac8f9414b4622905f78a6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Sep 2024 13:29:45 -0500 Subject: [PATCH 145/152] fix: use tempfile.TemporaryDirectory --- element_array_ephys/ephys_report.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py index c962d33d..0c6836a0 100644 --- a/element_array_ephys/ephys_report.py +++ b/element_array_ephys/ephys_report.py @@ -2,6 +2,7 @@ import datetime import pathlib +import tempfile from uuid import UUID import datajoint as dj @@ -53,7 +54,7 @@ class ProbeLevelReport(dj.Computed): def make(self, key): from .plotting.probe_level import plot_driftmap - save_dir = _make_save_dir() + save_dir = tempfile.TemporaryDirectory() units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'" @@ -88,13 +89,15 @@ def make(self, key): fig_dict = _save_figs( figs=(fig,), fig_names=("drift_map_plot",), - save_dir=save_dir, + save_dir=save_dir.name, fig_prefix=fig_prefix, extension=".png", ) self.insert1({**key, **fig_dict, "shank": shank_no}) + save_dir.cleanup() + @schema class UnitLevelReport(dj.Computed): @@ -266,17 +269,10 @@ def make(self, key): ) -def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path: - if root_dir is None: - root_dir = pathlib.Path().absolute() - save_dir = root_dir / "temp_ephys_figures" - save_dir.mkdir(parents=True, exist_ok=True) - return save_dir - - def _save_figs( figs, fig_names, save_dir, fig_prefix, extension=".png" ) -> dict[str, pathlib.Path]: + save_dir = pathlib.Path(save_dir) fig_dict = {} for fig, fig_name in zip(figs, fig_names): fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension) From 63df4cda7d5ab97d1d52195acf0d7031fe2496f6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Sep 2024 16:27:13 -0500 Subject: [PATCH 146/152] format: black --- element_array_ephys/ephys.py | 9 +++++++-- element_array_ephys/spike_sorting/si_spike_sorting.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/ephys.py b/element_array_ephys/ephys.py index 02e1366e..ad9bb8d7 100644 --- a/element_array_ephys/ephys.py +++ b/element_array_ephys/ephys.py @@ -1068,9 +1068,14 @@ def make(self, key): } spike_locations = sorting_analyzer.get_extension("spike_locations") - extremum_channel_inds = si.template_tools.get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = si.template_tools.get_template_extremum_channel( + sorting_analyzer, outputs="index" + ) spikes_df = pd.DataFrame( - sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)) + sorting_analyzer.sorting.to_spike_vector( + extremum_channel_inds=extremum_channel_inds + ) + ) units = [] for unit_idx, unit_id in enumerate(si_sorting.unit_ids): diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 547fd8ce..e2f011e1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -114,7 +114,9 @@ def make(self, key): spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent - si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + si_extractor = ( + si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + ) stream_names, stream_ids = si.extractors.get_neo_streams( "spikeglx", folder_path=data_dir ) @@ -125,7 +127,9 @@ def make(self, key): oe_probe = ephys.get_openephys_probe_data(key) assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] - si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + si_extractor = ( + si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + ) stream_names, stream_ids = si.extractors.get_neo_streams( "openephysbinary", folder_path=data_dir From 800d060e83cac12bc87adcaf0061da5100263932 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 20 Sep 2024 18:44:21 -0400 Subject: [PATCH 147/152] Update docs for new ephys release --- docs/docker-compose.yaml | 6 - docs/mkdocs.yaml | 13 +- docs/src/concepts.md | 32 ++--- docs/src/index.md | 32 ++--- docs/src/tutorials/index.md | 40 ++----- notebooks/demo_prepare.ipynb | 224 ----------------------------------- notebooks/demo_run.ipynb | 107 ----------------- 7 files changed, 36 insertions(+), 418 deletions(-) delete mode 100644 notebooks/demo_prepare.ipynb delete mode 100644 notebooks/demo_run.ipynb diff --git a/docs/docker-compose.yaml b/docs/docker-compose.yaml index 5ba221df..bc2c2b8b 100644 --- a/docs/docker-compose.yaml +++ b/docs/docker-compose.yaml @@ -30,12 +30,6 @@ services: export ELEMENT_UNDERSCORE=$$(echo $${PACKAGE} | sed 's/element_//g') export ELEMENT_HYPHEN=$$(echo $${ELEMENT_UNDERSCORE} | sed 's/_/-/g') export PATCH_VERSION=$$(cat /main/$${PACKAGE}/version.py | grep -oE '\d+\.\d+\.[a-z0-9]+') - git clone https://github.com/datajoint/workflow-$${ELEMENT_HYPHEN}.git /main/delete || true - if [ -d /main/delete/ ]; then - mv /main/delete/workflow_$${ELEMENT_UNDERSCORE} /main/ - mv /main/delete/notebooks/*ipynb /main/docs/src/tutorials/ - rm -fR /main/delete - fi if echo "$${MODE}" | grep -i live &>/dev/null; then mkdocs serve --config-file ./docs/mkdocs.yaml -a 0.0.0.0:80 2>&1 | tee docs/temp_mkdocs.log elif echo "$${MODE}" | grep -iE "qa|push" &>/dev/null; then diff --git a/docs/mkdocs.yaml b/docs/mkdocs.yaml index 5fdbffd2..e211069a 100644 --- a/docs/mkdocs.yaml +++ b/docs/mkdocs.yaml @@ -9,18 +9,7 @@ nav: - Concepts: concepts.md - Tutorials: - Overview: tutorials/index.md - - Data Download: tutorials/00-data-download-optional.ipynb - - Configure: tutorials/01-configure.ipynb - - Workflow Structure: tutorials/02-workflow-structure-optional.ipynb - - Process: tutorials/03-process.ipynb - - Automate: tutorials/04-automate-optional.ipynb - - Explore: tutorials/05-explore.ipynb - - Drop: tutorials/06-drop-optional.ipynb - - Downstream Analysis: tutorials/07-downstream-analysis.ipynb - - Visualizations: tutorials/10-data_visualization.ipynb - - Electrode Localization: tutorials/08-electrode-localization.ipynb - - NWB Export: tutorials/09-NWB-export.ipynb - - Quality Metrics: tutorials/quality_metrics.ipynb + - Tutorial: tutorials/tutorial.ipynb - Citation: citation.md - API: api/ # defer to gen-files + literate-nav - Changelog: changelog.md diff --git a/docs/src/concepts.md b/docs/src/concepts.md index f864b306..cb57a802 100644 --- a/docs/src/concepts.md +++ b/docs/src/concepts.md @@ -59,12 +59,16 @@ significant community uptake: Kilosort provides most automation and has gained significant popularity, being adopted as one of the key spike sorting methods in the majority of the teams/collaborations we have worked with. As part of our Year-1 NIH U24 effort, we provide support for data -ingestion of spike sorting results from Kilosort. Further effort will be devoted for the +ingestion of spike sorting results from Kilosort. + +Further effort has been devoted for the ingestion support of other spike sorting methods. On this end, a framework for unifying existing spike sorting methods, named [SpikeInterface](https://github.com/SpikeInterface/spikeinterface), has been developed by Alessio Buccino, et al. SpikeInterface provides a convenient Python-based wrapper to -invoke, extract, compare spike sorting results from different sorting algorithms. +invoke, extract, compare spike sorting results from different sorting algorithms. +SpikeInterface is the primary tool supported by Element Array Electrophysiology for +spike sorting as of version `1.0.0`. ## Key Partnerships @@ -95,22 +99,10 @@ Each of the DataJoint Elements creates a set of tables for common neuroscience d modalities to organize, preprocess, and analyze data. Each node in the following diagram is a table within the Element or a table connected to the Element. -### `ephys_acute` module +### `ephys` module ![diagram](https://raw.githubusercontent.com/datajoint/element-array-ephys/main/images/attached_array_ephys_element_acute.svg) -### `ephys_chronic` module - -![diagram](https://raw.githubusercontent.com/datajoint/element-array-ephys/main/images/attached_array_ephys_element_chronic.svg) - -### `ephys_precluster` module - -![diagram](https://raw.githubusercontent.com/datajoint/element-array-ephys/main/images/attached_array_ephys_element_precluster.svg) - -### `ephys_no_curation` module - -![diagram](https://raw.githubusercontent.com/datajoint/element-array-ephys/main/images/attached_array_ephys_element_no_curation.svg) - ### `subject` schema ([API docs](https://datajoint.com/docs/elements/element-animal/api/element_animal/subject)) Although not required, most choose to connect the `Session` table to a `Subject` table. @@ -181,12 +173,11 @@ Major features of the Array Electrophysiology Element include: + Probe-insertion, ephys-recordings, LFP extraction, clusterings, curations, sorted units and the associated data (e.g. spikes, waveforms, etc.). - + Store/track/manage different curations of the spike sorting results - supporting - both curated clustering and kilosort triggered clustering (i.e., `no_curation`). + + Store/track/manage the spike sorting results. + Ingestion support for data acquired with SpikeGLX and OpenEphys acquisition systems. -+ Ingestion support for spike sorting outputs from Kilosort. -+ Triggering support for workflow integrated Kilosort processing. ++ Ingestion support for spike sorting outputs from SpikeInterface. ++ Triggering support for workflow integrated SpikeInterface processing. + Sample data and complete test suite for quality assurance. ## Data Export and Publishing @@ -208,8 +199,7 @@ pip install element-array-ephys[nwb] ## Roadmap -Incorporation of SpikeInterface into the Array Electrophysiology Element will be -on DataJoint Elements development roadmap. Dr. Loren Frank has led a development +Dr. Loren Frank has led a development effort of a DataJoint pipeline with SpikeInterface framework and NeurodataWithoutBorders format integrated [https://github.com/LorenFrankLab/nwb_datajoint](https://github.com/LorenFrankLab/nwb_datajoint). diff --git a/docs/src/index.md b/docs/src/index.md index b21edcfc..0c828c00 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,29 +1,23 @@ # Element Array Electrophysiology This Element features DataJoint schemas for analyzing extracellular array -electrophysiology data acquired with Neuropixels probes and spike sorted using Kilosort -spike sorter. Each Element is a modular pipeline for data storage and processing with +electrophysiology data acquired with Neuropixels probes and spike sorted using [SpikeInterface](https://github.com/SpikeInterface/spikeinterface). +Each Element is a modular pipeline for data storage and processing with corresponding database tables that can be combined with other Elements to assemble a fully functional pipeline. ![diagram](https://raw.githubusercontent.com/datajoint/element-array-ephys/main/images/diagram_flowchart.svg) -The Element is comprised of `probe` and `ephys` schemas. Several `ephys` schemas are -developed to handle various use cases of this pipeline and workflow: - -+ `ephys_acute`: A probe is inserted into a new location during each session. - -+ `ephys_chronic`: A probe is inserted once and used to record across multiple - sessions. - -+ `ephys_precluster`: A probe is inserted into a new location during each session. - Pre-clustering steps are performed on the data from each probe prior to Kilosort - analysis. - -+ `ephys_no_curation`: A probe is inserted into a new location during each session and - Kilosort-triggered clustering is performed without the option to manually curate the - results. - -Visit the [Concepts page](./concepts.md) for more information about the use cases of +The Element is comprised of `probe` and `ephys` schemas. Visit the +[Concepts page](./concepts.md) for more information about the `probe` and `ephys` schemas and an explanation of the tables. To get started with building your own data pipeline, visit the [Tutorials page](./tutorials/index.md). + +Prior to version `1.0.0` , several `ephys` schemas were +developed and supported to handle various use cases of this pipeline and workflow. These + are now deprecated but still available on their own branch within the repository: + +* [`ephys_acute`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_acute) +* [`ephys_chronic`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_chronic) +* [`ephys_precluster`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_precluster) +* [`ephys_no_curation`](https://github.com/datajoint/element-array-ephys/tree/main_ephys_no_curation) diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 5f367cd9..ff0bd1f5 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,14 +1,18 @@ # Tutorials +## Executing the Tutorial Notebooks + +The tutorials are set up to run using GitHub Codespaces. To run the tutorials, click on +the "Open in Codespaces" button from the GitHub repository. This will open a +pre-configured environment with a VSCode IDE in your browser. THe environment contains +all the necessary dependencies and sample data to run the tutorials. + ## Installation Installation of the Element requires an integrated development environment and database. Instructions to setup each of the components can be found on the -[User Instructions](https://datajoint.com/docs/elements/user-guide/) page. These -instructions use the example -[workflow for Element Array Ephys](https://github.com/datajoint/workflow-array-ephys), -which can be modified for a user's specific experimental requirements. This example -workflow uses several Elements (Lab, Animal, Session, Event, and Electrophysiology) to construct +[User Instructions](https://datajoint.com/docs/elements/user-guide/) page. The example +tutorial uses several Elements (Lab, Animal, Session, Event, and Electrophysiology) to construct a complete pipeline, and is able to ingest experimental metadata and run model training and inference. @@ -23,32 +27,10 @@ Electrophysiology. ### Notebooks Each of the notebooks in the workflow -([download here](https://github.com/datajoint/workflow-array-ephys/tree/main/notebooks) +([download here](https://github.com/datajoint/workflow-array-ephys/tree/main/notebooks)) steps through ways to interact with the Element itself. For convenience, these notebooks are also rendered as part of this site. To try out the Elements notebooks in an online Jupyter environment with access to example data, visit [CodeBook](https://codebook.datajoint.io/). (Electrophysiology notebooks coming soon!) -- [Data Download](./00-data-download-optional.ipynb) highlights how to use DataJoint - tools to download a sample model for trying out the Element. -- [Configure](./01-configure.ipynb) helps configure your local DataJoint installation to - point to the correct database. -- [Workflow Structure](./02-workflow-structure-optional.ipynb) demonstrates the table - architecture of the Element and key DataJoint basics for interacting with these - tables. -- [Process](./03-process.ipynb) steps through adding data to these tables and launching - key Electrophysiology features, like model training. -- [Automate](./04-automate-optional.ipynb) highlights the same steps as above, but - utilizing all built-in automation tools. -- [Explore](./05-explore.ipynb) demonstrates how to fetch data from the Element. -- [Drop schemas](./06-drop-optional.ipynb) provides the steps for dropping all the - tables to start fresh. -- [Downstream Analysis](./07-downstream-analysis.ipynb) highlights how to link - this Element to Element Event for event-based analyses. -- [Visualizations](./10-data_visualization.ipynb) highlights how to use a built-in module - for visualizing units, probes and quality metrics. -- [Electrode Localization](./08-electrode-localization.ipynb) demonstrates how to link - this Element to - [Element Electrode Localization](https://datajoint.com/docs/elements/element-electrode-localization/). -- [NWB Export](./09-NWB-export.ipynb) highlights the export functionality available for the - `no-curation` schema. +- [Tutorial](../../../notebooks/tutorial.ipynb) diff --git a/notebooks/demo_prepare.ipynb b/notebooks/demo_prepare.ipynb deleted file mode 100644 index 85ee1be2..00000000 --- a/notebooks/demo_prepare.ipynb +++ /dev/null @@ -1,224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Demo Preparation Notebook\n", - "\n", - "**Please Note**: This notebook (`demo_prepare.ipynb`) and `demo_run.ipynb` are **NOT** intended to be used as learning materials. To gain\n", - "a thorough understanding of the DataJoint workflow for extracellular electrophysiology, please\n", - "see the [`tutorial`](./tutorial.ipynb) notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Runs in about 45s\n", - "import datajoint as dj\n", - "import datetime\n", - "from tutorial_pipeline import subject, session, probe, ephys\n", - "from element_array_ephys import ephys_report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "subject.Subject.insert1(\n", - " dict(subject=\"subject5\", subject_birth_date=\"2023-01-01\", sex=\"U\")\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "session_key = dict(subject=\"subject5\", session_datetime=\"2023-01-01 00:00:00\")\n", - "\n", - "session.Session.insert1(session_key)\n", - "\n", - "session.SessionDirectory.insert1(dict(session_key, session_dir=\"raw/subject5/session1\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "probe.Probe.insert1(dict(probe=\"714000838\", probe_type=\"neuropixels 1.0 - 3B\"))\n", - "\n", - "ephys.ProbeInsertion.insert1(\n", - " dict(\n", - " session_key,\n", - " insertion_number=1,\n", - " probe=\"714000838\",\n", - " )\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "populate_settings = {\"display_progress\": True}\n", - "\n", - "ephys.EphysRecording.populate(**populate_settings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "kilosort_params = {\n", - " \"fs\": 30000,\n", - " \"fshigh\": 150,\n", - " \"minfr_goodchannels\": 0.1,\n", - " \"Th\": [10, 4],\n", - " \"lam\": 10,\n", - " \"AUCsplit\": 0.9,\n", - " \"minFR\": 0.02,\n", - " \"momentum\": [20, 400],\n", - " \"sigmaMask\": 30,\n", - " \"ThPr\": 8,\n", - " \"spkTh\": -6,\n", - " \"reorder\": 1,\n", - " \"nskip\": 25,\n", - " \"GPU\": 1,\n", - " \"Nfilt\": 1024,\n", - " \"nfilt_factor\": 4,\n", - " \"ntbuff\": 64,\n", - " \"whiteningRange\": 32,\n", - " \"nSkipCov\": 25,\n", - " \"scaleproc\": 200,\n", - " \"nPCs\": 3,\n", - " \"useRAM\": 0,\n", - "}\n", - "\n", - "ephys.ClusteringParamSet.insert_new_params(\n", - " clustering_method=\"kilosort2\",\n", - " paramset_idx=1,\n", - " params=kilosort_params,\n", - " paramset_desc=\"Spike sorting using Kilosort2\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ephys.ClusteringTask.insert1(\n", - " dict(\n", - " session_key,\n", - " insertion_number=1,\n", - " paramset_idx=1,\n", - " task_mode=\"load\", # load or trigger\n", - " clustering_output_dir=\"processed/subject5/session1/probe_1/kilosort2-5_1\",\n", - " )\n", - ")\n", - "\n", - "ephys.Clustering.populate(**populate_settings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "clustering_key = (ephys.ClusteringTask & session_key).fetch1(\"KEY\")\n", - "ephys.Curation().create1_from_clustering_task(clustering_key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Runs in about 12m\n", - "ephys.CuratedClustering.populate(**populate_settings)\n", - "ephys.WaveformSet.populate(**populate_settings)\n", - "ephys_report.ProbeLevelReport.populate(**populate_settings)\n", - "ephys_report.UnitLevelReport.populate(**populate_settings)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Drop schemas\n", - "- Schemas are not typically dropped in a production workflow with real data in it.\n", - "- At the developmental phase, it might be required for the table redesign.\n", - "- When dropping all schemas is needed, the following is the dependency order." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def drop_databases(databases):\n", - " import pymysql.err\n", - "\n", - " conn = dj.conn()\n", - "\n", - " with dj.config(safemode=False):\n", - " for database in databases:\n", - " schema = dj.Schema(f'{dj.config[\"custom\"][\"database.prefix\"]}{database}')\n", - " while schema.list_tables():\n", - " for table in schema.list_tables():\n", - " try:\n", - " conn.query(f\"DROP TABLE `{schema.database}`.`{table}`\")\n", - " except pymysql.err.OperationalError:\n", - " print(f\"Can't drop `{schema.database}`.`{table}`. Retrying...\")\n", - " schema.drop()\n", - "\n", - "\n", - "# drop_databases(databases=['analysis', 'trial', 'event', 'ephys_report', 'ephys', 'probe', 'session', 'subject', 'project', 'lab'])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.17" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/demo_run.ipynb b/notebooks/demo_run.ipynb deleted file mode 100644 index 70fbb746..00000000 --- a/notebooks/demo_run.ipynb +++ /dev/null @@ -1,107 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# DataJoint Workflow for Neuropixels Analysis\n", - "\n", - "+ This notebook demonstrates using the open-source DataJoint Element to build a workflow for extracellular electrophysiology.\n", - "+ For a detailed tutorial, please see the [tutorial notebook](./tutorial.ipynb)." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import datajoint as dj\n", - "from tutorial_pipeline import subject, session, probe, ephys\n", - "from element_array_ephys.plotting.widget import main" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### View workflow" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " dj.Diagram(subject.Subject)\n", - " + dj.Diagram(session.Session)\n", - " + dj.Diagram(probe)\n", - " + dj.Diagram(ephys)\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize processed data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "main(ephys)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For an in-depth tutorial please see the [tutorial notebook](./tutorial.ipynb)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "python3p10", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.17" - }, - "vscode": { - "interpreter": { - "hash": "ff52d424e56dd643d8b2ec122f40a2e279e94970100b4e6430cb9025a65ba4cf" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 463ec527ff216a1078ea97665d365d08d140c621 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 8 Oct 2024 11:56:56 -0500 Subject: [PATCH 148/152] feat(spike_sorting): handle cases when no units/spikes are found --- element_array_ephys/spike_sorting/si_spike_sorting.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 550ae4a1..7a652076 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -275,11 +275,17 @@ def make(self, key): analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + has_units = si_sorting.unit_ids.size > 0 + @memoized_result( uniqueness_dict=postprocessing_params, output_directory=analyzer_output_dir, ) def _sorting_analyzer_compute(): + if not has_units: + log.info("No units found in sorting object. Skipping sorting analyzer.") + return + # Sorting Analyzer sorting_analyzer = si.create_sorting_analyzer( sorting=si_sorting, @@ -303,6 +309,8 @@ def _sorting_analyzer_compute(): _sorting_analyzer_compute() + do_si_export = postprocessing_params.get("export_to_phy", False) or postprocessing_params.get("export_report", False) + self.insert1( { **key, @@ -311,8 +319,7 @@ def _sorting_analyzer_compute(): datetime.utcnow() - execution_time ).total_seconds() / 3600, - "do_si_export": postprocessing_params.get("export_to_phy", False) - or postprocessing_params.get("export_report", False), + "do_si_export": do_si_export and has_units, } ) From 451571de595114e59f732c2e1e66298a26d6eeeb Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 8 Oct 2024 15:06:24 -0500 Subject: [PATCH 149/152] feat(spike_sorting): update downstream ephys tables ingestion when NO UNITs found --- element_array_ephys/ephys_no_curation.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index dae9e4a9..5a30d81d 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1038,6 +1038,16 @@ def make(self, key): if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs import spikeinterface as si + from spikeinterface import sorters + + sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + si_sorting_: si.sorters.BaseSorter = si.load_extractor( + sorting_file, base_folder=output_dir + ) + if si_sorting_.unit_ids.size == 0: + logger.info(f"No units found in {sorting_file}. Skipping Unit ingestion...") + self.insert1(key) + return sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) si_sorting = sorting_analyzer.sorting @@ -1241,6 +1251,11 @@ def make(self, key): output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") + self.insert1(key) + if not len(CuratedClustering.Unit & key): + logger.info(f"No CuratedClustering.Unit found for {key}, skipping Waveform ingestion.") + return + # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") channel2electrode_map: dict[int, dict] = { @@ -1294,7 +1309,6 @@ def yield_unit_waveforms(): ] yield unit_peak_waveform, unit_electrode_waveforms - else: # read from kilosort outputs (ecephys pipeline) kilosort_dataset = kilosort.Kilosort(output_dir) @@ -1394,7 +1408,6 @@ def yield_unit_waveforms(): yield unit_peak_waveform, unit_electrode_waveforms # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): if unit_peak_waveform: self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) @@ -1501,6 +1514,11 @@ def make(self, key): output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") + self.insert1(key) + if not len(CuratedClustering.Unit & key): + logger.info(f"No CuratedClustering.Unit found for {key}, skipping QualityMetrics ingestion.") + return + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si @@ -1556,7 +1574,6 @@ def make(self, key): for unit_key in (CuratedClustering.Unit & key).fetch("KEY") ] - self.insert1(key) self.Cluster.insert(metrics_list, ignore_extra_fields=True) self.Waveform.insert(metrics_list, ignore_extra_fields=True) From b1104ce09158c09175af94682e6e7a0281a7cda2 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 9 Oct 2024 09:37:00 -0500 Subject: [PATCH 150/152] fix(spike_sorting): create empty `sorting_analyzer` folder when no units found --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 7a652076..a47f1d89 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -284,6 +284,7 @@ def make(self, key): def _sorting_analyzer_compute(): if not has_units: log.info("No units found in sorting object. Skipping sorting analyzer.") + analyzer_output_dir.mkdir(parents=True, exist_ok=True) # create empty directory anyway, for consistency return # Sorting Analyzer From 0508c94df3881e8d7fb28fc7174e958336834962 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 13 Jan 2025 16:30:55 -0600 Subject: [PATCH 151/152] update: fix docs, new version is `0.4.0` --- docs/src/concepts.md | 2 +- docs/src/index.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/concepts.md b/docs/src/concepts.md index cb57a802..b5da5081 100644 --- a/docs/src/concepts.md +++ b/docs/src/concepts.md @@ -68,7 +68,7 @@ existing spike sorting methods, named by Alessio Buccino, et al. SpikeInterface provides a convenient Python-based wrapper to invoke, extract, compare spike sorting results from different sorting algorithms. SpikeInterface is the primary tool supported by Element Array Electrophysiology for -spike sorting as of version `1.0.0`. +spike sorting as of version `0.4.0`. ## Key Partnerships diff --git a/docs/src/index.md b/docs/src/index.md index 0c828c00..5d9b7f19 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -13,7 +13,7 @@ The Element is comprised of `probe` and `ephys` schemas. Visit the `ephys` schemas and an explanation of the tables. To get started with building your own data pipeline, visit the [Tutorials page](./tutorials/index.md). -Prior to version `1.0.0` , several `ephys` schemas were +Prior to version `0.4.0` , several `ephys` schemas were developed and supported to handle various use cases of this pipeline and workflow. These are now deprecated but still available on their own branch within the repository: From 284106a7e3820971ade67153b7b231435fd3c1eb Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 14 Jan 2025 08:50:20 -0600 Subject: [PATCH 152/152] update: make this the `1.0.0` release --- docs/src/concepts.md | 2 +- docs/src/index.md | 2 +- element_array_ephys/version.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/concepts.md b/docs/src/concepts.md index b5da5081..cb57a802 100644 --- a/docs/src/concepts.md +++ b/docs/src/concepts.md @@ -68,7 +68,7 @@ existing spike sorting methods, named by Alessio Buccino, et al. SpikeInterface provides a convenient Python-based wrapper to invoke, extract, compare spike sorting results from different sorting algorithms. SpikeInterface is the primary tool supported by Element Array Electrophysiology for -spike sorting as of version `0.4.0`. +spike sorting as of version `1.0.0`. ## Key Partnerships diff --git a/docs/src/index.md b/docs/src/index.md index 5d9b7f19..0c828c00 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -13,7 +13,7 @@ The Element is comprised of `probe` and `ephys` schemas. Visit the `ephys` schemas and an explanation of the tables. To get started with building your own data pipeline, visit the [Tutorials page](./tutorials/index.md). -Prior to version `0.4.0` , several `ephys` schemas were +Prior to version `1.0.0` , several `ephys` schemas were developed and supported to handle various use cases of this pipeline and workflow. These are now deprecated but still available on their own branch within the repository: diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 7a2d5521..19ba4c76 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "4.0.0" +__version__ = "1.0.0"