From 7bfc818f99beb90ce956fd0943868ff7c46ef1c5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 29 Aug 2024 16:20:55 +0200 Subject: [PATCH] Fix tests --- src/spikeinterface_pipelines/spikesorting/params.py | 5 +++-- tests/test_pipeline.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 4edd2c5..25edb4b 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -129,9 +129,10 @@ class Kilosort4Model(BaseModel): save_extra_kwargs: bool = Field(default=False, description="If True, additional kwargs are saved to the output") skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionally skip the internal kilosort preprocessing") scaleproc: Union[None, int] = Field(default=None, description="int16 scaling of whitened data, if None set to 200.") - save_preprocessed_copy: bool = Field(default=False, description="save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data") torch_device: str = Field(default="auto", description="Select the torch device auto/cuda/cpu") - # bad_channels: Optional[List[int]] = Field(default=None, description="List of bad channels to exclude from spike detection and clustering.") + # THESE 2 PARAMS REUQIRE A NEW RELEASE OD SPIKEINTERFACE + # ist[int]] = Field(default=None, description="List of bad channels to exclude from spike detection and clustering.") + # save_preprocessed_copy: bool = Field(default=False, description="save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e3fb868..6d3ba35 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -32,6 +32,7 @@ def _generate_gt_recording(): analyzer.compute( [ "random_spikes", + "waveforms", "templates", "noise_levels", "spike_amplitudes",