Skip to content

Commit

Permalink
Replace output field with MTK.outputs function and cleanup (#507)
Browse files Browse the repository at this point in the history
* reexport t and D from MTK

* add alias for `AbstractNeuronBlox` for brevity

* export functions for easier custom blox+connector definitions

* add generic spike affect and restructure existing one for LIF neurons

* add `system(blox)` to return (simplified) system

* option to choose between namespaced or not input equations in getter function

* add generic Connector dispatch

* add getter functions for spike affect states and params

* remove abstract neuron -> neuron Connector dispatch

* change values in `spike_affects` to `Vector{Tuple}` instead of `Tuple{Vector,Vector}`

* move functional spike affect to `Neurographs.jl`

* add generic fallback for discrete callbacks

* fix typo

* add `weight` argument to `connection_spike_affects`

* import & export `outputs`

* extend `MTK.outputs` util fuction

* remove `output, jcn, voltage` fields, only keep metadata

* apply `outputs` function instead of field access

* rename `odesystem` field to `system`

* wrap `outputs` in `Num`

* capitalize voltage state `V`

* capitalize all occurences of `v` state

* actually capitalize all `v`

* try fixing the GD `output` error

* actually try fixing the GD test

* add `namespaced` kwarg for `outputs`

* fix output tag in PING neurons

* default to `namespaced=false` for shorter `outputs` return

* also expand equations, inputs, unknowns, parameters from MTK to work with blox

* use namespaced system in system util funcs
  • Loading branch information
harisorgn authored Dec 24, 2024
1 parent 89cadb8 commit 24d304f
Show file tree
Hide file tree
Showing 37 changed files with 243 additions and 245 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/neural_assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ add_edge!(g, ASC1 => VAC, weight=44)
add_edge!(g, ASC1 => AC, weight=44)
add_edge!(g, VAC => AC, weight=3, density=0.08)

## define odesystem and solve
## define system and solve
sys = system_from_graph(g, name=global_namespace)
prob = ODEProblem(sys, [], (0.0, 1000), []) ## tspan = (0,1000)
sol = solve(prob, Vern7(), saveat=0.1);
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/parkinsons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ sol = solve(prob, Tsit5(), saveat=0.1); ## Solve the problem and save every 0.1m
# Let's interrogate the solution to see what we have. For the purposes of this tutorial, we'll focus on the striatal oscillations.
# In this simple model, we should see relatively sharp on/off transitions in the striatal populations. To test this, let's use [Symbolic Indexing](https://docs.sciml.ai/SymbolicIndexingInterface/stable/usage/) to access the states we're interested in: the ``y`` state of the D1 neuron population.

idx_func = ModelingToolkit.getu(sol, D1.odesystem.y); ## gets the state index of the D1 neuron population in the solution object
idx_func = ModelingToolkit.getu(sol, D1.system.y); ## gets the state index of the D1 neuron population in the solution object

# Now use this indexing function to plot the solution in a Makie plot ([read more about Makie in the docs](https://docs.makie.org/stable/tutorials/getting-started)).

Expand Down
2 changes: 1 addition & 1 deletion examples/RF_learning_using_BLOX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ end
# ╔═╡ d80b3bf7-24d8-4395-8903-de974e6445f9
begin
#extract membrane voltages of every neuron
getsys=agent.odesystem;
getsys=agent.system;
st=unknowns(getsys)
vlist=Int64[]
for ii = 1:length(st)
Expand Down
4 changes: 2 additions & 2 deletions examples/cortical_single.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ begin
end

# ╔═╡ 84410723-dcb4-46eb-bac4-871aac88cf8c
#construct odesystems for each neuron and then construct the composite odesystem for entire network using function synaptic_network()
#construct odesystems for each neuron and then construct the composite system for entire network using function synaptic_network()
begin

nrn_network=[]
Expand Down Expand Up @@ -328,7 +328,7 @@ prob = ODEProblem(syn_net, [], (0, simtime));
sol = solve(prob,Vern7(),saveat = 0.1)#,saveat = 0.1,reltol=1e-4,abstol=1e-4);

# ╔═╡ 263b26c2-b19b-4889-8e8c-fa3aef952649
# this extracts indices for input current amplitudes I_in and for weight matrix elements adj from the parameters of the entire odesystem. These are usefull for
# this extracts indices for input current amplitudes I_in and for weight matrix elements adj from the parameters of the entire system. These are usefull for
#changing the input currents and weights of already existing connections by remaking #the odeprob
begin
indexof(sym,syms) = findfirst(isequal(sym),syms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ end
# ╔═╡ d80b3bf7-24d8-4395-8903-de974e6445f9
begin
#extract membrane voltages of every neuron
getsys=agent.odesystem;
getsys=agent.system;
st=unknowns(getsys)
vlist=Int64[]
for ii = 1:length(st)
Expand Down
2 changes: 1 addition & 1 deletion examples/makiepowerplot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sol = solve(prob, Vern7(), saveat=0.01)
fss = powerspectrumplot(cb, sol)

@named msn = Striatum_MSN_Adam();
sys = structural_simplify(msn.odesystem)
sys = structural_simplify(msn.system)
prob = SDEProblem(sys, [], (0.0, 5500), [])
sol = solve(prob, RKMil(), dt=0.05, saveat=0.01)

Expand Down
2 changes: 1 addition & 1 deletion examples/rasterplot_firing_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using CairoMakie # due to a bug in CairoMakie, we need to use [email protected]


@named msn = Striatum_MSN_Adam(I_bg = 1.172*ones(100), σ = 0.11);
sys = structural_simplify(msn.odesystem)
sys = structural_simplify(msn.system)
prob = SDEProblem(sys, [], (0.0, 5500.0), [])
sol = solve(prob, RKMil(); dt=0.05, saveat=0.05)

Expand Down
2 changes: 1 addition & 1 deletion examples/signals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ begin
@named Str2 = jansen_ritC=0.0022, H=20, λ=300, r=0.3)
@parameters phase_input=0 ampl=1

sys = [Str2.odesystem]
sys = [Str2.system]
eqs = [sys[1].jcn ~ ampl*phase_input]
@named phase_system = ODESystem(eqs,systems=sys)
phase_system_simpl = structural_simplify(phase_system)
Expand Down
2 changes: 1 addition & 1 deletion examples/wilson_cowan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

# ╔═╡ 9f3d6efd-a884-449c-b4b8-42b0b435e245
begin
sys = [WC.odesystem]
sys = [WC.system]
eqs = [sys[1].jcn ~ 0.0, sys[1].P ~ 0.0]
@named WC_sys = ODESystem(eqs,systems=sys)
end
Expand Down
2 changes: 1 addition & 1 deletion examples/wilson_cowan2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
# ╔═╡ 9f3d6efd-a884-449c-b4b8-42b0b435e245
begin
blox = [WC1,WC2]
sys = [b.odesystem for b in blox]
sys = [b.system for b in blox]
connect = [b.connector for b in blox]
end

Expand Down
2 changes: 1 addition & 1 deletion examples/wilson_cowan2s.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end
# ╔═╡ 9f3d6efd-a884-449c-b4b8-42b0b435e245
begin
blox = [WC1,WC2]
sys = [b.odesystem for b in blox]
sys = [b.system for b in blox]
connect = [b.connector for b in blox]
end

Expand Down
4 changes: 2 additions & 2 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function generate_gap_weight_param(blox_src, blox_dst, kwargs)
name = nameof(w)
else
w_val = w
name = Symbol("g_w_$(nameof(blox_src.odesystem))_$(nameof(blox_dst.odesystem))")
name = Symbol("g_w_$(nameof(blox_src.system))_$(nameof(blox_dst.system))")
end
(;w_val, name)
end
Expand Down Expand Up @@ -476,7 +476,7 @@ end
# name_presyn = namespaced_nameof(neuron_presyn)
# # Check names to avoid recurrent connections between the same neuron
# if (name_postsyn != name_presyn) && rand(rng, dist)
# w_name = Symbol("w_$(nameof(neuron_presyn.odesystem))_$(nameof(neuron_postsyn.odesystem))")
# w_name = Symbol("w_$(nameof(neuron_presyn.system))_$(nameof(neuron_postsyn.system))")
# (; conn) = get_connection(neuron_presyn, neuron_postsyn, kwargs)
# add_edge!(h, v_dst[j], v_src[i], Dict(:conn => conn, :names => [w_name]))
# end
Expand Down
18 changes: 10 additions & 8 deletions src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function define_neurons()
(:pinhib, :PINGNeuronInhib)
]
sys = getproperty(Neuroblox, T)(;name)
system = structural_simplify(sys.odesystem; fully_determined=false)
system = structural_simplify(sys.system; fully_determined=false)
params = get_ps(system)
t = Symbol(get_iv(system))

Expand Down Expand Up @@ -112,8 +112,10 @@ function define_neurons()
end
end

if hasproperty(sys, :output)
output_sym = hasproperty(sys.output.val, :f) ? Symbol(sys.output.val.f) : Symbol(sys.output.val)
outs = Neuroblox.outputs(sys; namespaced=false)
if length(outs) == 1
out = only(outs)
output_sym = hasproperty(out.val, :f) ? Symbol(out.val.f) : Symbol(out.val)
@eval output(s::Subsystem{$T}) = s.$output_sym
end

Expand Down Expand Up @@ -168,7 +170,7 @@ GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain})
issupported(::KuramotoOscillator) = true
function to_subsystem(o::KuramotoOscillator)
states = SubsystemStates{KuramotoOscillator}((;θ=0.0,))
params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.odesystem.ω), ζ=getdefault(o.odesystem.ζ)))
params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ)))
Subsystem(states, params)
end

Expand Down Expand Up @@ -238,8 +240,8 @@ function GraphDynamics.apply_subsystem_differential!(_, s::Subsystem{TAN}, jcn,
end
GraphDynamics.subsystem_differential_requires_inputs(::Type{TAN}) = false
function to_subsystem(s::TAN)
κ = getdefault(s.odesystem.κ)
λ = getdefault(s.odesystem.λ)
κ = getdefault(s.system.κ)
λ = getdefault(s.system.λ)
states = SubsystemStates{TAN, Float64, @NamedTuple{}}((;))
params = SubsystemParams{TAN}((; κ, λ))
#TODO: support observed variable R = min(κ, κ/(λ*jcn + sqrt(eps())))
Expand All @@ -259,8 +261,8 @@ GraphDynamics.subsystem_differential_requires_inputs(::Type{SNc}) = false
function to_subsystem(s::SNc)
(;N_time_blocks, κ_DA, DA_reward) = s

κ = getdefault(s.odesystem.κ)
λ = getdefault(s.odesystem.λ)
κ = getdefault(s.system.κ)
λ = getdefault(s.system.λ)
states = SubsystemStates{SNc, Float64, @NamedTuple{}}((;))
params = SubsystemParams{TAN}((;κ_DA, N_time_blocks, DA_reward, λ_DA, t_event=t_event+sqrt(eps(t_event)), jcn_=0.0))
#TODO: support observed variables R ~ min(κ_DA, κ_DA/(λ_DA*jcn + sqrt(eps())))
Expand Down
8 changes: 4 additions & 4 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ using Distributions

using SciMLBase: SciMLBase, AbstractSolution, solve, remake


using ModelingToolkit: get_namespace, get_systems, isparameter,
renamespace, namespace_equation, namespace_parameters, namespace_expr,
AbstractODESystem, VariableTunable, getp
import ModelingToolkit: equations, inputs, nameof, getdescription
import ModelingToolkit: equations, inputs, outputs, unknowns, parameters, discrete_events, nameof, getdescription

using Symbolics: @register_symbolic, getdefaultval, get_variables

Expand Down Expand Up @@ -132,9 +131,9 @@ function simulate(sys::ODESystem, u0, timespan, p, solver = AutoVern7(Rodas4());
end

function simulate(blox::CorticalBlox, u0, timespan, p, solver = AutoVern7(Rodas4()); kwargs...)
prob = ODEProblem(blox.odesystem, u0, timespan, p)
prob = ODEProblem(blox.system, u0, timespan, p)
sol = solve(prob, solver; kwargs...) # pass keyword arguments to solver
statesV = [s for s in unknowns(blox.odesystem) if contains(string(s),"V")]
statesV = [s for s in unknowns(blox.system) if contains(string(s),"V")]
vsol = sol[statesV]
vmean = vec(mean(hcat(vsol...),dims=2))
df = DataFrame(sol)
Expand Down Expand Up @@ -258,4 +257,5 @@ export powerspectrumplot, powerspectrumplot!, welch_pgram, periodogram, hanning,
export detect_spikes, mean_firing_rate, firing_rate
export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons, get_exci_neurons, get_inh_neurons, get_neuron_color
export AdjacencyMatrix, Connector, connection_rule, connection_equations, connection_spike_affects, connection_learning_rules, connection_callbacks
export inputs, outputs, equations, unknowns, parameters, discrete_events
end
8 changes: 4 additions & 4 deletions src/blox/DBS_Model_Blox_Adam_Brown.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end
struct Striatum_MSN_Adam <: CompositeBlox
namespace
parts
odesystem
system
connector
mean
connection_matrix
Expand Down Expand Up @@ -124,7 +124,7 @@ end
struct Striatum_FSI_Adam <: CompositeBlox
namespace
parts
odesystem
system
connector
mean
connection_matrix
Expand Down Expand Up @@ -205,7 +205,7 @@ end
struct GPe_Adam <: CompositeBlox
namespace
parts
odesystem
system
connector
mean
connection_matrix
Expand Down Expand Up @@ -274,7 +274,7 @@ end
struct STN_Adam <: CompositeBlox
namespace
parts
odesystem
system
connector
mean
connection_matrix
Expand Down
2 changes: 1 addition & 1 deletion src/blox/DBS_sources.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Defines a DBS (Deep Brain Stimulation) stimulus that can be either continuous or burst protocol
struct DBS <: StimulusBlox
params::Vector{Num}
odesystem::ODESystem
system::ODESystem
namespace::Union{Symbol, Nothing}
stimulus::Function
end
Expand Down
44 changes: 37 additions & 7 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function Base.getproperty(b::Union{AbstractNeuronBlox, NeuralMassBlox}, name::Symbol)
# TO DO : Some of the fields below besides `odesystem` and `namespace`
# TO DO : Some of the fields below besides `system` and `namespace`
# are redundant and we should clean them up.
if (name === :odesystem) || (name === :namespace) || (name === :params) || (name === :output) || (name === :voltage)
if (name === :system) || (name === :namespace) || (name === :params)
return getfield(b, name)
else
return Base.getproperty(Neuroblox.get_namespaced_sys(b), name)
Expand Down Expand Up @@ -90,7 +90,7 @@ function get_discrete_parts(b::Union{AbstractComponent, CompositeBlox})
mapreduce(x -> get_discrete_parts(x), vcat, b.parts)
end

get_system(blox) = blox.odesystem
get_system(blox) = blox.system
get_system(sys::AbstractODESystem) = sys
get_system(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name)

Expand All @@ -105,12 +105,14 @@ end

function get_namespaced_sys(blox)
sys = get_system(blox)

System(
equations(sys),
only(independent_variables(sys)),
unknowns(sys),
parameters(sys);
name = namespaced_nameof(blox)
name = namespaced_nameof(blox),
discrete_events = discrete_events(sys)
)
end

Expand Down Expand Up @@ -147,6 +149,28 @@ function find_eq(eqs::Union{AbstractVector{<:Equation}, Equation}, lhs)
end
end

function ModelingToolkit.outputs(blox::AbstractBlox; namespaced=false)
sys = get_namespaced_sys(blox)

# Wrap in Num for convenience when checking `isa Num` to resolve delay or no delay connection.
return namespaced ? Num.(namespace_expr.(ModelingToolkit.outputs(sys), Ref(sys))) : Num.(ModelingToolkit.outputs(sys))
end

function ModelingToolkit.inputs(blox::AbstractBlox; namespaced=false)
sys = get_namespaced_sys(blox)

# Wrap in Num for convenience when checking `isa Num` to resolve delay or no delay connection.
return namespaced ? Num.(namespace_expr.(ModelingToolkit.inputs(sys), Ref(sys))) : Num.(ModelingToolkit.inputs(sys))
end

ModelingToolkit.equations(blox::AbstractBlox) = ModelingToolkit.equations(get_namespaced_sys(blox))

ModelingToolkit.discrete_events(blox::AbstractBlox) = ModelingToolkit.discrete_events(get_namespaced_sys(blox))

ModelingToolkit.unknowns(blox::AbstractBlox) = ModelingToolkit.unknowns(get_namespaced_sys(blox))

ModelingToolkit.parameters(blox::AbstractBlox) = ModelingToolkit.parameters(get_namespaced_sys(blox))

"""
Returns the equations for all input variables of a system,
assuming they have a form like : `sys.input_variable ~ ...`
Expand Down Expand Up @@ -271,7 +295,7 @@ function get_learning_rule(kwargs, name_src, name_dest)
end

function get_weights(agent::Agent, blox_out, blox_in)
ps = parameters(agent.odesystem)
ps = parameters(agent.system)
pv = agent.problem.p
map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps))

Expand Down Expand Up @@ -382,8 +406,14 @@ function get_connection_rule(kwargs, bloxout, bloxin, w)

# Logic based on connection rule type
if isequal(cr, "basic")
x = namespace_expr(bloxout.output, sys_out)
rhs = x*w
outs = outputs(bloxout; namespaced=true)
if !isempty(outs)
x = first(outs)
rhs = x*w
length(outs) > 1 && @warn "Blox $name_blox1 has more than one outputs. Defaulting to output=$x"
else
error("Blox $name_blox1 has no outputs. Please assign [output=true] to the variables you want to use as outputs or write a dispatch for connection_equations.")
end
elseif isequal(cr, "psp")
rhs = w*sys_out.G*(sys_out.E_syn - sys_in.V)
else
Expand Down
8 changes: 3 additions & 5 deletions src/blox/canonicalmicrocircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ Jansen-Rit model block for canonical micro circuit, analogous to the implementat
"""
mutable struct JansenRitSPM12 <: NeuralMassBlox
params
output
jcn
odesystem
system
namespace
function JansenRitSPM12(;name, namespace=nothing, τ=1.0, r=2.0/3.0)
p = paramscoping=τ, r=r)
Expand All @@ -18,14 +16,14 @@ mutable struct JansenRitSPM12 <: NeuralMassBlox
D(y) ~ -x/*τ) + jcn/τ]

sys = System(eqs, t, name=name)
new(p, sts[1], sts[3], sys, namespace)
new(p, sys, namespace)
end
end

mutable struct CanonicalMicroCircuitBlox <: CompositeBlox
namespace
parts
odesystem
system
connector

function CanonicalMicroCircuitBlox(;name, namespace=nothing, τ_ss=0.002, τ_sp=0.002, τ_ii=0.016, τ_dp=0.028, r_ss=2.0/3.0, r_sp=2.0/3.0, r_ii=2.0/3.0, r_dp=2.0/3.0)
Expand Down
Loading

0 comments on commit 24d304f

Please sign in to comment.