Skip to content

Commit

Permalink
patch LFP (#46)
Browse files Browse the repository at this point in the history
* fix RWSLFPFromPSCs buffer issue

* add style_plots_for_paper()

* bump to v0.14.1

* add section headings to advanced_lfp.ipynb

* fix RWSLFPFromPSCs buffer, add tests

* use rounded sample times to avoid skips

* add LFP comparison notebook
  • Loading branch information
kjohnsen authored Apr 1, 2024
1 parent d2215a4 commit 1b4553d
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 47 deletions.
8 changes: 5 additions & 3 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,11 @@ def set_io_processor(
return

def communicate_with_io_proc(t):
if io_processor.is_sampling_now(t / ms):
io_processor.put_state(self.get_state(), t / ms)
stim_values = io_processor.get_stim_values(t / ms)
# assuming no one will have timesteps shorter than nanoseconds...
now_ms = round(t / ms, 6)
if io_processor.is_sampling_now(now_ms):
io_processor.put_state(self.get_state(), now_ms)
stim_values = io_processor.get_stim_values(now_ms)
self.update_stimulators(stim_values)

# communication should be at every timestep. The IOProcessor
Expand Down
56 changes: 45 additions & 11 deletions cleo/ephys/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import deque
from datetime import datetime
from itertools import chain
from math import floor
from math import ceil
from numbers import Number
from typing import Any, Union

Expand Down Expand Up @@ -241,6 +241,12 @@ class RWSLFPSignalBase(LFPSignalBase):
"""Whether to aggregate currents across the population (as opposed to neurons having
differential contributions to LFP depending on their location). False by default."""

wslfp_kwargs: dict = field(factory=dict)
"""Keyword arguments to pass to the WSLFP calculator, e.g., ``alpha``,
``tau_ampa_ms``, ``tau_gaba_ms````source_coords_are_somata``,
``source_dendrite_length_um``, ``amp_kwargs``, ``strict_boundaries``.
"""

_wslfps: dict[NeuronGroup, wslfp.WSLFPCalculator] = field(
init=False, factory=dict, repr=False
)
Expand Down Expand Up @@ -271,6 +277,8 @@ def _init_wslfp_calc(self, neuron_group: NeuronGroup, **kwparams):
]:
if key in kwparams:
wslfp_kwargs[key] = kwparams.pop(key)
elif key in self.wslfp_kwargs:
wslfp_kwargs[key] = self.wslfp_kwargs[key]

self._wslfps[neuron_group] = wslfp.from_xyz_coords(
self._elec_coords / um,
Expand Down Expand Up @@ -595,7 +603,9 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):

self._init_wslfp_calc(neuron_group, **kwparams)

buf_len_ampa, buf_len_gaba = self._get_buf_lens(self._wslfps[neuron_group])
buf_len_ampa, buf_len_gaba = self._get_buf_lens_for_wslfp(
self._wslfps[neuron_group]
)
self._t_ampa_bufs[neuron_group] = deque(maxlen=buf_len_ampa)
self._I_ampa_bufs[neuron_group] = deque(maxlen=buf_len_ampa)
self._t_gaba_bufs[neuron_group] = deque(maxlen=buf_len_gaba)
Expand All @@ -605,7 +615,10 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):
self._ampa_vars[neuron_group] = [varname + "_" for varname in I_ampa_names]
self._gaba_vars[neuron_group] = [varname + "_" for varname in I_gaba_names]

def _get_buf_lens(self, wslfp_calc, **kwparams):
def _buf_len(self, tau, dt):
return ceil((tau + dt) / dt)

def _get_buf_lens_for_wslfp(self, wslfp_calc, **kwparams):
# need sampling period
sample_period_ms = kwparams.get("sample_period_ms", None)
if sample_period_ms is None:
Expand All @@ -619,14 +632,14 @@ def _get_buf_lens(self, wslfp_calc, **kwparams):
f" specify it on injection: .inject({self.probe.name}"
", sample_period_ms=...)"
)
buf_len_ampa = floor(wslfp_calc.tau_ampa_ms / sample_period_ms + 1)
buf_len_gaba = floor(wslfp_calc.tau_gaba_ms / sample_period_ms + 1)
buf_len_ampa = self._buf_len(wslfp_calc.tau_ampa_ms, sample_period_ms)
buf_len_gaba = self._buf_len(wslfp_calc.tau_gaba_ms, sample_period_ms)
return buf_len_ampa, buf_len_gaba

def _curr_from_buffer(self, t_buf_ms, I_buf, t_eval_ms: float, n_sources):
# t_eval_ms is not iterable
empty = np.zeros((1, n_sources))
if len(t_buf_ms) == 0 or t_buf_ms[0] > t_eval_ms:
if len(t_buf_ms) == 0 or t_buf_ms[0] > t_eval_ms or t_buf_ms[-1] < t_eval_ms:
return empty
# when tau is multiple of sample time, current should be collected
# right when needed, at the left end of the buffer
Expand All @@ -636,10 +649,31 @@ def _curr_from_buffer(self, t_buf_ms, I_buf, t_eval_ms: float, n_sources):
# if not, should only need to interpolate between first and second positions
# if buffer length is correct
assert len(t_buf_ms) > 1
assert t_buf_ms[0] < t_eval_ms < t_buf_ms[1]
I_interp = I_buf[0] + (I_buf[1] - I_buf[0]) * (t_eval_ms - t_buf_ms[0]) / (
t_buf_ms[1] - t_buf_ms[0]
)
if t_buf_ms[0] < t_eval_ms < t_buf_ms[1]:
i_l, i_r = 0, 1
else:
warnings.warn(
f"Time buffer is unexpected. Did a sample get skipped? "
f"t_buf_ms={t_buf_ms}, t_eval_ms={t_eval_ms}"
)
i_l, i_r = None, None
for i, t in enumerate(t_buf_ms):
if t < t_eval_ms:
i_l = i
if t >= t_eval_ms:
i_r = i
break
if i_l is None or i_r is None or i_l >= i_r:
warnings.warn(
"Signal buffer does not contain currents at needed timepoints. "
"Returning 0. "
f"t_buf_ms={t_buf_ms}, t_eval_ms={t_eval_ms}"
)
return empty

I_interp = I_buf[i_l] + (I_buf[i_r] - I_buf[i_l]) * (
t_eval_ms - t_buf_ms[i_l]
) / (t_buf_ms[i_r] - t_buf_ms[i_l])

I_interp = np.reshape(I_interp, (1, n_sources))
I_interp = np.nan_to_num(I_interp, nan=0)
Expand Down Expand Up @@ -683,7 +717,7 @@ def _needed_current(
def reset(self, **kwargs) -> None:
self._init_saved_vars()
for ng in self._t_ampa_bufs:
buf_len_ampa, buf_len_gaba = self._get_buf_lens(self._wslfps[ng])
buf_len_ampa, buf_len_gaba = self._get_buf_lens_for_wslfp(self._wslfps[ng])
self._t_ampa_bufs[ng] = deque(maxlen=buf_len_ampa)
self._I_ampa_bufs[ng] = deque(maxlen=buf_len_ampa)
self._t_gaba_bufs[ng] = deque(maxlen=buf_len_gaba)
Expand Down
2 changes: 1 addition & 1 deletion cleo/ioproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def is_sampling_now(self, query_time_ms):
if np.isclose(query_time_ms % self.sample_period_ms, 0):
return True
elif self.sampling == "when idle":
if query_time_ms % self.sample_period_ms == 0:
if np.isclose(query_time_ms % self.sample_period_ms, 0):
if self._is_currently_idle(query_time_ms):
self._needs_off_schedule_sample = False
return True
Expand Down
13 changes: 13 additions & 0 deletions cleo/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,19 @@ def style_plots_for_docs(dark=True):
plt.rc("font", **{"sans-serif": "Open Sans"})


def style_plots_for_paper():
# some hacky workaround for params not being updated until after first plot
f = plt.figure()
plt.plot()
plt.close(f)

plt.style.use("seaborn-v0_8-paper")
plt.rc("savefig", transparent=True, bbox="tight", dpi=300)
plt.rc("svg", fonttype="none")
plt.rc("axes.spines", top=False, right=False)
plt.rc("font", **{"sans-serif": "Open Sans"})


def unit_safe_append(q1: Quantity, q2: Quantity, axis=0):
if not b2.have_same_dimensions(q1, q2):
raise ValueError("Dimensions must match")
Expand Down
9 changes: 9 additions & 0 deletions docs/tutorials/advanced_lfp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## RWSLFP recording options\n",
"\n",
"There are a few important variations on how to record RWSLFP:\n",
"- Currents can be summed over the population, so that a postsynaptic current (PSC) in one location has the same effect on LFP as one on the other side of the population.\n",
" The main advantage to this approach is it saves some memory storing currents.\n",
Expand Down Expand Up @@ -352,6 +354,13 @@
"sim.inject(probe, inh, tklfp_type=\"inh\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run simulation and plot results"
]
},
{
"cell_type": "code",
"execution_count": 9,
Expand Down
447 changes: 447 additions & 0 deletions notebooks/lfp_comparison.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cleosim"
version = "0.14.0"
version = "0.14.1"
description = "Cleo: the Closed-Loop, Electrophysiology, and Optogenetics experiment simulation testbed"
authors = [
"Kyle Johnsen <[email protected]>",
Expand Down
112 changes: 92 additions & 20 deletions tests/ephys/test_lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,23 @@ def test_lfp_signal_to_neo(LFPSignal, n_channels, t, regular_samples):
assert neo_sig.name == f"{sig.probe.name}.{sig.name}"


_base_param_options = [
("pop_agg", (True, False)),
("amp_func", (wslfp.aussel18, wslfp.mazzoni15_pop)),
# ("wslfp_kwargs", ({}, {"tau_ampa_ms": 5})),
("wslfp_kwargs", ({}, {"tau_gaba_ms": 1})),
("wslfp_kwargs", ({}, {"alpha": 1})),
(
"wslfp_kwargs",
({"source_coords_are_somata": True}, {"source_coords_are_somata": False}),
),
(
"wslfp_kwargs",
({"source_dendrite_length_um": 250}, {"source_dendrite_length_um": 300}),
),
]


def test_RWSLFPSignalFromSpikes(rand_seed):
rng = np.random.default_rng(rand_seed)
b2.seed(rand_seed)
Expand All @@ -259,6 +276,7 @@ def add_rwslfp_sig(
pop_agg=True,
amp_func=wslfp.mazzoni15_pop,
ornt=[0, 0, -1],
wslfp_kwargs={},
tau1_ampa=2 * ms,
tau2_ampa=0.4 * ms,
tau1_gaba=5 * ms,
Expand All @@ -276,6 +294,7 @@ def add_rwslfp_sig(
tau2_gaba=tau2_gaba,
syn_delay=syn_delay,
I_threshold=I_threshold,
wslfp_kwargs=wslfp_kwargs,
)
# separate probe for each signal since some injection kwargs need to be different per signal
i = len(sim.recorders)
Expand All @@ -293,10 +312,8 @@ def add_rwslfp_sig(
)
return rwslfp_sig

signals_by_param = {}
for param_name, param_vals in [
("pop_agg", (True, False)),
("amp_func", (wslfp.aussel18, wslfp.mazzoni15_pop)),
signals_for_options = []
param_options = [
("ornt", (rng.normal(size=(n_exc, 3)), (-0.2, -0.1, -0.3))),
("tau1_ampa", (2, 1) * ms),
("tau2_ampa", (0.4, 0.2) * ms),
Expand All @@ -305,18 +322,23 @@ def add_rwslfp_sig(
("syn_delay", (1, 2) * ms),
("homog_J", (True, False)),
("I_threshold", (0.1, 0.001)),
]:
signals_by_param[param_name] = []
]
for param_name, param_vals in _base_param_options + param_options:
signals_to_compare = []
for val in param_vals:
# store result of each different value for each param
signals_by_param[param_name].append(add_rwslfp_sig(**{param_name: val}))
signals_to_compare.append(add_rwslfp_sig(**{param_name: val}))
signals_for_options.append((param_name, signals_to_compare))
sim.run(100 * ms)
# each parameter change should change the resulting signal:
for param, signals in signals_by_param.items():
assert len(signals_for_options) == len(_base_param_options) + len(param_options)
for param, signals in signals_for_options:
assert len(signals) > 1
for i, sig1 in enumerate(signals):
for sig2 in signals[i + 1 :]:
assert not np.allclose(sig1.lfp, sig2.lfp)
assert not np.allclose(
sig1.lfp, sig2.lfp
), f"{param} variation not yielding different results"


@pytest.mark.parametrize("samp_period_ms", [1, 1.4])
Expand All @@ -325,7 +347,7 @@ def test_RWSLFPSignalFromPSCs(rand_seed, samp_period_ms):
b2.seed(rand_seed)
n_exc = 16
n_elec = 4
elec_coords = rng.uniform(-1, 1, (n_elec, 3)) * mm
elec_coords = rng.uniform(-0.5, 0.5, (n_elec, 3)) * mm

exc = b2.NeuronGroup(
n_exc,
Expand All @@ -335,7 +357,7 @@ def test_RWSLFPSignalFromPSCs(rand_seed, samp_period_ms):
dIgaba2/dt = xi_4 / sqrt(dt) : 1
""",
)
assign_coords(exc, rng.uniform(-1, 1, (n_exc, 3)) * mm)
assign_coords(exc, rng.uniform(-0.5, 0.5, (n_exc, 3)) * mm)

# cleo setup
sim = CLSimulator(Network(exc))
Expand All @@ -345,13 +367,18 @@ def add_rwslfp_sig(
pop_agg=True,
amp_func=wslfp.mazzoni15_pop,
ornt=[0, 0, -1],
wslfp_kwargs={},
Iampa_var_names=["Iampa1"],
Igaba_var_names=["Igaba1"],
name=None,
):
rwslfp_sig = RWSLFPSignalFromPSCs(
pop_aggregate=pop_agg,
amp_func=amp_func,
wslfp_kwargs=wslfp_kwargs,
)
if name:
rwslfp_sig.name = name
# separate probe for each signal since some injection kwargs need to be different per signal
i = len(sim.recorders)
probe = Probe(elec_coords, [rwslfp_sig], name=f"probe{i}")
Expand All @@ -365,25 +392,70 @@ def add_rwslfp_sig(
)
return rwslfp_sig

signals_by_param = {}
for param_name, param_vals in [
("pop_agg", (True, False)),
("amp_func", (wslfp.aussel18, wslfp.mazzoni15_pop)),
signals_for_options = []
param_options = [
("ornt", (rng.normal(size=(n_exc, 3)), (-0.2, -0.1, -0.3))),
("Iampa_var_names", (["Iampa1"], ["Iampa1", "Iampa2"])),
("Igaba_var_names", (["Igaba1"], ["Igaba1", "Igaba2"])),
]:
signals_by_param[param_name] = []
]
for param_name, param_vals in _base_param_options + param_options:
signals = []
for val in param_vals:
# store result of each different value for each param
signals_by_param[param_name].append(add_rwslfp_sig(**{param_name: val}))
signals.append(
add_rwslfp_sig(**{param_name: val}, name=f"{param_name}_{val}")
)
signals_for_options.append((param_name, signals))
sim.run(30 * ms)
assert not np.allclose(exc.Iampa1_, exc.Iampa2_)
# each parameter change should change the resulting signal:
for param, signals in signals_by_param.items():
assert len(signals_for_options) == len(_base_param_options) + len(param_options)
for param, signals in signals_for_options:
assert len(signals) > 1
for i, sig1 in enumerate(signals):
for sig2 in signals[i + 1 :]:
assert not np.allclose(sig1.lfp, sig2.lfp)
assert not np.allclose(
sig1.lfp, sig2.lfp
), f"{sig1.name} and {sig2.name} not yielding different results"


def test_psc_buffer():
sig = RWSLFPSignalFromPSCs()
t_buf = [1, 2, 4] # skipped 3 for some reason; should be [2, 3, 4]
n_src = 4
I_buf = np.arange(3)[..., None] + np.arange(n_src)
print(I_buf)
t_eval = 2
with pytest.warns(match="buffer is unexpected"):
assert np.all(
sig._curr_from_buffer(t_buf, I_buf, t_eval, n_sources=n_src)
== np.arange(1, 5)
)
assert np.all(
sig._curr_from_buffer([1, 3, 4], I_buf, t_eval, n_sources=n_src)
== np.arange(0.5, 4.5)
)

assert np.all(sig._curr_from_buffer([3, 4, 5], I_buf, t_eval, n_sources=n_src) == 0)
assert np.all(
sig._curr_from_buffer([-1, 0, 1], I_buf, t_eval, n_sources=n_src) == 0
)


def test_psc_buf_len():
sig = RWSLFPSignalFromPSCs()
# args are buf_width, sample_period
assert sig._buf_len(0, 1) == 1
assert sig._buf_len(0, 1.5) == 1
assert sig._buf_len(0, 12.13) == 1
assert sig._buf_len(1, 1) == 2
assert sig._buf_len(1, 0.5) == 3
assert sig._buf_len(2, 1) == 3
assert sig._buf_len(6, 1) == 7
assert sig._buf_len(6, 1.2) == 6
assert sig._buf_len(6, 1.4) == 6
assert sig._buf_len(6, 1.5) == 5
assert sig._buf_len(6, 1.6) == 5


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1b4553d

Please sign in to comment.