From faca0e5c4b42a0556480be4f1fd734bf3f74b770 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 8 Jan 2025 17:42:18 +0000 Subject: [PATCH] Update docstring on 'calculate_displacement_unit_factor. --- .../generation/drifting_generator.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 49b5d5929e..e4ab723ae3 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -253,14 +253,18 @@ def calculate_displacement_unit_factor( non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array ) -> np.array: """ - In the case of introducing non-rigid drift, a set of scaling - factors (one per unit) is generated for scaling the displacement - as a function of unit position. + Introduces a non-rigid drift across the probe, this is a linear + scaling of the displacement based on the unit position. + + To introduce non-rigid drift, a set of scaling factors (one per unit) + are generated. These scale the displacement applied to each unit + as a function of unit position. The smaller the `non_rigid_gradient`, + the larger the influence of the unit position is on scaling the + displacement (more non-linearity). The projections of the gradient vector (x, y) and unit locations (x, y) are normalised to range between 0 and 1 (i.e. based on relative location to the gradient). - These factors are scaled by `non_rigid_gradient`. Parameters ---------- @@ -270,6 +274,7 @@ def calculate_displacement_unit_factor( that are based on unit location. This sets the weighting given to the factors based on unit locations. When 1, the factors will all equal 1 (no effect), when 0, the scaling factor based on unit location will be used directly. + Smaller number results in more nonlinearity. unit_locations : np.array The unit location with shape (num_units, 2) drift_start_um : np.array @@ -293,6 +298,7 @@ def calculate_displacement_unit_factor( factors = 1 - factors f = np.abs(non_rigid_gradient) + print("f", f) displacement_unit_factor = factors * (1 - f) + f return displacement_unit_factor