Skip to content

Commit

Permalink
Include basal ganglia tutorial (#434)
Browse files Browse the repository at this point in the history
* rename tutorial just for brevity

* include basal ganglia tutorial

* increase `abstol` and `reltol`, and reduce `trajectories`

* improve performance of `state_timeseries`

* rename tutorial just for brevity

* include basal ganglia tutorial

* increase `abstol` and `reltol`, and reduce `trajectories`

* improve performance of `state_timeseries`

* fix `voltage_timeseries` for single neurons

* reduce time span

* change `state_timeseries` dispatch to work with vector of AbstractBlox

---------

Co-authored-by: gabrevaya <[email protected]>
  • Loading branch information
harisorgn and gabrevaya authored Oct 8, 2024
1 parent 5faa7cd commit 1ea99b9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Literate.markdown.([
"./docs/src/tutorials/parkinsons.jl",
"./docs/src/tutorials/neural_assembly.jl",
"./docs/src/tutorials/ping_network.jl",
"./docs/src/tutorials/basal_ganglia.jl",
"./docs/src/tutorials/spectralDCM.jl"
],
"./docs/src/tutorials";
Expand Down
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pages = ["index.md",
"tutorials/parkinsons.md",
"tutorials/neural_assembly.md",
"tutorials/ping_network.md",
"tutorials/basal_ganglia.md",
"tutorials/spectralDCM.md"
],
"API" => "api.md",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

using Neuroblox
using StochasticDiffEq ## For building and solving differential equations problems
using MetaGraphs ## use its MetaGraph type to build the circuit
using CairoMakie ## For plotting
using Random ## For setting a random seed

Expand All @@ -28,12 +27,12 @@ unknowns(sys)

# Create and solve the SDE problem
## Define simulation parameters
tspan = (0.0, 5500.0) ## simulation time span [ms]
tspan = (0.0, 2000.0) ## simulation time span [ms]
dt = 0.05 ## time step for solving and saving [ms]

## Create a stochastic differential equation problem and use the RKMil method to solve it
prob = SDEProblem(sys, [], tspan, [])
sol = solve(prob, RKMil(); dt=dt, saveat=dt)
sol = solve(prob, RKMil(); dt=dt, saveat=dt, abstol = 1e-2, reltol = 1e-2)

# Plot voltage of a single neuron
plot(sol, idxs=1, axis = (xlabel = "time (ms)", ylabel = "membrane potential (mV)"))
Expand Down Expand Up @@ -71,7 +70,7 @@ fig

# We can also run multiple simulations in parallel and compute the average power spectrum
ens_prob = EnsembleProblem(prob)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=5)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=3, abstol = 1e-2, reltol = 1e-2);

powerspectrumplot(msn, ens_sol; state = "G",
method=welch_pgram, window=hanning,
Expand Down Expand Up @@ -104,7 +103,7 @@ add_edge!(g, 2, 1, Dict(:weight => weight_FSI_MSN, :density => density_FSI_MSN))
@named sys = system_from_graph(g)
prob = SDEProblem(sys, [], tspan, [])
ens_prob = EnsembleProblem(prob)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=5)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=3, abstol = 1e-2, reltol = 1e-2);

# Detect spikes and compute firing rates
spikes_msn = detect_spikes(msn, ens_sol[1]; threshold=-35)
Expand Down Expand Up @@ -170,7 +169,7 @@ add_edge!(g, 4, 2, Dict(:weight => weight_STN_FSI, :density => density_STN_FSI))
@named sys = system_from_graph(g)
prob = SDEProblem(sys, [], tspan, [])
ens_prob = EnsembleProblem(prob)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=5)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=3, abstol = 1e-2, reltol = 1e-2);

# Compute and plot power spectra for all components
fig = Figure(size = (1500, 600))
Expand Down Expand Up @@ -249,7 +248,7 @@ add_edge!(g, 4, 2, Dict(:weight => weight_STN_FSI, :density => density_STN_FSI))

prob = SDEProblem(sys, [], tspan, [])
ens_prob = EnsembleProblem(prob)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=5)
ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, trajectories=3, abstol = 1e-2, reltol = 1e-2);

# Compute and compare power spectra for all neural populations in Parkinsonian condition against their counterparts in baseline conditions.
powerspectrumplot(fig[2,1], msn, ens_sol; state = "G",
Expand Down Expand Up @@ -291,4 +290,4 @@ fig


# ## References
# [Adam, Elie M., et al. "Deep brain stimulation in the subthalamic nucleus for Parkinson's disease can restore dynamics of striatal networks." Proceedings of the National Academy of Sciences 119.19 (2022): e2120808119.](https://doi.org/10.1073/pnas.2120808119)
# [Adam, Elie M., et al. "Deep brain stimulation in the subthalamic nucleus for Parkinson's disease can restore dynamics of striatal networks." Proceedings of the National Academy of Sciences 119.19 (2022): e2120808119.](https://doi.org/10.1073/pnas.2120808119)
42 changes: 25 additions & 17 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,7 @@ function mean_firing_rate(spikes::SparseMatrixCSC, sol; trim_transient = 0,
return t, rₘ
end

function state_timeseries(blox, sol::SciMLBase.AbstractSolution,
state::String; ts=nothing)
function state_timeseries(blox, sol::SciMLBase.AbstractSolution, state::String; ts=nothing)

namespaced_name = namespaced_nameof(blox)
state_name = Symbol(namespaced_name, "$(state)")
Expand All @@ -504,15 +503,23 @@ function state_timeseries(blox, sol::SciMLBase.AbstractSolution,
end
end

function state_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution, state::String; ts=nothing)
function state_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractBlox}},
sol::SciMLBase.AbstractSolution, state::String; ts=nothing)

neurons = get_neurons(cb)
state_names = map(neuron -> Symbol(namespaced_nameof(neuron), "", state), neurons)

return mapreduce(hcat, get_neurons(cb)) do neuron
state_timeseries(neuron, sol, state; ts)
if isnothing(ts)
s = stack(sol[state_names], dims=1)
else
s = transpose(Array(sol(ts; idxs=state_names)))
end

return s
end

function meanfield_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution,
state::String; ts=nothing)
function meanfield_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}},
sol::SciMLBase.AbstractSolution, state::String; ts=nothing)

s = state_timeseries(cb, sol, state; ts)

Expand All @@ -522,31 +529,32 @@ end
voltage_timeseries(blox, sol::SciMLBase.AbstractSolution; ts=nothing) =
state_timeseries(blox, sol, "V"; ts)

function voltage_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractBlox}}, sol::SciMLBase.AbstractSolution; ts=nothing)

return mapreduce(hcat, get_neurons(cb)) do neuron
voltage_timeseries(neuron, sol; ts)
end
function voltage_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractBlox}},
sol::SciMLBase.AbstractSolution; ts=nothing)
return state_timeseries(cb, sol, "V"; ts)
end

function meanfield_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution; ts=nothing)
function meanfield_timeseries(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}},
sol::SciMLBase.AbstractSolution; ts=nothing)
V = voltage_timeseries(cb, sol; ts)
replace_refractory!(V, cb, sol)

return vec(mapslices(nanmean, V; dims = 2))
end

function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution, state::String;
sampling_rate=nothing, method=periodogram, window=nothing)
function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}},
sol::SciMLBase.AbstractSolution, state::String; sampling_rate=nothing,
method=periodogram, window=nothing)

t_sampled, sampling_freq = get_sampling_info(sol; sampling_rate=sampling_rate)
s = meanfield_timeseries(cb, sol, state; ts = t_sampled)

return method(s, fs=sampling_freq, window=window)
end

function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution;
sampling_rate=nothing, method=periodogram, window=nothing)
function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}},
sol::SciMLBase.AbstractSolution; sampling_rate=nothing,
method=periodogram, window=nothing)

t_sampled, sampling_freq = get_sampling_info(sol; sampling_rate=sampling_rate)
V = voltage_timeseries(cb, sol; ts = t_sampled)
Expand Down

0 comments on commit 1ea99b9

Please sign in to comment.