Skip to content

Commit

Permalink
Merge branch 'master' into AP-create-a-learning-tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
harisorgn authored Nov 13, 2024
2 parents 57bf3b6 + daa4f4b commit 9d5c39d
Show file tree
Hide file tree
Showing 17 changed files with 468 additions and 72 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/basal_ganglia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Random.seed!(123) ## Set a random seed for reproducibility
# Blox definition
N_MSN = 100 ## number of Medium Spiny Neurons
@named msn = Striatum_MSN_Adam(N_inhib = N_MSN)
sys = get_system(msn, simplify = true)
sys = structural_simplify(get_system(msn))

## Check the system's variables (100 neurons, each with associated currents)
unknowns(sys)
Expand Down
150 changes: 150 additions & 0 deletions examples/Alan_ERNA_protocol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
## Alan Bush and Vissani's DBS protocol tried on Elie Adam's model

using Neuroblox
using CairoMakie
using StochasticDiffEq

N_MSN = 100 ## number of Medium Spiny Neurons
N_FSI = 50 ## number of Fast Spiking Interneurons
N_GPe = 80 ## number of GPe neurons
N_STN = 40 ## number of STN neurons

global_ns = :g

@named msn = Striatum_MSN_Adam(namespace=global_ns, N_inhib = N_MSN, I_bg = 1.2519*ones(N_MSN), G_M = 1.2)
@named fsi = Striatum_FSI_Adam(namespace=global_ns, N_inhib = N_FSI, I_bg = 4.511*ones(N_FSI), weight = 0.2, g_weight = 0.075)
@named gpe = GPe_Adam(namespace=global_ns, N_inhib = N_GPe)
@named stn = STN_Adam(namespace=global_ns, N_exci = N_STN)

ḡ_FSI_MSN = 0.48 ## decreased maximal conductance of FSI-MSN projection [mS/cm^-2]
ḡ_MSN_GPe = 2.5 ## maximal conductance for MSN to GPe synapses [mS/cm^-2]
ḡ_GPe_STN = 0.3 ## maximal conductance for GPe to STN synapses [mS/cm^-2]
ḡ_STN_FSI = 0.165 ## maximal conductance for STN to FSI synapses [mS/cm^-2]

density_FSI_MSN = 0.15 ## fraction of FSIs connecting to the MSN population
density_MSN_GPe = 0.33 ## fraction of MSNs connecting to the GPe population
density_GPe_STN = 0.05 ## fraction of GPe neurons connecting to the STN population
density_STN_FSI = 0.1 ## fraction of STN neurons connecting to the FSI population

weight_FSI_MSN = ḡ_FSI_MSN / (N_FSI * density_FSI_MSN) ## normalized synaptic weight
weight_MSN_GPe = ḡ_MSN_GPe / (N_MSN * density_MSN_GPe)
weight_GPe_STN = ḡ_GPe_STN / (N_GPe * density_GPe_STN)
weight_STN_FSI = ḡ_STN_FSI / (N_STN * density_STN_FSI)

g = MetaDiGraph()
add_edge!(g, fsi => msn, weight = weight_FSI_MSN, density = density_FSI_MSN)
add_edge!(g, msn => gpe, weight = weight_MSN_GPe, density = density_MSN_GPe)
add_edge!(g, gpe => stn, weight = weight_GPe_STN, density = density_GPe_STN)
add_edge!(g, stn => fsi, weight = weight_STN_FSI, density = density_STN_FSI)


frequency = 130.0
amplitude = 600.0
pulse_width = 0.066
smooth = 1e-3
pulse_start_time = 0.008
offset = -300
pulses_per_burst = 10
bursts_per_block = 12
pre_block_time = 200.0
inter_burst_time = 200.0

@named dbs = ProtocolDBS(
namespace=global_ns,
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
smooth=smooth,
offset=offset,
pulses_per_burst=pulses_per_burst,
bursts_per_block=bursts_per_block,
pre_block_time=pre_block_time,
inter_burst_time=inter_burst_time,
start_time = pulse_start_time);

t_end = get_protocol_duration(dbs)
t_end = t_end + inter_burst_time

# tspan = (0.0, t_end) # Simulation time span [ms]
tspan = (0.0, 900.0) # for testing when little RAM is available
dt = 0.001 # Time step for solving and saving [ms]

add_edge!(g, dbs => stn)

@named sys = system_from_graph(g, simplify=true)

t_ = tspan[1]:dt:tspan[2]
stimulus = dbs.stimulus.(t_)
transitions_inds = detect_transitions(t_, stimulus; atol=0.05)
transition_times = t_[transitions_inds]
transition_values = stimulus[transitions_inds]

# visualize stimulus
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "stimulus")
lines!(ax, t_, stimulus)
transition_points = scatter!(ax, transition_times, transition_values, label="transition points")
axislegend()
fig

xlims!(ax, 200, 350)
fig

xlims!(ax, 199.8, 200.3)
fig


# Creating and solving the problem
prob = SDEProblem(sys, [], tspan, [])
ens_prob = EnsembleProblem(prob)
@time ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, adaptive = true, trajectories=1, abstol = 1e-3, reltol = 1e-3, tstops = transition_times);

# visualize STN average AMPA current
stn_g = meanfield_timeseries(stn, ens_sol[1], "G")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "STN average I_AMPA ")
lines!(ax, t_, stn_g)
fig

xlims!(ax, 200, 350)
fig


stn_v = meanfield_timeseries(stn, ens_sol[1], "V")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "STN average V ")
lines!(ax, t_, stn_v)
fig

xlims!(ax, 200, 350)
fig


fsi_v = meanfield_timeseries(fsi, ens_sol[1], "V")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "FSI average V ")
lines!(ax, t_, fsi_v)
fig

xlims!(ax, 200, 350)
fig

msn_v = meanfield_timeseries(msn, ens_sol[1], "V")

fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "MSN average V ")
lines!(ax, t_, msn_v)
fig

xlims!(ax, 200, 350)
fig

gpe_v = meanfield_timeseries(gpe, ens_sol[1], "V")

fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "GPe average V ")
lines!(ax, t_, gpe_v)
fig

xlims!(ax, 200, 350)
fig
4 changes: 2 additions & 2 deletions examples/RF_learning_using_BLOX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ begin
t_trial = env.t_trial
tspan = (0, t_trial)
stim_params_in = Neuroblox.get_trial_stimulus(env)
sys = Neuroblox.get_sys(agent)
sys = Neuroblox.get_system(agent)
prob2 = remake(prob; p = merge(stim_params_in),tspan=(0,1600))

if t_warmup > 0
Expand Down Expand Up @@ -210,7 +210,7 @@ end

# ╔═╡ b7e84b20-0b80-478c-bc88-2883f80bcbb4
begin
# sys = Neuroblox.get_sys(agent)
# sys = Neuroblox.get_system(agent)
# learning_rules = agent.learning_rules
# weights = Dict{Num, Float64}()
# for w in keys(learning_rules)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ begin
t_trial = env.t_trial
tspan = (0, t_trial)
stim_params_in = Neuroblox.get_trial_stimulus(env)
sys = Neuroblox.get_sys(agent)
sys = Neuroblox.get_system(agent)
prob2 = remake(prob; p = merge(stim_params_in),tspan=(0,1600))

if t_warmup > 0
Expand Down Expand Up @@ -210,7 +210,7 @@ end

# ╔═╡ b7e84b20-0b80-478c-bc88-2883f80bcbb4
begin
# sys = Neuroblox.get_sys(agent)
# sys = Neuroblox.get_system(agent)
# learning_rules = agent.learning_rules
# weights = Dict{Num, Float64}()
# for w in keys(learning_rules)
Expand Down
4 changes: 2 additions & 2 deletions src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ function define_neurons()
end
function to_subsystem($name::$T)
states = SubsystemStates{$T}(NamedTuple{$(Expr(:tuple, QuoteNode.(s_syms)...))}(
$(Expr(:tuple, (:(float($recursive_getdefault($getproperty(Neuroblox.get_sys($name), $(QuoteNode(s)))))) for s s_syms)...))
$(Expr(:tuple, (:(float($recursive_getdefault($getproperty(Neuroblox.get_system($name), $(QuoteNode(s)))))) for s s_syms)...))
))
params = SubsystemParams{$T}(NamedTuple{$(Expr(:tuple, QuoteNode.(p_syms)...))}(
$(Expr(:tuple, (:($recursive_getdefault($getproperty(Neuroblox.get_sys($name), $(QuoteNode(s))))) for s p_syms)...))
$(Expr(:tuple, (:($recursive_getdefault($getproperty(Neuroblox.get_system($name), $(QuoteNode(s))))) for s p_syms)...))
))
Subsystem(states, params)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, detect_transitions, compute_transition_times, compute_transition_values
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration
export BandPassFilterBlox
export OUBlox, OUCouplingBlox
export phase_inter, phase_sin_blox, phase_cos_blox
Expand Down
6 changes: 3 additions & 3 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function get_bloxs(g::MetaDiGraph)
return bs
end

get_sys(g::MetaDiGraph) = get_sys.(get_bloxs(g))
get_system(g::MetaDiGraph) = get_system.(get_bloxs(g))

get_dynamics_bloxs(blox) = [blox]
get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox)
Expand Down Expand Up @@ -185,7 +185,7 @@ end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}=Num[];
name, t_block=missing, simplify=true, simplify_kwargs...)
blox_syss = get_sys(g)
blox_syss = get_system(g)
connection_eqs = get_equations_with_state_lhs(bc)

discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block))
Expand All @@ -200,7 +200,7 @@ end


function system_from_parts(parts::AbstractVector; name)
return compose(System(Equation[], t; name), get_sys.(parts))
return compose(System(Equation[], t; name), get_system.(parts))
end

function action_selection_from_graph(g::MetaDiGraph)
Expand Down
28 changes: 28 additions & 0 deletions src/adjacency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@ function get_adjacency(g::MetaDiGraph)
return get_adjacency(bc)
end

function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProblem)
A = get_adjacency(bc)
names = A.names
mat = A.matrix

I, J, _ = findnz(mat)

ps = String.(Symbol.(parameters(sys)))

w_idxs = map(zip(I,J)) do (src_idx, dst_idx)
w = join(["w", names[src_idx], names[dst_idx]], "_")
findfirst(p -> p == w, ps)
end

W = getp(prob, parameters(sys)[w_idxs])(prob)
S = sparse(I, J, W, size(mat)...)

return AdjacencyMatrix(S, names)
end

function get_adjacency(agent::Agent)
prob = agent.problem
sys = get_system(agent)
bc = get_connector(agent)

return get_adjacency(bc, sys, prob)
end

function adjmatrixfromdigraph(g::MetaDiGraph)
myadj = map(Num, adjacency_matrix(g))
for edge in edges(g)
Expand Down
Loading

0 comments on commit 9d5c39d

Please sign in to comment.