Skip to content

Commit

Permalink
Make neurons from Decision Making tutorial much faster with GraphDyna…
Browse files Browse the repository at this point in the history
…mics (#484)

* remove soem junk

* more junk

* switch to using `ForeachConnectedSubsystem` instead of composite events

* make connections subtypes of ConnectionRule

* allow skipping building `ODESystem`

* bump GraphDynamics compat

* remove more unused comments
  • Loading branch information
MasonProtter authored and david-hofmann committed Nov 11, 2024
1 parent e445d79 commit 53b29c2
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 211 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DataFrames = "1.3"
Distributions = "0.25.102"
ExponentialUtilities = "1"
ForwardDiff = "0.10"
GraphDynamics = "0.1.5"
GraphDynamics = "0.2"
Graphs = "1"
Interpolations = "0.14, 0.15"
MetaGraphs = "0.7"
Expand Down
24 changes: 6 additions & 18 deletions src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ using GraphDynamics:
StateIndex,
ParamIndex,
event_times,
calculate_inputs,
connection_index
calculate_inputs

using Random:
Random,
Expand Down Expand Up @@ -256,12 +255,6 @@ function populate_flatgraph(h, g, blox, v, g_i, h_i)
if length(components(blox)) == 1 && only(components(blox)) == blox
h_i += 1
add_subsystem!(h, to_subsystem(blox), Neuroblox.namespaced_nameof(blox))
# add_vertices!(h, 1)
# subsystem = to_subsystem(blox)
# name = Neuroblox.namespaced_nameof(blox)

# set_subsystem!(h, to_subsystem(blox), h_i)
# set_name!(h, name, h_i)
if v isa Dict
@assert !haskey(v, g_i)
v[g_i] = h_i
Expand Down Expand Up @@ -341,14 +334,9 @@ situations where tiny matrices like (e.g. 5x5) get stored as sparse arrays rathe
function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0)
check_all_supported_blox(_g)
g = flat_graph(_g)

total_eltype = mapreduce(promote_type, vertices(g)) do i
eltype(get_subsystem(g, i))
end
fix_eltype(s::Subsystem{Name}) where {Name} = convert(Subsystem{Name, total_eltype}, s)


subsystems_and_names_flat = map(vertices(g)) do i
(subsystem = fix_eltype(get_subsystem(g, i)), name = get_name(g, i))
(subsystem = get_subsystem(g, i), name = get_name(g, i))
end
names_flat = map(last, subsystems_and_names_flat)
subsystems_flat = map(first, subsystems_and_names_flat)
Expand Down Expand Up @@ -433,9 +421,9 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_
end
end
end
states_partitioned = map(v -> map(get_states, v), subsystems)
states_partitioned = map(v -> map(get_states, v), subsystems)
params_partitioned = map(v -> map(get_params, v), subsystems)
names_partitioned = map(v -> map(last, v), subsystems_and_names)
names_partitioned = map(v -> map(last, v), subsystems_and_names)

composite_continuous_events_partitioned = let
if isempty(g.composite_continuous_events_builder)
Expand Down Expand Up @@ -479,5 +467,5 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_
end
end


end#module GraphDynamicsInterop

217 changes: 73 additions & 144 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ end
struct BasicConnection <: ConnectionRule
weight::Float64
end
Base.zero(::BasicConnection) = Base.zero(BasicConnection)
Base.zero(::Type{<:BasicConnection}) = BasicConnection(0.0)
function (c::BasicConnection)(blox_src, blox_dst)
(; jcn = c.weight * output(blox_src))
Expand Down Expand Up @@ -177,7 +178,7 @@ function get_connection(
(;conn, names)
end

struct HHConnection_GAP
struct HHConnection_GAP <: ConnectionRule
w::Float64
w_gap::Float64
w_gap_rev::Float64
Expand Down Expand Up @@ -211,7 +212,7 @@ end


#----------------------------------------------
# Kuramoto
# Kuramoto
function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs)
(;w_val, name) = generate_weight_param(src, dst, kwargs)
(;conn=BasicConnection(w_val), names=[name])
Expand All @@ -227,37 +228,6 @@ end
#----------------------------------------------
# LIFExci / LIFInh

function blox_wiring_rule!(h, blox::Union{LIFExciNeuron, LIFInhNeuron}, v, kwargs)
evbs = h.composite_discrete_events_builder
i = only(v)
push!(evbs, SpikeAffectEventBuilder(i, Int[], Int[]))
end


function blox_wiring_rule!(h,
blox_src::Union{LIFExciNeuron, LIFInhNeuron},
blox_dst::Union{LIFExciNeuron, LIFInhNeuron},
v_src, v_dst, kwargs)
#this is the fallback method for non-composite blox, hence vi and vj should have only one element
i, j = only(v_src), only(v_dst)
(; w_val, name) = generate_weight_param(blox_src, blox_dst, kwargs)
conn = BasicConnection(w_val)

let evbs = h.composite_discrete_events_builder
idx = findfirst(evb -> (evb isa SpikeAffectEventBuilder) && (evb.idx_src == i), evbs)
if isnothing(idx)
error("SpikeAffectEventBuilder for neuron not found, this indicates its blox wiring rule never ran.")
else
if blox_dst isa LIFExciNeuron
push!(evbs[idx].idx_dsts_exci, j)
elseif blox_dst isa LIFInhNeuron
push!(evbs[idx].idx_dsts_inh, j)
end
end
end
add_edge!(h, i, j, Dict(:conn => conn, :names => [name]))
end

function (c::BasicConnection)(sys_src::Subsystem{LIFExciNeuron},
sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
w = c.weight
Expand All @@ -271,111 +241,65 @@ function (c::BasicConnection)(::Subsystem{LIFInhNeuron},
(; jcn = 0.0)
end

struct SpikeAffectEventBuilder
idx_src::Int
idx_dsts_inh::Vector{Int}
idx_dsts_exci::Vector{Int}
const LIFExciInhNeuron = Union{LIFExciNeuron, LIFInhNeuron}
GraphDynamics.has_discrete_events(::Type{LIFExciNeuron}) = true
GraphDynamics.has_discrete_events(::Type{LIFInhNeuron}) = true
function GraphDynamics.discrete_event_condition((; t_refract_end, V, θ)::Subsystem{LIF}, t, _) where {LIF <: LIFExciInhNeuron}
# Triggers when either a refractory period is ending, or the neuron spiked (voltage exceeds threshold θ)
(V > θ) || (t_refract_end == t)
end
function GraphDynamics.apply_discrete_event!(integrator,
states_view_src, params_view_src,
neuron_src::Subsystem{LIF},
foreach_connected_neuron) where {LIF <: LIFExciInhNeuron}
t = integrator.t
if t == neuron_src.t_refract_end # Refreactory period is over
params = params_view_src[]
params_view_src[] = @set params.is_refractory = 0
else # Neuron fired
# Begin refractory period
params_src = params_view_src[]
@reset params_src.t_refract_end = t + params_src.t_refract_duration
@reset params_src.is_refractory = 1

add_tstop!(integrator, params_src.t_refract_end)
params_view_src[] = params_src

struct SpikeAffectEvent{i_src, i_LIFInh, i_LIFExci}
j_src::Int
j_dsts_inh::Vector{Int}
j_dsts_exci::Vector{Int}
end
# Reset the neuron voltage
states_view_src[:V] = params_src.V_reset

function (ev::SpikeAffectEventBuilder)(index_map)
(i_src, j_src) = index_map[ev.idx_src]
i_inh, j_dsts_inh = let v = ev.idx_dsts_inh
if isempty(v)
nothing, Int[]
else
index_map[first(v)][1], map(idx -> index_map[idx][2], v)
# Now apply a function to each connected dst neuron
foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
lif_exci_inh_update_connected_neuron(neuron_src, states_view_src, conn, neuron_dst, states_view_dst)
end
end
i_exci, j_dsts_exci = let v = ev.idx_dsts_exci
if isempty(v)
nothing, Int[]
else
index_map[first(v)][1], map(idx -> index_map[idx][2], v)
end
end
function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFExciNeuron},
states_view_src,
conn::BasicConnection,
neuron_dst::Subsystem{<:LIFExciInhNeuron},
states_view_dst)
w = conn.weight
# check if the neuron is connected to itself
if states_view_src === states_view_dst
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
states_view_dst[:x] += w
end
SpikeAffectEvent{i_src, i_inh, i_exci}(j_src, j_dsts_inh, j_dsts_exci)
states_view_dst[:S_AMPA] += w
nothing
end

function GraphDynamics.discrete_event_condition(states,
params,
connection_matrices,
ev::SpikeAffectEvent{i_src, i_dst_inh, i_dsts_exci},
t) where {i_src, i_dst_inh, i_dsts_exci}
(; j_src) = ev
neuron_src = Subsystem(states[i_src][j_src], params[i_src][j_src])
neuron_src.V > neuron_src.θ
function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFInhNeuron},
states_view_src,
conn::BasicConnection,
neuron_dst::Subsystem{<:LIFExciInhNeuron},
states_view_dst)
w = conn.weight
states_view_dst[:S_GABA] += w
nothing
end




function GraphDynamics.apply_discrete_event!(integrator,
states::NTuple{Len, Any},
params::NTuple{Len, Any},
connection_matrices,
t,
ev::SpikeAffectEvent{i_src, i_dst_inh, i_dst_exci}
) where {i_src, i_dst_inh, i_dst_exci, Len}
(; j_src, j_dsts_inh, j_dsts_exci) = ev

nc = connection_index(BasicConnection, connection_matrices)

params_src = params[i_src][j_src]
@reset params_src.t_refract_end = t + params_src.t_refract_duration
@reset params_src.is_refractory = 1

params[i_src][j_src] = params_src
add_tstop!(integrator, params_src.t_refract_end)

states_src = states[i_src][j_src]
states[i_src][:V, j_src] = params_src.V_reset
if (states_src isa SubsystemStates{LIFExciNeuron}) && (j_src j_dsts_exci)
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
w = connection_matrices[nc][i_src, i_src][j_src, j_src].weight
states[i_src][:x, j_src] += w
end

if states_src isa SubsystemStates{LIFExciNeuron}
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_AMPA, j_dst] += w
end
end
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_AMPA, j_dst] += w
end
end
elseif states_src isa SubsystemStates{LIFInhNeuron}
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_GABA, j_dst] += w
end
end
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_GABA, j_dst] += w
end
end
else
error("this should be unreachable")
end
end

function blox_wiring_rule!(h,
stim::PoissonSpikeTrain,
blox_dst::Union{LIFExciNeuron, LIFInhNeuron},
Expand All @@ -385,7 +309,7 @@ function blox_wiring_rule!(h,
conn = PoissonSpikeConn(w_val, Set(Neuroblox.generate_spike_times(stim)))
add_edge!(h, i, j, Dict(:conn => conn, :names => [name]))
end
struct PoissonSpikeConn
struct PoissonSpikeConn <: ConnectionRule
w::Float64
t_spikes::Set{Float64}
end
Expand All @@ -394,23 +318,29 @@ function ((;w)::PoissonSpikeConn)(stim::Subsystem{PoissonSpikeTrain},
blox_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
(; jcn = 0.0)
end
GraphDynamics.has_discrete_events(::PoissonSpikeConn) = true
GraphDynamics.has_discrete_events(::Type{PoissonSpikeConn}) = true
GraphDynamics.event_times((;t_spikes)::PoissonSpikeConn) = (t_spikes)
GraphDynamics.discrete_event_condition((;t_spikes)::PoissonSpikeConn, t) = (t t_spikes)

GraphDynamics.has_discrete_events(::Type{PoissonSpikeTrain}) = true
function GraphDynamics.discrete_event_condition(p::Subsystem{PoissonSpikeTrain}, t, foreach_connected_neuron::F) where {F}
# check if any of the downstream connections from p spike at time t.
cond = mapreduce(|, foreach_connected_neuron; init=false) do conn, _, _, _
t conn.t_spikes
end
end
function GraphDynamics.apply_discrete_event!(integrator,
_, _,
vstates_dst, _,
_::PoissonSpikeConn,
_::Subsystem{PoissonSpikeTrain},
_::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
states = vstates_dst[]
states = @set states.S_AMPA_ext += 1
vstates_dst[] = states
nothing
states_view_src, params_view_src,
neuron_src::Subsystem{PoissonSpikeTrain},
foreach_connected_neuron::F) where {F}
t = integrator.t
foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
# Check each downstream connection, if it's time to spike, increment the downstream neuron's S_AMPA_ext
if t conn.t_spikes
states_view_dst[:S_AMPA_ext] += 1
end
end
end


components(blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = blox.parts

issupported(::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = true
Expand Down Expand Up @@ -441,7 +371,6 @@ function blox_wiring_rule!(h,
blox_dst::Union{LIFExciCircuitBlox, LIFInhCircuitBlox},
v_src, v_dst, kwargs)
neurons_dst = components(blox_dst)

for (j, neuron_dst) enumerate(neurons_dst)
blox_wiring_rule!(h, stim, neuron_dst, only(v_src), v_dst[j], kwargs)
end
Expand Down Expand Up @@ -752,7 +681,7 @@ function get_connection(discr_src::Matrisome, discr_dst::Matrisome, kwargs)
MMConn(t_event)
end

struct MMConn{T}
struct MMConn{T} <: ConnectionRule
t_event::T
end

Expand Down Expand Up @@ -828,7 +757,7 @@ function get_connection(discr_src::TAN, discr_dst::Matrisome, kwargs)
(; conn = TAN_M_Conn(w_val, t_event), names=[name])
end

struct TAN_M_Conn
struct TAN_M_Conn <: ConnectionRule
w::Float64
t_event::Float64
end
Expand Down Expand Up @@ -1022,7 +951,7 @@ end

# #-------------------------
# PING Network
struct PINGConnection
struct PINGConnection <: ConnectionRule
w::Float64
V_E::Float64
V_I::Float64
Expand Down
Loading

0 comments on commit 53b29c2

Please sign in to comment.