Skip to content

Commit

Permalink
Merge pull request #3553 from samuelgarcia/bench_motion
Browse files Browse the repository at this point in the history
MotionEstimationStudy : plot drift with the scatter plot
  • Loading branch information
alejoe91 authored Jan 8, 2025
2 parents d38dbf4 + 06e7d81 commit 9b022da
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/spikeinterface/benchmark/benchmark_motion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def run(self, **job_kwargs):
estimate_motion=t4 - t3,
)

self.result["peaks"] = peaks
self.result["peak_locations"] = peak_locations
self.result["step_run_times"] = step_run_times
self.result["raw_motion"] = motion

Expand All @@ -131,6 +133,8 @@ def compute_result(self, **result_params):
self.result["motion"] = motion

_run_key_saved = [
("peaks", "npy"),
("peak_locations", "npy"),
("raw_motion", "Motion"),
("step_run_times", "pickle"),
]
Expand Down Expand Up @@ -161,7 +165,9 @@ def create_benchmark(self, key):
def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)):
self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize)

def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)):
def plot_drift(
self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)
):
import matplotlib.pyplot as plt

if case_keys is None:
Expand Down Expand Up @@ -195,6 +201,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p

# for i in range(self.gt_unit_positions.shape[1]):
# ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5")
if raster:
peaks = bench.result["peaks"]
peak_locations = bench.result["peak_locations"]
rec = bench.recording
x = peaks["sample_index"] / rec.sampling_frequency
y = peak_locations[bench.direction]
ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno")

for i in range(gt_motion.displacement[0].shape[1]):
depth = motion.spatial_bins_um[i]
Expand Down

0 comments on commit 9b022da

Please sign in to comment.