Skip to content

Commit

Permalink
Reenable the connection matrix storage for the DBS structures (#455)
Browse files Browse the repository at this point in the history
* reenable the connection matrix storage for the DBS structures

* - add `mean_firing_rate` dispatch for EnsembleProblems
- fix the trimming of initial transient

* add firing rate tests

* FSI fixes

* move firing rate tests to components

* fix Striatum_FSI_Adam blox in GraphDynamics

---------

Co-authored-by: haris organtzidis <[email protected]>
Co-authored-by: gabrevaya <[email protected]>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent 234c0d1 commit 26ed573
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 93 deletions.
16 changes: 8 additions & 8 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,17 +567,17 @@ function blox_wiring_rule!(h, s::Striatum_FSI_Adam, v, kwargs)
# just like how its done in Neuroblox.jl
g = MetaDiGraph()
add_vertices!(g, N_inhib)
for i axes(connection_matrix, 1)
for j axes(connection_matrix, 2)
cij = connection_matrix[i,j]
if iszero(cij.weight) && iszero(cij.g_weight)
for i axes(connection_matrix, 2)
for j axes(connection_matrix, 1)
cji = connection_matrix[j, i]
if iszero(cji.weight) && iszero(cji.g_weight)
nothing
elseif iszero(cij.g_weight)
add_edge!(g, i, i, Dict(:weight=>cij.weight))
elseif iszero(cji.g_weight)
add_edge!(g, j, i, Dict(:weight=>cji.weight))
else
add_edge!(g, i, j, Dict(:weight=>cij.weight,
add_edge!(g, j, i, Dict(:weight=>cji.weight,
:gap => true,
:gap_weight => cij.g_weight))
:gap_weight => cji.g_weight))
end
end
end
Expand Down
178 changes: 108 additions & 70 deletions src/blox/DBS_Model_Blox_Adam_Brown.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,62 @@
Subcortical Blox used for DBS model in Adam et al,2021
"""

function adam_connection_matrix(density, N, weight)
connection_matrix = zeros(N, N)
in_degree = Int(ceil(density*(N)))
idxs = 1:N
for i in idxs
source_set = setdiff(idxs, i)
source = sample(source_set, in_degree; replace=false)
for j in source
connection_matrix[j, i] = weight / in_degree
end
end
connection_matrix
end

function adam_connection_matrix_gap(density, g_density, N, weight, g_weight)
connection_matrix = [(weight = 0.0, g_weight = 0.0) for _ 1:N, _ 1:N]
in_degree = Int(ceil(density*N))
gap_degree = Int(ceil(g_density*N))
idxs = 1:N
gap_junctions = zeros(Int, N)
for i in idxs
if gap_junctions[i] < gap_degree
other_fsi = setdiff(idxs,i)
rem = findall(x -> x < gap_degree, gap_junctions[other_fsi])
gap_idx = sample(rem, min(gap_degree, length(rem)); replace=false)
gap_nbr = other_fsi[gap_idx]
gap_junctions[i] += length(gap_idx)
gap_junctions[gap_nbr] .+= 1
else
gap_nbr = []
end
source_set = setdiff(idxs, i)
syn_source = sample(source_set, in_degree; replace=false)
only_syn=setdiff(syn_source,gap_nbr)
only_gap=setdiff(gap_nbr,syn_source)
syn_gap=intersect(syn_source,gap_nbr)
for j in only_syn
connection_matrix[j, i] = (;weight = weight/in_degree, g_weight=0)
end
for j in only_gap
connection_matrix[j, i] = (;weight = 0, g_weight=g_weight/gap_degree)
end
for j in syn_gap
connection_matrix[j, i] = (;weight = weight/in_degree, g_weight=g_weight/gap_degree)
end
end
connection_matrix
end

struct Striatum_MSN_Adam <: CompositeBlox
namespace
parts
odesystem
connector
mean
connection_matrix

function Striatum_MSN_Adam(;
name,
Expand All @@ -21,7 +71,8 @@ struct Striatum_MSN_Adam <: CompositeBlox
σ=0.11,
density=0.3,
weight=0.1,
G_M=1.3
G_M=1.3,
connection_matrix=nothing
)
n_inh = [
HHNeuronInhib_MSN_Adam_Blox(
Expand All @@ -42,16 +93,17 @@ struct Striatum_MSN_Adam <: CompositeBlox
for i in Base.OneTo(N_inhib)
add_blox!(g, n_inh[i])
end
in_degree = Int(ceil(density*(N_inhib)))
idxs = Base.OneTo(N_inhib)
for i in idxs
source_set = setdiff(idxs, i)
source = sample(source_set, in_degree; replace=false)
for j in source
add_edge!(g, j, i, Dict(:weight=>weight/in_degree))
if isnothing(connection_matrix)
connection_matrix = adam_connection_matrix(density, N_inhib, weight)
end
for i axes(connection_matrix, 2)
for j axes(connection_matrix, 1)
cji = connection_matrix[j,i]
if !iszero(cji)
add_edge!(g, j, i, Dict(:weight => cji))
end
end
end

parts = n_inh

bc = connector_from_graph(g)
Expand All @@ -61,12 +113,10 @@ struct Striatum_MSN_Adam <: CompositeBlox
m = if isnothing(namespace)
[s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")]
else
@variables t
sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name))
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m)
new(namespace, parts, sys, bc, m, connection_matrix)
end

end
Expand All @@ -77,6 +127,7 @@ struct Striatum_FSI_Adam <: CompositeBlox
odesystem
connector
mean
connection_matrix

function Striatum_FSI_Adam(;
name,
Expand All @@ -92,7 +143,8 @@ struct Striatum_FSI_Adam <: CompositeBlox
density=0.58,
g_density=0.33,
weight=0.6,
g_weight=0.15
g_weight=0.15,
connection_matrix=nothing
)
n_inh = [
HHNeuronInhib_FSI_Adam_Blox(
Expand All @@ -113,39 +165,22 @@ struct Striatum_FSI_Adam <: CompositeBlox
for i in Base.OneTo(N_inhib)
add_blox!(g, n_inh[i])
end
in_degree = Int(ceil(density*(N_inhib)))
gap_degree = Int(ceil(g_density*(N_inhib)))
idxs = Base.OneTo(N_inhib)

gap_junctions = zeros(Int, N_inhib)
for i in idxs
if gap_junctions[i]<gap_degree
other_fsi = setdiff(idxs,i)
rem = findall(x -> x < gap_degree, gap_junctions[other_fsi])
gap_idx = sample(rem, min(gap_degree, length(rem)); replace=false)
gap_nbr = other_fsi[gap_idx]
gap_junctions[i] += length(gap_idx)
gap_junctions[gap_nbr] .+= 1

else
gap_nbr = []
end
source_set = setdiff(idxs, i)
syn_source = sample(source_set, in_degree; replace=false)
only_syn=setdiff(syn_source,gap_nbr)
only_gap=setdiff(gap_nbr,syn_source)
syn_gap=intersect(syn_source,gap_nbr)
for j in only_syn
add_edge!(g, j, i, Dict(:weight=>weight/in_degree))
end
for j in only_gap
add_edge!(g, j, i, Dict(:weight=>0, :gap => true, :gap_weight => g_weight/gap_degree))
end

for j in syn_gap
add_edge!(g, j, i, Dict(:weight=>weight/in_degree, :gap => true, :gap_weight => g_weight/gap_degree))
if isnothing(connection_matrix)
connection_matrix = adam_connection_matrix_gap(density, g_density, N_inhib, weight, g_weight)
end
for i axes(connection_matrix, 2)
for j axes(connection_matrix, 1)
cji = connection_matrix[j, i]
if iszero(cji.weight) && iszero(cji.g_weight)
nothing
elseif iszero(cji.g_weight)
add_edge!(g, j, i, Dict(:weight=>cji.weight))
else
add_edge!(g, j, i, Dict(:weight=>cji.weight,
:gap => true,
:gap_weight => cji.g_weight))
end
end

end

parts = n_inh
Expand All @@ -162,17 +197,18 @@ struct Striatum_FSI_Adam <: CompositeBlox
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m)
new(namespace, parts, sys, bc, m, connection_matrix)
end

end
end

struct GPe_Adam <: CompositeBlox
namespace
parts
odesystem
connector
mean
connection_matrix

function GPe_Adam(;
name,
Expand All @@ -186,6 +222,7 @@ struct GPe_Adam <: CompositeBlox
σ=1.7,
density=0.0,
weight=0.0,
connection_matrix=nothing
)
n_inh = [
HHNeuronInhib_MSN_Adam_Blox(
Expand All @@ -205,16 +242,17 @@ struct GPe_Adam <: CompositeBlox
for i in Base.OneTo(N_inhib)
add_blox!(g, n_inh[i])
end
in_degree = Int(ceil(density*(N_inhib)))
idxs = Base.OneTo(N_inhib)
for i in idxs
source_set = setdiff(idxs, i)
source = sample(source_set, in_degree; replace=false)
for j in source
add_edge!(g, j, i, Dict(:weight=>weight/in_degree))
if isnothing(connection_matrix)
connection_matrix = adam_connection_matrix(density, N_inhib, weight)
end
for i axes(connection_matrix, 2)
for j axes(connection_matrix, 1)
cji = connection_matrix[j,i]
if !iszero(cji)
add_edge!(g, j, i, Dict(:weight => cji))
end
end
end

parts = n_inh

bc = connector_from_graph(g)
Expand All @@ -224,12 +262,11 @@ struct GPe_Adam <: CompositeBlox
m = if isnothing(namespace)
[s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")]
else
@variables t
sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name))
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m)
new(namespace, parts, sys, bc, m, connection_matrix)
end

end
Expand All @@ -240,6 +277,7 @@ struct STN_Adam <: CompositeBlox
odesystem
connector
mean
connection_matrix

function STN_Adam(;
name,
Expand All @@ -252,7 +290,8 @@ struct STN_Adam <: CompositeBlox
τ_exci=2,
σ=1.7,
density=0.0,
weight=0.0
weight=0.0,
connection_matrix=nothing
)
n_exci = [
HHNeuronExci_STN_Adam_Blox(
Expand All @@ -272,16 +311,17 @@ struct STN_Adam <: CompositeBlox
for i in Base.OneTo(N_exci)
add_blox!(g, n_exci[i])
end
in_degree = Int(ceil(density*(N_exci)))
idxs = Base.OneTo(N_exci)
for i in idxs
source_set = setdiff(idxs, i)
source = sample(source_set, in_degree; replace=false)
for j in source
add_edge!(g, j, i, Dict(:weight=>weight/in_degree))
if isnothing(connection_matrix)
connection_matrix = adam_connection_matrix(density, N_exci, weight)
end
for i axes(connection_matrix, 2)
for j axes(connection_matrix, 1)
cji = connection_matrix[j,i]
if !iszero(cji)
add_edge!(g, j, i, Dict(:weight => cji))
end
end
end

parts = n_exci

bc = connector_from_graph(g)
Expand All @@ -291,12 +331,10 @@ struct STN_Adam <: CompositeBlox
m = if isnothing(namespace)
[s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")]
else
@variables t
sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name))
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m)
new(namespace, parts, sys, bc, m, connection_matrix)
end

end
end
42 changes: 32 additions & 10 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,26 +472,48 @@ function firing_rate(
end

function mean_firing_rate(spikes::SparseMatrixCSC, sol; trim_transient = 0,
firing_rate_Δt = last(sol.t) - trim_transient,)
Δt = last(sol.t) - trim_transient)

spikes = transpose(spikes)
tmax = last(sol.t) - trim_transient
t = trim_transient:firing_rate_Δt:tmax
t_fr = trim_transient:Δt:last(sol.t)
fr = _mean_firing_rate(spikes, unique(sol.t), t_fr)

return t_fr, fr
end

function mean_firing_rate(blox, sols::SciMLBase.EnsembleSolution; trim_transient = 0,
Δt = last(sols[1].t) - trim_transient, threshold = -35)

t_fr = trim_transient:Δt:last(sols[1].t)
firing_rates = Vector{Float64}[]

for sol in sols
spikes = detect_spikes(blox, sol; threshold = threshold)
spikes = transpose(spikes)
fr = _mean_firing_rate(spikes, unique(sol.t), t_fr)
push!(firing_rates, fr)
end

mean_fr = mean(firing_rates)
std_fr = std(firing_rates)

return t_fr, mean_fr, std_fr
end

function _mean_firing_rate(spikes, t, t_fr)

tᵤ = unique(sol.t)
counts = vec(sum(spikes, dims=1))
fr = fill(NaN64, length(t_fr) - 1)

rₘ = fill(NaN64, length(t) - 1)
for i in 2:length(t)
idx = intersect(findall(tᵤ .<= t[i]), findall(tᵤ .> t[i-1]))
for i in 2:length(t_fr)
idx = intersect(findall(t .<= t_fr[i]), findall(t .> t_fr[i-1]))
if ~isempty(idx)
rₘ[i-1] = sum(counts[idx])
fr[i-1] = sum(counts[idx])
end
end

# firing rate in spikes/s averaged over the population
rₘ = rₘ*1000 ./ (size(spikes,1)*firing_rate_Δt)
return t, rₘ
fr = fr*1000 ./ (size(spikes, 1)*(t_fr[2] - t_fr[1]))
end

function state_timeseries(blox, sol::SciMLBase.AbstractSolution, state::String; ts=nothing)
Expand Down
Loading

0 comments on commit 26ed573

Please sign in to comment.