diff --git a/docs/src/tutorials/basal_ganglia.jl b/docs/src/tutorials/basal_ganglia.jl index 6a2de78a..a90c81cb 100644 --- a/docs/src/tutorials/basal_ganglia.jl +++ b/docs/src/tutorials/basal_ganglia.jl @@ -55,10 +55,10 @@ rasterplot(msn, sol, threshold = -35, title = "Neuron's Spikes - Mean Firing Rat # Compute and plot the power spectrum of the GABAa current fig = Figure(size = (1500, 500)) -powerspectrumplot(fig[1,1], msn, sol, state = "G", +powerspectrumplot(fig[1,1], msn, sol, state = "I_syn_msn", title = "FFT with no window") -powerspectrumplot(fig[1,2], msn, sol, state = "G", +powerspectrumplot(fig[1,2], msn, sol, state = "I_syn_msn", method = welch_pgram, window = hanning, title = "Welch's method + Hanning window") fig @@ -105,12 +105,12 @@ fig = Figure(size = (1000, 800)) rasterplot(fig[1,1], msn, ens_sol[1], threshold = -35, title = "MSN - Mean Firing Rate: $(round(fr_msn[1], digits=2)) spikes/s") rasterplot(fig[1,2], fsi, ens_sol[1], threshold = -35, title = "FSI - Mean Firing Rate: $(round(fr_fsi[1], digits=2)) spikes/s") -powerspectrumplot(fig[2,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[2,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims= (-35, 15), xlims= (8, 100)) -powerspectrumplot(fig[2,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[2,2], fsi, ens_sol, state = "I_syn_fsi", method=welch_pgram, window=hanning, ylims= (-35, 15), xlims= (8, 100)) @@ -153,22 +153,22 @@ ens_sol = solve(ens_prob, RKMil(), dt=dt, saveat = dt, trajectories = 3); # Compute and plot power spectra for all components fig = Figure(size = (1600, 450)) -powerspectrumplot(fig[1,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[1,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "MSN (Baseline)") -powerspectrumplot(fig[1,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[1,2], fsi, ens_sol, state = "I_syn_fsi", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "FSI (Baseline)") -powerspectrumplot(fig[1,3], gpe, ens_sol, state = "G", +powerspectrumplot(fig[1,3], gpe, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "GPe (Baseline)") -powerspectrumplot(fig[1,4], stn, ens_sol, state = "G", +powerspectrumplot(fig[1,4], stn, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "STN (Baseline)") @@ -215,22 +215,22 @@ ens_prob = EnsembleProblem(prob) ens_sol = solve(ens_prob, RKMil(), dt = dt, saveat = dt, trajectories = 3); # Compute and compare power spectra for all neural populations in Parkinsonian condition against their counterparts in baseline conditions. -powerspectrumplot(fig[2,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[2,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "MSN (PD)") -powerspectrumplot(fig[2,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[2,2], fsi, ens_sol, state = "I_syn_fsi", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "FSI (PD)") -powerspectrumplot(fig[2,3], gpe, ens_sol, state = "G", +powerspectrumplot(fig[2,3], gpe, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "GPe (PD)") -powerspectrumplot(fig[2,4], stn, ens_sol, state = "G", +powerspectrumplot(fig[2,4], stn, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "STN (PD)") diff --git a/examples/tune_parameters.jl b/examples/tune_parameters.jl new file mode 100644 index 00000000..4f8d6d4e --- /dev/null +++ b/examples/tune_parameters.jl @@ -0,0 +1,371 @@ +using Neuroblox +using StochasticDiffEq +using Random +using Statistics +using Optimization +using OptimizationOptimJL +using ModelingToolkit: setp, getp + +abstract type AbstractPopulationMetric end + +struct Population{B,N,MT<:NTuple{N,AbstractPopulationMetric}} + name::Symbol + blox::B + metrics::MT + tuning::Dict{String, Vector{Int}} + tunable::Bool +end + +function Population( + name, + blox; + frm=nothing, + freqm=nothing, + tuning_params=String[], + prob=nothing, + tunable::Bool=false +) + mt = () + if frm !== nothing + fr_target, fr_weight, fr_threshold, fr_transient = frm + mt = (FiringRateMetric(fr_target, fr_weight, fr_threshold, fr_transient),) + end + if freqm !== nothing + freq_target, freq_weight, fmin, fmax, agg = freqm + freq = FrequencyMetric(freq_target, freq_weight, fmin, fmax, agg) + mt = tuple(mt..., freq) + end + + tspec = isempty(tuning_params) || prob === nothing ? + Dict() : build_tuning_spec(prob, string(name), tuning_params) + + local N = length(mt) + return Population{typeof(blox), N, typeof(mt)}(name, blox, mt, tspec, tunable) +end + +# Find indexes of the parameters to be tuned +function build_tuning_spec(prob, pop_name::String, param_names::Vector{String}) + paramlist = string.(tunable_parameters(prob.f.sys)) + param_map = Dict{String,Vector{Int}}() + for pname in param_names + inds = findall(str -> occursin(pname, str) && occursin(pop_name, str), + paramlist) + param_map[pname] = inds + end + return param_map +end + +struct OptimizationConfig{P,PopT,GetterT,SetterT,SolverT,EnsembleAlgT,dtT,DiffeqkargsT} + prob::P + populations::PopT + get_ps::GetterT + set_ps!::SetterT + solver::SolverT + ensemblealg::EnsembleAlgT + dt::dtT + other_diffeq_kwargs::DiffeqkargsT + trajectories::Int + seed::Int +end + +########### +# Metrics # +########### +struct FiringRateMetric{T} <: AbstractPopulationMetric + target::T + weight::T + threshold::T + transient::T +end +struct FrequencyMetric{T} <: AbstractPopulationMetric + target::T + weight::T + min_freq::T + max_freq::T + aggregate_state::String +end + +function compute_error(m::FiringRateMetric, pop, sol; logging::Bool=false) + fr, fr_std = firing_rate(pop.blox, sol; + threshold=m.threshold, + transient=m.transient, + scheduler=:dynamic) + val = (fr[1] - m.target)^2 * m.weight + + if logging + @info "[$(pop.name)] Firing Rate = $(fr[1]) ± $(fr_std[1]) " * + " (target=$(m.target)), metric error = $val" + end + return val +end + +function compute_error(m::FrequencyMetric, pop, sol; logging::Bool=false) + powspecs = powerspectrum(pop.blox, sol, m.aggregate_state; + method=welch_pgram, + window=hamming) + peak_freq, peak_freq_std = get_peak_freq(powspecs, m.min_freq, m.max_freq) + val = abs(peak_freq - m.target) * m.weight + + if logging + @info "[$(pop.name)] Peak Frequency = $peak_freq ± $peak_freq_std " * + " in [$(m.min_freq), $(m.max_freq)] (target=$(m.target)), metric error = $val" + end + return val +end + +#################### +# Helper Functions # +#################### + +function get_peak_freq(powspecs, freq_min, freq_max) + freq_ind = get_freq_inds(powspecs[1].freq, freq_min, freq_max) + if isempty(freq_ind) + return NaN64, NaN64 + end + + # Average the power across the different trajectories + mean_power = mean(powspec.power[freq_ind] for powspec in powspecs) + ind = argmax(mean_power) + + # alternative method + # peak_freqs = [get_peak_freq(powspec, freq_min, freq_max, freq_ind=freq_ind) for powspec in powspecs] + # mean_peak_freq = mean(peak_freqs) + # std_peak_freq = std(peak_freqs) + + return powspecs[1].freq[freq_ind][ind], NaN64 +end + +function get_freq_inds(freq, freq_min, freq_max) + return findall(x -> x > freq_min && x < freq_max, freq) +end + +# Mason's unroll macro, from https://github.com/Neuroblox/GraphDynamics.jl/blob/6c0bbb81abf1981c52a4605dc32d7073fea2ff0d/src/utils.jl#L10 +macro unroll(N::Int, loop) + Base.isexpr(loop, :for) || error("only works on for loops") + Base.isexpr(loop.args[1], :(=)) || error("This loop pattern isn't supported") + val, itr = esc.(loop.args[1].args) + body = esc(loop.args[2]) + @gensym loopend + label = :(@label $loopend) + goto = :(@goto $loopend) + out = Expr(:block, :(itr = $itr), :(next = iterate(itr))) + unrolled = map(1:N) do _ + quote + isnothing(next) && @goto loopend + $val, state = next + $body + next = iterate(itr, state) + end + end + append!(out.args, unrolled) + remainder = quote + while !isnothing(next) + $val, state = next + $body + next = iterate(itr, state) + end + @label loopend + end + push!(out.args, remainder) + out +end + +function compute_errors(pop::Population, sol; logging::Bool=false) + total = zero(eltype(sol)) + for m in pop.metrics + total += compute_error(m, pop, sol; logging=logging) + end + return total +end + +function update_parameters!(prob, populations, p, get_ps, set_ps!) + ps_new = get_ps(prob) + offset = 1 + + @unroll 16 for pop in populations + if pop.tunable + for (param_name, inds) in pop.tuning + ps_new[inds] .= abs.(p[offset]) + offset += 1 + end + end + end + + set_ps!(prob, ps_new) + return nothing +end + +function sum_errors(pops, sol; logging=false) + total_err = zero(eltype(sol)) + @unroll 16 for pop in pops + total_err += compute_errors(pop, sol; logging=logging) + end + return total_err +end + +function loss(p, config::OptimizationConfig; logging::Bool=false) + # Set random seed + Random.seed!(config.seed) + + # Update prob in-place + update_parameters!(config.prob, config.populations, p, config.get_ps, config.set_ps!) + + # Solve + ens_prob = EnsembleProblem(config.prob) + sol = solve(ens_prob, config.solver, config.ensemblealg; + trajectories=config.trajectories, + dt=config.dt, + saveat=config.dt, + config.other_diffeq_kwargs...) + + # Sum errors from each population + total_err = sum_errors(config.populations, sol; logging=logging) + return total_err +end + +################# +# Example usage # +################# + +function create_problem(size, T) + @info "Size = $size" + Random.seed!(123) + N_MSN = round(Int64, 100*size) + N_FSI = round(Int64, 50*size) + N_GPe = round(Int64, 80*size) + N_STN = round(Int64, 40*size) + + make_conn = Neuroblox.indegree_constrained_connection_matrix + + global_ns = :g + + ḡ_FSI_MSN = 0.6 + density_FSI_MSN = 0.15 + weight_FSI_MSN = ḡ_FSI_MSN / (N_FSI * density_FSI_MSN) + conn_FSI_MSN = make_conn(density_FSI_MSN, N_FSI, N_MSN) + + ḡ_MSN_GPe = 2.5 + density_MSN_GPe = 0.33 + weight_MSN_GPe = ḡ_MSN_GPe / (N_MSN * density_MSN_GPe) + conn_MSN_GPe = make_conn(density_MSN_GPe, N_MSN, N_GPe) + + ḡ_GPe_STN = 0.3 + density_GPe_STN = 0.05 + weight_GPe_STN = ḡ_GPe_STN / (N_GPe * density_GPe_STN) + conn_GPe_STN = make_conn(density_GPe_STN, N_GPe, N_STN) + + ḡ_STN_FSI = 0.165 + density_STN_FSI = 0.1 + weight_STN_FSI = ḡ_STN_FSI / (N_STN * density_STN_FSI) + conn_STN_FSI = make_conn(density_STN_FSI, N_STN, N_FSI) + + @named msn = Striatum_MSN_Adam( + namespace=global_ns, + N_inhib=N_MSN, + I_bg=1.153064742988923*ones(N_MSN), + σ=0.17256774881503584 + ) + @named fsi = Striatum_FSI_Adam( + namespace=global_ns, + N_inhib=N_FSI, + I_bg=6.196201739395473*ones(N_FSI), + σ=0.9548801242101033 + ) + @named gpe = GPe_Adam( + namespace=global_ns, + N_inhib=N_GPe, + I_bg=3.272893843123162*ones(N_GPe), + σ=1.0959782801317943 + ) + @named stn = STN_Adam( + namespace=global_ns, + N_exci=N_STN, + I_bg=2.2010777359961953*ones(N_STN), + σ=2.9158528502583545 + ) + + g = MetaDiGraph() + add_edge!(g, fsi => msn, weight=weight_FSI_MSN, connection_matrix=conn_FSI_MSN) + add_edge!(g, msn => gpe, weight=weight_MSN_GPe, connection_matrix=conn_MSN_GPe) + add_edge!(g, gpe => stn, weight=weight_GPe_STN, connection_matrix=conn_GPe_STN) + add_edge!(g, stn => fsi, weight=weight_STN_FSI, connection_matrix=conn_STN_FSI) + + @info "Creating system from graph" + @named sys = system_from_graph(g) + + tspan = (0.0, T) + @info "Creating SDEProblem" + prob = SDEProblem{true}(sys, [], tspan, []) + + return prob, sys, msn, fsi, gpe, stn +end + +# Example +prob, sys, msn, fsi, gpe, stn = create_problem(0.1, 5500.0) + +msn_pop = Population( + :msn, msn; + frm = (1.21, 60.0, -35.0, 200.0), # FiringRateMetric: (target, weight, threshold, transient) + freqm = (10.0, 0.5, 3.0, 20.0, "I_syn_msn"), # FrequencyMetric: (target, weight, fmin, fmax, aggregate_state) + prob = prob, + tunable = false +) + +fsi_pop = Population( + :fsi, fsi; + frm = (13.0, 10.0, -35.0, 200.0), + freqm = (61.14, 1.0, 40.0, 90.0, "I_syn_fsi"), + prob = prob, + tunable = false +) + +gpe_pop = Population( + :gpe, gpe; + freqm = (85.0, 0.5, 40.0, 90.0, "V"), + tuning_params = ["I_bg", "σ"], + prob = prob, + tunable = true +) + +stn_pop = Population( + :stn, stn; + tuning_params = ["I_bg", "σ"], + prob = prob, + tunable = false +) + +other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10) +get_ps = getp(prob, tunable_parameters(sys)) +set_ps! = setp(prob, tunable_parameters(sys)) + +config = OptimizationConfig( + prob, + (msn_pop, fsi_pop, gpe_pop, stn_pop), + get_ps, + set_ps!, + RKMil(), + EnsembleThreads(), + 0.1, + other_diffeq_kwargs, + 3, + 1234 +) + +p0 = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] + +# optprob = Optimization.OptimizationProblem(loss, p0, config) +optprob = Optimization.OptimizationProblem((p, config)->loss(p, config; logging=true), p0, config) +callback = function (state, l) + println("\n") + @info "Iteration: $(state.iter)" + @info "Parameters: $(state.u)" + @info "Loss: $l" + println("\n") + return false +end + +# Example run +res = solve(optprob, Optim.NelderMead(); + maxiters=2, + callback=callback +) \ No newline at end of file diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 093808fa..4ac7c412 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -4,7 +4,7 @@ import Base: merge using Base.Threads: nthreads -using OhMyThreads: tmapreduce +using OhMyThreads: tmapreduce, tforeach using Reexport @reexport using ModelingToolkit diff --git a/src/blox/DBS_Model_Blox_Adam_Brown.jl b/src/blox/DBS_Model_Blox_Adam_Brown.jl index 11f7827d..c7f1598b 100644 --- a/src/blox/DBS_Model_Blox_Adam_Brown.jl +++ b/src/blox/DBS_Model_Blox_Adam_Brown.jl @@ -64,11 +64,11 @@ struct Striatum_MSN_Adam <: CompositeBlox namespace = nothing, N_inhib = 100, E_syn_inhib=-80, - I_bg=1.172*ones(N_inhib), + I_bg=1.153064742988923*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=13, - σ=0.11, + σ=0.17256774881503584, density=0.3, weight=0.1, G_M=1.3, @@ -134,12 +134,12 @@ struct Striatum_FSI_Adam <: CompositeBlox namespace = nothing, N_inhib = 50, E_syn_inhib=-80, - I_bg=6.2*ones(N_inhib), + I_bg=6.196201739395473*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=11, τ_inhib_s=6.5, - σ=1.2, + σ=0.9548801242101033, density=0.58, g_density=0.33, weight=0.6, @@ -215,17 +215,17 @@ struct GPe_Adam <: CompositeBlox namespace = nothing, N_inhib = 80, E_syn_inhib=-80, - I_bg=3.4*ones(N_inhib), + I_bg=3.272893843123162*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=10, - σ=1.7, + σ=1.0959782801317943, density=0.0, weight=0.0, connection_matrix=nothing ) n_inh = [ - HHNeuronInhib_MSN_Adam_Blox( + HHNeuronInhib_GPe_Adam_Blox( name = Symbol("inh$i"), namespace = namespaced_name(namespace, name), E_syn = E_syn_inhib, @@ -284,11 +284,11 @@ struct STN_Adam <: CompositeBlox namespace = nothing, N_exci = 40, E_syn_exci=0.0, - I_bg=1.8*ones(N_exci), + I_bg=2.2010777359961953*ones(N_exci), freq=zeros(N_exci), phase=zeros(N_exci), τ_exci=2, - σ=1.7, + σ=2.9158528502583545, density=0.0, weight=0.0, connection_matrix=nothing diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index a00fcd4c..1f18bbb1 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -532,7 +532,8 @@ function firing_rate( 1000.0 * (nnz(spikes[idx_start:idx_end, :]) / N_neurons) / win_size end - return fr + T = eltype(sol) + return fr::Vector{T} end function firing_rate( @@ -635,16 +636,23 @@ function powerspectrum(blox::AbstractNeuronBlox, sol::SciMLBase.AbstractSolution end function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, - sols::SciMLBase.EnsembleSolution, state::String; sampling_rate=nothing, - method=periodogram, window=nothing)::Vector{DSP.Periodograms.Periodogram} - - t_sampled, sampling_freq = get_sampling_info(sols[1]; sampling_rate=sampling_rate) - powspecs = DSP.Periodograms.Periodogram[] + sols::SciMLBase.EnsembleSolution{T}, + state::String; + sampling_rate=nothing, + method=periodogram, + window=nothing + ) where {T} - for sol in sols - s = meanfield_timeseries(cb, sol, state; ts = t_sampled) - powspec = method(s, fs=sampling_freq, window=window) - push!(powspecs, powspec) + t_sampled, sampling_freq = get_sampling_info(sols[1]; sampling_rate=sampling_rate) + + # Pre-allocate concretely typed array + powspecs = Vector{DSP.Periodograms.Periodogram{T, + DSP.Frequencies{T}, + Vector{T}}}(undef, length(sols)) + tforeach(eachindex(sols)) do i + sol = sols[i] + s = meanfield_timeseries(cb, sol, state; ts=t_sampled) + powspecs[i] = method(s, fs=sampling_freq, window=window) end return powspecs diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 143e342c..dd4e2035 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -357,27 +357,14 @@ function Connector( sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) end - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) -end - -function Connector( - blox_src::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}, - blox_dest::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; - kwargs... -) - sys_src = get_namespaced_sys(blox_src) - sys_dest = get_namespaced_sys(blox_dest) - - w = generate_weight_param(blox_src, blox_dest; kwargs...) - - STA = get_sta(kwargs, nameof(blox_src), nameof(blox_dest)) - eq = if STA - sys_dest.I_syn ~ -w * sys_dest.Gₛₜₚ * sys_src.G * (sys_dest.V - sys_src.E_syn) + if blox_src isa HHNeuronInhib_MSN_Adam_Blox && blox_dest isa HHNeuronInhib_MSN_Adam_Blox + eq2 = sys_dest.I_syn_msn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + eqs = [eq, eq2] else - sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + eqs = eq end - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eqs, weight=w, learning_rule=Dict(w => lr)) end function Connector( @@ -406,6 +393,7 @@ function Connector( w = generate_weight_param(blox_src, blox_dest; kwargs...) eq = sys_dest.I_syn ~ -w * sys_src.Gₛ * (sys_dest.V - sys_src.E_syn) + eq4 = sys_dest.I_syn_fsi ~ -w * sys_src.Gₛ * (sys_dest.V - sys_src.E_syn) GAP = get_gap(kwargs, nameof(blox_src), nameof(blox_dest)) if GAP @@ -413,9 +401,9 @@ function Connector( eq2 = sys_dest.I_gap ~ -w_gap * (sys_dest.V - sys_src.V) eq3 = sys_src.I_gap ~ -w_gap * (sys_src.V - sys_dest.V) - return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq2, eq3], weight=[w, w_gap]) + return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq2, eq3, eq4], weight=[w, w_gap]) else - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) + return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq4], weight=w) end end diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index eeb521e5..8bd7fa4f 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -199,7 +199,14 @@ struct HHNeuronInhib_MSN_Adam_Blox <: AbstractInhNeuronBlox I_asc(t) [input=true] G(t)=0.0 - [output =true] + [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 + + # observables + I_syn_msn(t)=0.0 + [input=true] end ps = @parameters begin @@ -291,7 +298,14 @@ struct HHNeuronInhib_FSI_Adam_Blox <: AbstractInhNeuronBlox G(t)=0.0 [output=true] Gₛ(t)=0.0 - [output=true] + [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 + + # observables + I_syn_fsi(t)=0.0 + [input=true] end ps = @parameters begin @@ -379,6 +393,9 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox [input=true] G(t)=0.0 [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 end ps = @parameters begin @@ -458,6 +475,9 @@ struct HHNeuronInhib_GPe_Adam_Blox <: AbstractInhNeuronBlox [input=true] G(t)=0.0 [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 end ps = @parameters begin