Skip to content

Commit

Permalink
DBS protocol (#481)
Browse files Browse the repository at this point in the history
* initial DBS protocol for getting ERNAs

* unify DBS and DBSProtocol into single type

* add DBS protocol example

* add docstrings for `DBS` and `protocol_dbs` constructors

* add a tunable option to `paramscoping` function

* set DBS and protocol_dbs parameters as not tunable

* change constructor name to ProtocolDBS

* change input frequency units to Hz

* add tests for ProtocolDBS

* adjust `frequency` in `compute_transition_times`

---------

Co-authored-by: haris organtzidis <[email protected]>
  • Loading branch information
gabrevaya and harisorgn authored Nov 8, 2024
1 parent 529ba66 commit 6f59b05
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 32 deletions.
150 changes: 150 additions & 0 deletions examples/Alan_ERNA_protocol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
## Alan Bush and Vissani's DBS protocol tried on Elie Adam's model

using Neuroblox
using CairoMakie
using StochasticDiffEq

N_MSN = 100 ## number of Medium Spiny Neurons
N_FSI = 50 ## number of Fast Spiking Interneurons
N_GPe = 80 ## number of GPe neurons
N_STN = 40 ## number of STN neurons

global_ns = :g

@named msn = Striatum_MSN_Adam(namespace=global_ns, N_inhib = N_MSN, I_bg = 1.2519*ones(N_MSN), G_M = 1.2)
@named fsi = Striatum_FSI_Adam(namespace=global_ns, N_inhib = N_FSI, I_bg = 4.511*ones(N_FSI), weight = 0.2, g_weight = 0.075)
@named gpe = GPe_Adam(namespace=global_ns, N_inhib = N_GPe)
@named stn = STN_Adam(namespace=global_ns, N_exci = N_STN)

ḡ_FSI_MSN = 0.48 ## decreased maximal conductance of FSI-MSN projection [mS/cm^-2]
ḡ_MSN_GPe = 2.5 ## maximal conductance for MSN to GPe synapses [mS/cm^-2]
ḡ_GPe_STN = 0.3 ## maximal conductance for GPe to STN synapses [mS/cm^-2]
ḡ_STN_FSI = 0.165 ## maximal conductance for STN to FSI synapses [mS/cm^-2]

density_FSI_MSN = 0.15 ## fraction of FSIs connecting to the MSN population
density_MSN_GPe = 0.33 ## fraction of MSNs connecting to the GPe population
density_GPe_STN = 0.05 ## fraction of GPe neurons connecting to the STN population
density_STN_FSI = 0.1 ## fraction of STN neurons connecting to the FSI population

weight_FSI_MSN = ḡ_FSI_MSN / (N_FSI * density_FSI_MSN) ## normalized synaptic weight
weight_MSN_GPe = ḡ_MSN_GPe / (N_MSN * density_MSN_GPe)
weight_GPe_STN = ḡ_GPe_STN / (N_GPe * density_GPe_STN)
weight_STN_FSI = ḡ_STN_FSI / (N_STN * density_STN_FSI)

g = MetaDiGraph()
add_edge!(g, fsi => msn, weight = weight_FSI_MSN, density = density_FSI_MSN)
add_edge!(g, msn => gpe, weight = weight_MSN_GPe, density = density_MSN_GPe)
add_edge!(g, gpe => stn, weight = weight_GPe_STN, density = density_GPe_STN)
add_edge!(g, stn => fsi, weight = weight_STN_FSI, density = density_STN_FSI)


frequency = 130.0
amplitude = 600.0
pulse_width = 0.066
smooth = 1e-3
pulse_start_time = 0.008
offset = -300
pulses_per_burst = 10
bursts_per_block = 12
pre_block_time = 200.0
inter_burst_time = 200.0

@named dbs = ProtocolDBS(
namespace=global_ns,
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
smooth=smooth,
offset=offset,
pulses_per_burst=pulses_per_burst,
bursts_per_block=bursts_per_block,
pre_block_time=pre_block_time,
inter_burst_time=inter_burst_time,
start_time = pulse_start_time);

t_end = get_protocol_duration(dbs)
t_end = t_end + inter_burst_time

# tspan = (0.0, t_end) # Simulation time span [ms]
tspan = (0.0, 900.0) # for testing when little RAM is available
dt = 0.001 # Time step for solving and saving [ms]

add_edge!(g, dbs => stn)

@named sys = system_from_graph(g, simplify=true)

t_ = tspan[1]:dt:tspan[2]
stimulus = dbs.stimulus.(t_)
transitions_inds = detect_transitions(t_, stimulus; atol=0.05)
transition_times = t_[transitions_inds]
transition_values = stimulus[transitions_inds]

# visualize stimulus
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "stimulus")
lines!(ax, t_, stimulus)
transition_points = scatter!(ax, transition_times, transition_values, label="transition points")
axislegend()
fig

xlims!(ax, 200, 350)
fig

xlims!(ax, 199.8, 200.3)
fig


# Creating and solving the problem
prob = SDEProblem(sys, [], tspan, [])
ens_prob = EnsembleProblem(prob)
@time ens_sol = solve(ens_prob, RKMil(); dt=dt, saveat=dt, adaptive = true, trajectories=1, abstol = 1e-3, reltol = 1e-3, tstops = transition_times);

# visualize STN average AMPA current
stn_g = meanfield_timeseries(stn, ens_sol[1], "G")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "STN average I_AMPA ")
lines!(ax, t_, stn_g)
fig

xlims!(ax, 200, 350)
fig


stn_v = meanfield_timeseries(stn, ens_sol[1], "V")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "STN average V ")
lines!(ax, t_, stn_v)
fig

xlims!(ax, 200, 350)
fig


fsi_v = meanfield_timeseries(fsi, ens_sol[1], "V")
fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "FSI average V ")
lines!(ax, t_, fsi_v)
fig

xlims!(ax, 200, 350)
fig

msn_v = meanfield_timeseries(msn, ens_sol[1], "V")

fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "MSN average V ")
lines!(ax, t_, msn_v)
fig

xlims!(ax, 200, 350)
fig

gpe_v = meanfield_timeseries(gpe, ens_sol[1], "V")

fig = Figure();
ax = Axis(fig[1,1]; xlabel = "time (ms)", ylabel = "GPe average V ")
lines!(ax, t_, gpe_v)
fig

xlims!(ax, 200, 350)
fig
2 changes: 1 addition & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, detect_transitions, compute_transition_times, compute_transition_values
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration
export BandPassFilterBlox
export OUBlox, OUCouplingBlox
export phase_inter, phase_sin_blox, phase_cos_blox
Expand Down
199 changes: 172 additions & 27 deletions src/blox/DBS_sources.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,164 @@
# Defines a DBS (Deep Brain Stimulation) stimulus that can be either continuous or burst protocol
struct DBS <: StimulusBlox
params::Vector{Num}
odesystem::ODESystem
namespace::Union{Symbol, Nothing}
stimulus::Function
end

"""
DBS(; name, namespace=nothing, frequency=130.0, amplitude=2.5, pulse_width=0.066,
offset=0.0, start_time=0.0, smooth=1e-4)
Create a continuous deep brain stimulation (DBS) stimulus with regular pulses.
Arguments:
- name: Name given to ODESystem object within the blox
- namespace: Additional namespace above name if needed for inheritance
- frequency: Pulse frequency in Hz
- amplitude: Pulse amplitude in arbitrary units
- pulse_width: Duration of each pulse in ms
- offset: Baseline value of the signal between pulses
- start_time: Time delay before stimulation begins in ms
- smooth: Smoothing parameter for pulse transitions, set to 0 for sharp transitions
Returns a DBS stimulus blox that outputs square pulses with specified parameters.
"""
function DBS(;
name,
namespace=nothing,
frequency=130.0,
amplitude=2.5,
pulse_width=0.066,
offset=0.0,
start_time=0.0,
smooth=1e-4
)
# Ensure consistent numeric types for all parameters
frequency, amplitude, pulse_width, offset, start_time, smooth =
promote(frequency, amplitude, pulse_width, offset, start_time, smooth)

# Convert to kHz (to match interal time in ms)
frequency_khz = frequency/1000.0

# Create stimulus function based on smooth/non-smooth square wave
stimulus = if smooth == 0
t -> square(t, frequency_khz, amplitude, offset, start_time, pulse_width)
else
t -> square(t, frequency_khz, amplitude, offset, start_time, pulse_width, smooth)
end

function DBS(;
name,
namespace=nothing,
frequency=130.0,
amplitude=100.0,
pulse_width=0.15,
offset=0.0,
start_time=0.0,
smooth=1e-4
p = paramscoping(
tunable=false;
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
offset=offset,
start_time=start_time
)

if smooth == 0
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width)
else
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width, smooth)
end

p = paramscoping(
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
offset=offset,
start_time=start_time
)
sts = @variables u(t) [output = true]
eqs = [u ~ stimulus(t)]
sys = System(eqs, t, sts, p; name=name)

DBS(p, sys, namespace, stimulus)
end

"""
ProtocolDBS(; name, namespace=nothing, frequency=130.0, amplitude=2.5,
pulse_width=0.066, offset=0.0, start_time=0.0, smooth=1e-4,
pulses_per_burst=10, bursts_per_block=12,
pre_block_time=200.0, inter_burst_time=200.0)
Create a deep brain stimulation (DBS) stimulus consisting of a block of pulse bursts.
Arguments:
- name: Name given to ODESystem object within the blox
- namespace: Additional namespace above name if needed for inheritance
- frequency: Pulse frequency in Hz
- amplitude: Pulse amplitude in arbitrary units
- pulse_width: Duration of each pulse in ms
- offset: Baseline value of the signal between pulses
- start_time: Time delay before stimulation begins in ms
- smooth: Smoothing parameter for pulse transitions, set to 0 for sharp transitions
- pulses_per_burst: Number of pulses in each burst
- bursts_per_block: Number of bursts in the block
- pre_block_time: Time before the block starts in ms
- inter_burst_time: Time between bursts in ms
sts = @variables u(t) [output = true]
Returns a DBS stimulus blox that outputs a block of pulse bursts.
"""
function ProtocolDBS(;
name,
namespace=nothing,
frequency=130.0,
amplitude=2.5,
pulse_width=0.066,
offset=0.0,
start_time=0.0,
smooth=1e-4,
pulses_per_burst=10,
bursts_per_block=12,
pre_block_time=200.0,
inter_burst_time=200.0
)
# Ensure consistent numeric types for all parameters
frequency, amplitude, pulse_width, offset, start_time, smooth, pre_block_time, inter_burst_time =
promote(frequency, amplitude, pulse_width, offset, start_time, smooth, pre_block_time, inter_burst_time)

eqs = [u ~ stimulus(t)]
sys = System(eqs, t, sts, p; name=name)
# Convert to kHz (to match interal time in ms)
frequency_khz = frequency/1000.0

# Pre-compute timing parameters for the protocol
pulse_period = 1/frequency_khz
burst_duration = pulse_period * pulses_per_burst
burst_plus_gap = burst_duration + inter_burst_time

function protocol_stimulus(t)
# Compute timing relative to protocol start
t_adjusted = t - pre_block_time # Time since protocol start
current_burst = floor(t_adjusted / burst_plus_gap) # Current burst number
t_within_burst_cycle = t_adjusted - current_burst * burst_plus_gap # Time within current burst

new(p, sys, namespace, stimulus)
# Nested ifelse structure (for compatibility with Symbolics) determines output at time t:
# 1. Before protocol starts: return offset
# 2. After all bursts complete: return offset
# 3. Between bursts: return offset
# 4. During burst: return square wave pulse
ifelse(t < pre_block_time,
offset,
ifelse(current_burst >= bursts_per_block,
offset,
ifelse(t_within_burst_cycle >= burst_duration - pulse_width/2,
offset,
ifelse(smooth == 0,
square(t_within_burst_cycle, frequency_khz, amplitude, offset, start_time, pulse_width),
square(t_within_burst_cycle, frequency_khz, amplitude, offset, start_time, pulse_width, smooth)
)
)
)
)
end

p = paramscoping(
tunable=false;
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
offset=offset,
smooth=smooth,
start_time=start_time,
pulses_per_burst=pulses_per_burst,
bursts_per_block=bursts_per_block,
pre_block_time=pre_block_time,
inter_burst_time=inter_burst_time,
)

sts = @variables u(t) [output = true]
eqs = [u ~ protocol_stimulus(t)]
sys = System(eqs, t, sts, p; name=name)

DBS(p, sys, namespace, protocol_stimulus)
end

function sawtooth(t, f, offset)
Expand Down Expand Up @@ -94,7 +217,7 @@ function detect_transitions(t, signal::Vector{T}; atol=0) where T <: AbstractFlo
end

function compute_transition_times(stimulus::Function, f , dt, tspan, start_time, pulse_width; atol=0)
period = 1 / f
period = 1000.0 / f
n_periods = floor((tspan[end] - start_time) / period)

# Detect single pulse transition points
Expand Down Expand Up @@ -127,4 +250,26 @@ function compute_transition_values(transition_times, t, signal)
transition_values = signal[indices]

return transition_values
end

function get_protocol_duration(dbs::DBS)

# Check if this is a protocol DBS by looking at the number of parameters (in the future we may create a DBS subtype)
if length(dbs.params) < 10
error("This DBS object does not contain protocol parameters")
end

# Access parameters in correct order based on paramscoping
frequency = ModelingToolkit.getdefault(dbs.params[1])
pulses_per_burst = ModelingToolkit.getdefault(dbs.params[7])
bursts_per_block = ModelingToolkit.getdefault(dbs.params[8])
pre_block_time = ModelingToolkit.getdefault(dbs.params[9])
inter_burst_time = ModelingToolkit.getdefault(dbs.params[10])

# Calculate total protocol duration
pulse_period = 1000.0/frequency
burst_duration = pulses_per_burst * pulse_period
block_duration = bursts_per_block * (burst_duration + inter_burst_time) - inter_burst_time

return pre_block_time + block_duration
end
Loading

0 comments on commit 6f59b05

Please sign in to comment.