Skip to content

Commit

Permalink
Generic Connector dispatch (#505)
Browse files Browse the repository at this point in the history
* reexport t and D from MTK

* add alias for `AbstractNeuronBlox` for brevity

* export functions for easier custom blox+connector definitions

* add generic spike affect and restructure existing one for LIF neurons

* add `system(blox)` to return (simplified) system

* option to choose between namespaced or not input equations in getter function

* add generic Connector dispatch

* add getter functions for spike affect states and params

* remove abstract neuron -> neuron Connector dispatch

* change values in `spike_affects` to `Vector{Tuple}` instead of `Tuple{Vector,Vector}`

* move functional spike affect to `Neurographs.jl`

* add generic fallback for discrete callbacks

* fix typo

* add `weight` argument to `connection_spike_affects`
  • Loading branch information
harisorgn authored Dec 22, 2024
1 parent 355d24d commit 89cadb8
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 84 deletions.
11 changes: 6 additions & 5 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ using OhMyThreads: tmapreduce

using Reexport
@reexport using ModelingToolkit
const t = ModelingToolkit.t_nounits
const D = ModelingToolkit.D_nounits
export t, D
@reexport using ModelingToolkit: ModelingToolkit.t_nounits as t, ModelingToolkit.D_nounits as D

@reexport using ModelingToolkitStandardLibrary.Blocks
@reexport import Graphs: add_edge!
@reexport using MetaGraphs: MetaDiGraph
Expand Down Expand Up @@ -69,6 +68,7 @@ abstract type StimulusBlox <: AbstractBlox end
abstract type ObserverBlox end # not AbstractBlox since it should not show up in the GUI
abstract type AbstractPINGNeuron <: AbstractNeuronBlox end

const Neuron = AbstractNeuronBlox

# we define these in neural_mass.jl
# abstract type HarmonicOscillatorBlox <: NeuralMassBlox end
Expand Down Expand Up @@ -223,6 +223,7 @@ function __init__()
end


export Neuron
export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory,
hh_neuron_inhibitory, van_der_pol, Generic2dOscillator
export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron,
Expand All @@ -243,7 +244,7 @@ export powerspectrum, complexwavelet, bandpassfilter, hilberttransform, phaseang
export learningrate, ControlError
export vecparam, csd_Q, setup_sDCM, run_sDCM_iteration!, defaultprior
export simulate, random_initials
export system_from_graph, graph_delays
export system_from_graph, system, graph_delays
export create_adjacency_edges!, adjmatrixfromdigraph
export get_namespaced_sys, nameof
export run_experiment!, run_trial!
Expand All @@ -256,5 +257,5 @@ export meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, fr
export powerspectrumplot, powerspectrumplot!, welch_pgram, periodogram, hanning, hamming
export detect_spikes, mean_firing_rate, firing_rate
export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons, get_exci_neurons, get_inh_neurons, get_neuron_color
export AdjacencyMatrix, Connector, connection_rule, connection_equation
export AdjacencyMatrix, Connector, connection_rule, connection_equations, connection_spike_affects, connection_learning_rules, connection_callbacks
end
84 changes: 70 additions & 14 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,37 +81,93 @@ function graph_delays(g::MetaDiGraph)
return conn.delay
end

generate_discrete_callbacks(blox, ::Connector; t_block = missing) = []

function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::Connector; t_block = missing)
sa = spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)

states_affect, params_affect = get(sa, name_blox, (Num[], Num[]))

function make_unique_param_pairs(params)
# HACK : MTK will complain if the parameter vector passed to a functional affect
# contains non-unique parameters. Here we sometimes need to pass duplicate parameters that
# affect states in the loop in LIF_spike_affect! .
# Passing parameters with Symbol aliases bypasses this issue and allows for duplicates.
affect_pairs = if unique(params_affect) == length(params_affect)
[p => Symbol(p) for p in params_affect]
param_pairs = if unique(params) == length(params)
[p => Symbol(p) for p in params]
else
map(params_affect) do p
if count(pi -> Symbol(pi) == Symbol(p), params_affect) > 1
map(params) do p
if count(pi -> Symbol(pi) == Symbol(p), params) > 1
p => Symbol(p, "_$(rand(1:1000))")
else
p => Symbol(p)
end
end
end

return param_pairs
end

function generic_spike_affect!(integ, u, p, ctx)
N = length(u)
for i in Base.OneTo(N)
integ.u[u[i]] += integ.p[p[i]]
end
end

function LIF_spike_affect!(integ, u, p, ctx)
integ.u[u[1]] = integ.p[p[1]]

t_refract_end = integ.t + integ.p[p[2]]
integ.p[p[3]] = t_refract_end

integ.p[p[4]] = 1

SciMLBase.add_tstop!(integ, t_refract_end)

c = 1
for i in eachindex(u)[2:end]
integ.u[u[i]] += integ.p[p[c + 4]]
c += 1
end
end

generate_discrete_callbacks(blox, ::Connector; t_block = missing) = []

function generate_discrete_callbacks(blox::AbstractNeuronBlox, bc::Connector; t_block = missing)
sa = spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)

states_affect = get_states_spikes_affect(sa, name_blox)
params_affect = get_params_spikes_affect(sa, name_blox)

if isempty(states_affect) && isempty(params_affect)
return []
else
param_pairs = make_unique_param_pairs(params_affect)

cb = (sys.V > sys.θ) => (
generic_spike_affect!,
states_affect,
param_pairs,
[],
nothing
)

return cb
end
end

function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::Connector; t_block = missing)
sa = spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)

states_affect = get_states_spikes_affect(sa, name_blox)
params_affect = get_params_spikes_affect(sa, name_blox)

param_pairs = make_unique_param_pairs(params_affect)

ps = vcat([
sys.V_reset => Symbol(sys.V_reset),
sys.t_refract_duration => Symbol(sys.t_refract_duration),
sys.t_refract_end => Symbol(sys.t_refract_end),
sys.is_refractory => Symbol(sys.is_refractory)
], affect_pairs)
], param_pairs)

cb = (sys.V > sys.θ) => (
LIF_spike_affect!,
Expand Down
39 changes: 32 additions & 7 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
function Base.getproperty(b::Union{AbstractNeuronBlox, NeuralMassBlox}, name::Symbol)
# TO DO : Some of the fields below besides `odesystem` and `namespace`
# are redundant and we should clean them up.
if (name === :odesystem) || (name === :namespace) || (name === :params) || (name === :output) || (name === :voltage)
return getfield(b, name)
else
return Base.getproperty(Neuroblox.get_namespaced_sys(b), name)
end
end

"""
function paramscoping(;tunable=true, kwargs...)
Expand Down Expand Up @@ -84,6 +94,15 @@ get_system(blox) = blox.odesystem
get_system(sys::AbstractODESystem) = sys
get_system(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name)

function system(blox::AbstractBlox; simplify=true)
sys = get_system(blox)
eqs = get_input_equations(blox; namespaced=false)

csys = System(vcat(equations(sys), eqs), t, unknowns(sys), parameters(sys); name = nameof(sys))

return simplify ? structural_simplify(csys) : csys
end

function get_namespaced_sys(blox)
sys = get_system(blox)
System(
Expand Down Expand Up @@ -141,20 +160,26 @@ end
which holds a `Connector` object with all relevant connections
from lower levels and this level.
"""
function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
function get_input_equations(blox::Union{AbstractBlox, ObserverBlox}; namespaced=true)
sys = get_system(blox)
sys_eqs = equations(sys)

inps = inputs(sys)
filter!(inp -> isnothing(find_eq(sys_eqs, inp)), inps)

if !isempty(inps)
eqs = map(inps) do inp
namespace_equation(
inp ~ 0,
sys,
namespaced_name(inner_namespaceof(blox), nameof(blox))
)
eqs = if namespaced
map(inps) do inp
namespace_equation(
inp ~ 0,
sys,
namespaced_name(inner_namespaceof(blox), nameof(blox))
)
end
else
map(inps) do inp
inp ~ 0
end
end

return eqs
Expand Down
Loading

0 comments on commit 89cadb8

Please sign in to comment.