Skip to content

Commit

Permalink
Merge pull request #13 from SpikeInterface/non-rigid-fast-accurate
Browse files Browse the repository at this point in the history
Add nonrigid_fast_and_accurate
  • Loading branch information
luiztauffer authored Mar 18, 2024
2 parents 9f9ca21 + 2f3ad49 commit 22976c5
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def curate(
Path to the scratch folder
results_folder: Path
Path to the results folder
Returns
-------
si.BaseSorting | None
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface_pipelines/curation/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ class CurationParams(BaseModel):
"""
Curation parameters.
"""

curation_query: str = Field(
default="isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8",
description=(
"Query to select units to keep after curation. "
"Default is 'isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8'."
)
)
),
)
5 changes: 2 additions & 3 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def run_pipeline(
preprocessing_params = PreprocessingParams(**preprocessing_params)
if isinstance(spikesorting_params, dict):
spikesorting_params = SpikeSortingParams(
sorter_name=spikesorting_params['sorter_name'],
sorter_kwargs=spikesorting_params['sorter_kwargs']
sorter_name=spikesorting_params["sorter_name"], sorter_kwargs=spikesorting_params["sorter_kwargs"]
)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
Expand Down Expand Up @@ -84,6 +83,7 @@ def run_pipeline(

# Spike Sorting
if run_spikesorting:
# TODO: turn off sorter motion correction if motion correction is already done
sorting = spikesort(
recording=recording_preprocessed,
scratch_folder=scratch_folder,
Expand Down Expand Up @@ -126,7 +126,6 @@ def run_pipeline(
waveform_extractor = None
sorting_curated = None


# Visualization
visualization_output = None
if run_visualization:
Expand Down
12 changes: 10 additions & 2 deletions src/spikeinterface_pipelines/postprocessing/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,16 @@ class QMParams(BaseModel):
class QualityMetrics(BaseModel):
qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.")
metric_names: List[str] = Field(
default=["presence_ratio", "snr", "isi_violation", "rp_violation", "sliding_rp_violation", "amplitude_cutoff", "amplitude_median"],
description="List of metric names to compute. If None, all available metrics are computed."
default=[
"presence_ratio",
"snr",
"isi_violation",
"rp_violation",
"sliding_rp_violation",
"amplitude_cutoff",
"amplitude_median",
],
description="List of metric names to compute. If None, all available metrics are computed.",
)
n_jobs: int = Field(default=1, description="Number of jobs.")

Expand Down
81 changes: 64 additions & 17 deletions src/spikeinterface_pipelines/preprocessing/params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from typing import Optional, Union, List, Literal
from enum import Enum
import numpy as np


class PreprocessingStrategy(str, Enum):
Expand Down Expand Up @@ -52,25 +53,35 @@ class MCDetectKwargs(BaseModel):

class MCLocalizeCenterOfMass(BaseModel):
radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.")
feature: str = Field(default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation")
feature: str = Field(
default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation"
)


class MCLocalizeMonopolarTriangulation(BaseModel):
radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.")
max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.")
optimizer: str = Field(default="minimize_with_log_penality", description="")
enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors")
feature: str = Field(default="ptp", description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)")
feature: str = Field(
default="ptp",
description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)",
)


class MCLocalizeGridConvolution(BaseModel):
radius_um: float = Field(default=40.0, description="Radius in um for channel sparsity.")
upsampling_um: float = Field(default=5.0, description="Upsampling resolution for the grid of templates.")
sigma_um: List[float] = Field(default=[5.0, 25.0, 5], description="Spatial decays of the fake templates.")
weight_method: dict = Field(
default={"mode": "gaussian_2d", "sigma_list_um": np.linspace(5, 25, 5)}, description="Weighting strategy."
)
sigma_ms: float = Field(default=0.25, description="The temporal decay of the fake templates.")
margin_um: float = Field(default=30.0, description="The margin for the grid of fake templates.")
percentile: float = Field(default=10.0, description="The percentage in [0, 100] of the best scalar products kept to estimate the position.")
sparsity_threshold: float = Field(default=0.01, description="The sparsity threshold (in [0, 1]) below which weights should be considered as 0.")
percentile: float = Field(
default=10.0,
description="The percentage in [0, 100] of the best scalar products kept to estimate the position.",
)
prototype: Optional[list] = Field(default=None, description="Fake waveforms for the templates.")


class MCEstimateMotionDecentralized(BaseModel):
Expand Down Expand Up @@ -117,45 +128,81 @@ class MCEstimateMotionIterativeTemplate(BaseModel):


class MCInterpolateMotionKwargs(BaseModel):
direction: int = Field(default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z).")
border_mode: str = Field(default="remove_channels", description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.")
spatial_interpolation_method: str = Field(default="idw", description="The spatial interpolation method used to interpolate the channel locations.")
direction: int = Field(
default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z)."
)
border_mode: str = Field(
default="remove_channels",
description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.",
)
spatial_interpolation_method: str = Field(
default="idw", description="The spatial interpolation method used to interpolate the channel locations."
)
sigma_um: float = Field(default=20.0, description="Used in the 'kriging' formula")
p: int = Field(default=1, description="Used in the 'kriging' formula")
num_closest: int = Field(default=3, description="Number of closest channels used by 'idw' method for interpolation.")
num_closest: int = Field(
default=3, description="Number of closest channels used by 'idw' method for interpolation."
)


class MCNonrigidAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(default=MCLocalizeMonopolarTriangulation(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(
default=MCLocalizeMonopolarTriangulation(), description=""
)
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCNonrigidFastAndAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCRigidFast(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCKilosortLike(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(default=MCEstimateMotionIterativeTemplate(), description="")
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), description="")
estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(
default=MCEstimateMotionIterativeTemplate(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(
default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"),
description="",
)


class MCPreset(str, Enum):
nonrigid_accurate = "nonrigid_accurate"
nonrigid_fast_and_accurate = "nonrigid_fast_and_accurate"
rigid_fast = "rigid_fast"
kilosort_like = "kilosort_like"


class MotionCorrection(BaseModel):
strategy: Literal["skip", "compute", "apply"] = Field(default="compute", description="What strategy to use for motion correction")
preset: MCPreset = Field(default=MCPreset.nonrigid_accurate.value, description="Preset for motion correction")
motion_kwargs: Union[MCNonrigidAccurate, MCRigidFast, MCKilosortLike] = Field(default=MCNonrigidAccurate(), description="Motion correction parameters")
strategy: Literal["skip", "compute", "apply"] = Field(
default="compute", description="What strategy to use for motion correction"
)
preset: MCPreset = Field(
default=MCPreset.nonrigid_fast_and_accurate.value, description="Preset for motion correction"
)
motion_kwargs: Union[MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike] = Field(
default=MCNonrigidFastAndAccurate(), description="Motion correction parameters"
)


# Preprocessing params ---------------------------------------------------------------
Expand Down
18 changes: 11 additions & 7 deletions src/spikeinterface_pipelines/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
import spikeinterface.preprocessing as spre

from ..logger import logger
from .params import PreprocessingParams, MCNonrigidAccurate, MCRigidFast, MCKilosortLike
from .params import PreprocessingParams, MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike


warnings.filterwarnings("ignore")

_motion_correction_presets_to_params = dict(
nonrigid_accurate=MCNonrigidAccurate,
nonrigid_fast_and_accurate=MCNonrigidFastAndAccurate,
rigid_fast=MCKilosortLike,
kilosort_like=MCKilosortLike,
)


def preprocess(
recording: si.BaseRecording,
Expand Down Expand Up @@ -102,12 +109,9 @@ def preprocess(
# Motion correction
if preprocessing_params.motion_correction.strategy != "skip":
preset = preprocessing_params.motion_correction.preset
if preset == "nonrigid_accurate":
motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "rigid_fast":
motion_correction_kwargs = MCRigidFast(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "kilosort_like":
motion_correction_kwargs = MCKilosortLike(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
motion_correction_kwargs = _motion_correction_presets_to_params[preset](
**preprocessing_params.motion_correction.motion_kwargs.model_dump()
)
logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}")
motion_folder = results_folder / "motion_correction"
recording_corrected = spre.correct_motion(
Expand Down
34 changes: 19 additions & 15 deletions src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SorterName(str, Enum):


class Kilosort25Model(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections")
preclust_threshold: float = Field(
Expand All @@ -29,7 +29,10 @@ class Kilosort25Model(BaseModel):
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=150, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
lam: float = Field(
default=10.0,
description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)",
)
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
Expand All @@ -50,22 +53,18 @@ class Kilosort25Model(BaseModel):


class Kilosort3Model(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
pass


class IronClustModel(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
pass


class MountainSort5Model(BaseModel):
model_config = ConfigDict(extra='forbid')
scheme: str = Field(
default='2',
description="Sorting scheme",
json_schema_extra={'options': ["1", "2", "3"]}
)
model_config = ConfigDict(extra="forbid")
scheme: str = Field(default="2", description="Sorting scheme", json_schema_extra={"options": ["1", "2", "3"]})
detect_threshold: float = Field(default=5.5, description="Threshold for spike detection")
detect_sign: int = Field(default=-1, description="Sign of the peak")
detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds")
Expand All @@ -77,9 +76,13 @@ class MountainSort5Model(BaseModel):
scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius")
scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius")
scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius")
scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch")
scheme2_max_num_snippets_per_training_batch: int = Field(
default=200, description="Scheme 2 max number of snippets per training batch"
)
scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds")
scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode")
scheme2_training_recording_sampling_mode: str = Field(
default="uniform", description="Scheme 2 training recording sampling mode"
)
scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds")
freq_min: int = Field(default=300, description="High-pass filter cutoff frequency")
freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency")
Expand All @@ -90,7 +93,8 @@ class MountainSort5Model(BaseModel):
class SpikeSortingParams(BaseModel):
sorter_name: SorterName = Field(description="Name of the sorter to use.")
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field(
description="Sorter specific kwargs.",
union_mode='left_to_right'
description="Sorter specific kwargs.", union_mode="left_to_right"
)
spikesort_by_group: bool = Field(
default=False, description="If True, spike sorting is run for each group separately."
)
spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .visualization import visualize
from .params import VisualizationParams
from .params import VisualizationParams
Loading

0 comments on commit 22976c5

Please sign in to comment.