Skip to content

Commit

Permalink
Merge branch 'master' into ho/german_tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
harisorgn authored Oct 8, 2024
2 parents a02fca7 + d4ec0bc commit 1d2e814
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 146 deletions.
Binary file added docs/src/assets/spectral_DCM_illustration.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 27 additions & 9 deletions docs/src/tutorials/spectralDCM.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
# # Spectral Dynamic Causal Modeling Tutorial
# # Introduction
# Here we roughly resemble the simulation in the [SPM12](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) script DEM_demo_induced_fMRI.m in [Neuroblox](https://www.neuroblox.org/).
# This work was also presented in Hofmann et al.[1]

# <table style="width:100%">
# <tr>
# <td style="width:60%">
#
# In this tutorial we will introduce how to perform a spectral Dynamic Causal Modeling analysis on simulated data [1,2].
# To do so we roughly resemble the procedure in the [SPM12](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) script `DEM_demo_induced_fMRI.m` in [Neuroblox](https://www.neuroblox.org/).
# This work was also presented in Hofmann et al.[2]
#
# In this tutorial we will define a circuit of three linear neuronal mass models, all driven by an Ornstein-Uhlenbeck process.
# We will model fMRI data by a balloon model and BOLD signal on top.
# After simulation of this simple model we will use spectral Dynamic Causal Modeling to infer some of the model parameters from the simulation time series.
# - define the graph, add blocks
# - simulate the model
# - compute the cross spectral density
# After simulation of this simple model we will use spectral Dynamic Causal Modeling to infer some of the model parameters from the simulation time series.
#
# A brief outline of the procedure we will pursue:
# - define the graph, add blocks -> section A, B and C in the figure
# - simulate the model -> instead we could also use actual data, section D in figure
# - compute the cross spectral density
# - setup the DCM
# - estimate
# - estimate parameters
# - plot the results
#
# </td>
# <td>
#
# <img src="./docs/src/assets/spectral_DCM_illustration.png" width="350" height="470" />
#
#
# </td>
# </tr>
# </table>

using Neuroblox
using LinearAlgebra
Expand Down Expand Up @@ -192,4 +209,5 @@ freeenergy(state)
ecbarplot(state, setup, A_true)

# ## References
# [Hofmann, David, Anthony G. Chesebro, Chris Rackauckas, Lilianne R. Mujica-Parodi, Karl J. Friston, Alan Edelman, and Helmut H. Strey. “Leveraging Julia’s Automated Differentiation and Symbolic Computation to Increase Spectral DCM Flexibility and Speed.” bioRxiv: The Preprint Server for Biology, 2023.](https://doi.org/10.1101/2023.10.27.564407)
# [1] [Novelli, Leonardo, Karl Friston, and Adeel Razi. “Spectral Dynamic Causal Modeling: A Didactic Introduction and Its Relationship with Functional Connectivity.” Network Neuroscience 8, no. 1 (April 1, 2024): 178–202.](https://doi.org/10.1162/netn_a_00348) \
# [2] [Hofmann, David, Anthony G. Chesebro, Chris Rackauckas, Lilianne R. Mujica-Parodi, Karl J. Friston, Alan Edelman, and Helmut H. Strey. “Leveraging Julia’s Automated Differentiation and Symbolic Computation to Increase Spectral DCM Flexibility and Speed.” bioRxiv: The Preprint Server for Biology, 2023.](https://doi.org/10.1101/2023.10.27.564407)
2 changes: 1 addition & 1 deletion src/blox/canonicalmicrocircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mutable struct JansenRitSPM12 <: NeuralMassBlox
eqs = [D(x) ~ y - ((2/τ)*x),
D(y) ~ -x/*τ) + jcn/τ]

sys = System(eqs, t, name=name)
sys = System(eqs, t, sts, p, name=name)
new(p, sts[1], sts[3], sys, namespace)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/measurementmodels/fmri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct BalloonModel <: ObserverBlox
D(lnq) ~ (exp(lnu)/exp(lnq)*((1 - (1 - H[5])^(exp(lnu)^-1))/H[5]) - exp(lnν)^(H[4]^-1 - 1))/(H[3]*exp(lnτ)),
bold ~ B[2]*(k1 - k1*exp(lnq) + exp(lnϵ)*B[3]*B[5]*B[1] - exp(lnϵ)*B[3]*B[5]*B[1]*exp(lnq)/exp(lnν) + 1-exp(lnϵ) - (1-exp(lnϵ))*exp(lnν))
]
sys = System(eqs, t, name=name)
sys = System(eqs, t, sts, p; name=name)
new(p, Num(0), sts[5], sys, namespace)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/measurementmodels/lfp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct LeadField <: ObserverBlox
lfp ~ L * jcn
]

sys = System(eqs, t; name=name)
sys = System(eqs, t, sts, p; name=name)
new(p, Num(0), sts[2], sys, namespace)
end
end
272 changes: 138 additions & 134 deletions test/datafitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using MAT
max_iter = 128
dt = 2.0 # time bin in seconds
freq = range(min(128, ns*dt)^-1, max(8, 2*dt)^-1, 32) # define frequencies at which to evaluate the CSD


########## assemble the model ##########
g = MetaDiGraph()
Expand Down Expand Up @@ -100,159 +100,163 @@ using MAT
end

@testset "LFP test" begin
### Load data ###
vars = matread(joinpath(@__DIR__, "spm12_cmc.mat"));
data = DataFrame(vars["data"], :auto) # turn data into DataFrame, name column names after building the model.
x = vars["x"] # point around which expansion is computed
nrr = ncol(data) # number of recorded regions
ns = nrow(data) # number of samples
max_iter = 128
dt = 2.0 # time bin in seconds
freq = range(1.0, 64.0) # define frequencies at which to evaluate the CSD
# fixme
@test_broken begin
### Load data ###
vars = matread(joinpath(@__DIR__, "spm12_cmc.mat"));
data = DataFrame(vars["data"], :auto) # turn data into DataFrame, name column names after building the model.
x = vars["x"] # point around which expansion is computed
nrr = ncol(data) # number of recorded regions
ns = nrow(data) # number of samples
max_iter = 128
dt = 2.0 # time bin in seconds
freq = range(1.0, 64.0) # define frequencies at which to evaluate the CSD

########## assemble the model ##########
g = MetaDiGraph()
global_ns = :g # global namespace
regions = Dict()
########## assemble the model ##########
g = MetaDiGraph()
global_ns = :g # global namespace
regions = Dict()

@parameters lnr = 0.0
@parameters lnτ_ss=0 lnτ_sp=0 lnτ_ii=0 lnτ_dp=0
@parameters C=512.0 [tunable = false] # TODO: SPM12 has this seemingly arbitrary 512 pre-factor in spm_fx_cmc.m. Can we understand why?
for ii = 1:nrr
region = CanonicalMicroCircuitBlox(;namespace=global_ns, name=Symbol("r$(ii)₊cmc"),
τ_ss=exp(lnτ_ss)*0.002, τ_sp=exp(lnτ_sp)*0.002, τ_ii=exp(lnτ_ii)*0.016, τ_dp=exp(lnτ_dp)*0.028,
r_ss=exp(lnr)*2.0/3, r_sp=exp(lnr)*2.0/3, r_ii=exp(lnr)*2.0/3, r_dp=exp(lnr)*2.0/3)
add_blox!(g, region)
regions[ii] = nv(g) # store index of neural mass model
input = ExternalInput(;name=Symbol("r$(ii)₊ei"), I=1.0)
add_blox!(g, input)
add_edge!(g, nv(g), nv(g) - 1, Dict(:weight => C))
@parameters lnr = 0.0
@parameters lnτ_ss=0 lnτ_sp=0 lnτ_ii=0 lnτ_dp=0
@parameters C=512.0 [tunable = false] # TODO: SPM12 has this seemingly arbitrary 512 pre-factor in spm_fx_cmc.m. Can we understand why?
for ii = 1:nrr
region = CanonicalMicroCircuitBlox(;namespace=global_ns, name=Symbol("r$(ii)₊cmc"),
τ_ss=exp(lnτ_ss)*0.002, τ_sp=exp(lnτ_sp)*0.002, τ_ii=exp(lnτ_ii)*0.016, τ_dp=exp(lnτ_dp)*0.028,
r_ss=exp(lnr)*2.0/3, r_sp=exp(lnr)*2.0/3, r_ii=exp(lnr)*2.0/3, r_dp=exp(lnr)*2.0/3)
add_blox!(g, region)
regions[ii] = nv(g) # store index of neural mass model
input = ExternalInput(;name=Symbol("r$(ii)₊ei"), I=1.0)
add_blox!(g, input)
add_edge!(g, nv(g), nv(g) - 1, Dict(:weight => C))

# add lead field (LFP measurement)
measurement = LeadField(;name=Symbol("r$(ii)₊lf"))
add_blox!(g, measurement)
# connect measurement with neuronal signal
add_edge!(g, nv(g) - 2, nv(g), Dict(:weight => 1.0))
end
# add lead field (LFP measurement)
measurement = LeadField(;name=Symbol("r$(ii)₊lf"))
add_blox!(g, measurement)
# connect measurement with neuronal signal
add_edge!(g, nv(g) - 2, nv(g), Dict(:weight => 1.0))
end

nl = Int((nrr^2-nrr)/2) # number of links unidirectional
@parameters a_sp_ss[1:nl] = repeat([0.0], nl) # forward connection parameter sp -> ss
@parameters a_sp_dp[1:nl] = repeat([0.0], nl) # forward connection parameter sp -> dp
@parameters a_dp_sp[1:nl] = repeat([0.0], nl) # backward connection parameter dp -> sp
@parameters a_dp_ii[1:nl] = repeat([0.0], nl) # backward connection parameters dp -> ii
nl = Int((nrr^2-nrr)/2) # number of links unidirectional
@parameters a_sp_ss[1:nl] = repeat([0.0], nl) # forward connection parameter sp -> ss
@parameters a_sp_dp[1:nl] = repeat([0.0], nl) # forward connection parameter sp -> dp
@parameters a_dp_sp[1:nl] = repeat([0.0], nl) # backward connection parameter dp -> sp
@parameters a_dp_ii[1:nl] = repeat([0.0], nl) # backward connection parameter dp -> ii

k = 0
for i in 1:nrr
for j in (i+1):nrr
k += 1
# forward connection matrix
add_edge!(g, regions[i], regions[j], :weightmatrix,
[0 exp(a_sp_ss[k]) 0 0; # connection from sp to ss
0 0 0 0;
0 0 0 0;
0 exp(a_sp_dp[k])/2 0 0] * 200) # connection from sp to dp
# backward connection matrix
add_edge!(g, regions[j], regions[i], :weightmatrix,
[0 0 0 0;
0 0 0 -exp(a_dp_sp[k]); # connection from dp to sp
0 0 0 -exp(a_dp_ii[k])/2; # connection from dp to ii
0 0 0 0] * 200)
k = 0
for i in 1:nrr
for j in (i+1):nrr
k += 1
# forward connection matrix
add_edge!(g, regions[i], regions[j], :weightmatrix,
[0 exp(a_sp_ss[k]) 0 0; # connection from sp to ss
0 0 0 0;
0 0 0 0;
0 exp(a_sp_dp[k])/2 0 0] * 200) # connection from sp to dp
# backward connection matrix
add_edge!(g, regions[j], regions[i], :weightmatrix,
[0 0 0 0;
0 0 0 -exp(a_dp_sp[k]); # connection from dp to sp
0 0 0 -exp(a_dp_ii[k])/2; # connection from dp to ii
0 0 0 0] * 200)
end
end
end

@named fullmodel = system_from_graph(g; split=false)
@named fullmodel = system_from_graph(g; split=false)

# attribute initial conditions to states
sts, idx_sts = get_dynamic_states(fullmodel)
idx_u = get_idx_tagged_vars(fullmodel, "ext_input") # get index of external input state
idx_measurement, obsvars = get_eqidx_tagged_vars(fullmodel, "measurement") # get index of equation of bold state
rename!(data, Symbol.(obsvars))
# attribute initial conditions to states
sts, idx_sts = get_dynamic_states(fullmodel)
idx_u = get_idx_tagged_vars(fullmodel, "ext_input") # get index of external input state
idx_measurement, obsvars = get_eqidx_tagged_vars(fullmodel, "measurement") # get index of equation of bold state
rename!(data, Symbol.(obsvars))

initcond = OrderedDict(sts .=> 0.0)
rnames = []
map(x->push!(rnames, split(string(x), "")[1]), sts);
rnames = unique(rnames);
for (i, r) in enumerate(rnames)
for (j, s) in enumerate(sts[r .== map(x -> x[1], split.(string.(sts), ""))])
initcond[s] = x[i, j]
initcond = OrderedDict(sts .=> 0.0)
rnames = []
map(x->push!(rnames, split(string(x), "")[1]), sts);
rnames = unique(rnames);
for (i, r) in enumerate(rnames)
for (j, s) in enumerate(sts[r .== map(x -> x[1], split.(string.(sts), ""))])
initcond[s] = x[i, j]
end
end
end

modelparam = OrderedDict()
np = sum(tunable_parameters(fullmodel); init=0) do par
val = Symbolics.getdefaultval(par)
modelparam[par] = val
length(val)
end
indices = Dict(:dspars => collect(1:np))
# Noise parameter mean
modelparam[:lnα] = zeros(Float64, 2, nrr); # intrinsic fluctuations, ln(α) as in equation 2 of Friston et al. 2014
n = length(modelparam[:lnα]);
indices[:lnα] = collect(np+1:np+n);
np += n;
modelparam[:lnβ] = [-16.0, -16.0]; # global observation noise, ln(β) as above
n = length(modelparam[:lnβ]);
indices[:lnβ] = collect(np+1:np+n);
np += n;
modelparam[:lnγ] = [-16.0, -16.0]; # region specific observation noise
indices[:lnγ] = collect(np+1:np+nrr);
np += nrr
indices[:u] = idx_u
indices[:m] = idx_measurement
indices[:sts] = idx_sts
modelparam = OrderedDict()
np = sum(tunable_parameters(fullmodel); init=0) do par
val = Symbolics.getdefaultval(par)
modelparam[par] = val
length(val)
end
indices = Dict(:dspars => collect(1:np))
# Noise parameter mean
modelparam[:lnα] = zeros(Float64, 2, nrr); # intrinsic fluctuations, ln(α) as in equation 2 of Friston et al. 2014
n = length(modelparam[:lnα]);
indices[:lnα] = collect(np+1:np+n);
np += n;
modelparam[:lnβ] = [-16.0, -16.0]; # global observation noise, ln(β) as above
n = length(modelparam[:lnβ]);
indices[:lnβ] = collect(np+1:np+n);
np += n;
modelparam[:lnγ] = [-16.0, -16.0]; # region specific observation noise
indices[:lnγ] = collect(np+1:np+nrr);
np += nrr
indices[:u] = idx_u
indices[:m] = idx_measurement
indices[:sts] = idx_sts

# define prior variances
paramvariance = copy(modelparam)
paramvariance[:lnα] = ones(Float64, size(modelparam[:lnα]))./128.0;
paramvariance[:lnβ] = ones(Float64, nrr)./128.0;
paramvariance[:lnγ] = ones(Float64, nrr)./128.0;
for (k, v) in paramvariance
if occursin("a_", string(k))
paramvariance[k] = 1/16.0
elseif "lnr" == string(k)
paramvariance[k] = 1/64.0;
elseif occursin("lnτ", string(k))
paramvariance[k] = 1/32.0;
elseif occursin("lf₊L", string(k))
paramvariance[k] = 64;
# define prior variances
paramvariance = copy(modelparam)
paramvariance[:lnα] = ones(Float64, size(modelparam[:lnα]))./128.0;
paramvariance[:lnβ] = ones(Float64, nrr)./128.0;
paramvariance[:lnγ] = ones(Float64, nrr)./128.0;
for (k, v) in paramvariance
if occursin("a_", string(k))
paramvariance[k] = 1/16.0
elseif "lnr" == string(k)
paramvariance[k] = 1/64.0;
elseif occursin("lnτ", string(k))
paramvariance[k] = 1/32.0;
elseif occursin("lf₊L", string(k))
paramvariance[k] = 64;
end
end
end

# priors = DataFrame(name=[k for k in keys(modelparam)], mean=[m for m in values(modelparam)], variance=[v for v in values(paramvariance)])
priors = (μθ_pr = modelparam,
Σθ_pr = paramvariance
);
# priors = DataFrame(name=[k for k in keys(modelparam)], mean=[m for m in values(modelparam)], variance=[v for v in values(paramvariance)])
priors = (μθ_pr = modelparam,
Σθ_pr = paramvariance
);

hype = matread(joinpath(@__DIR__, "spm12_cmc_hyperpriors.mat"));
hyperpriors = Dict(:Πλ_pr => hype["ihC"], # prior metaparameter precision, needs to be a matrix
:μλ_pr => vec(hype["hE"]), # prior metaparameter mean, needs to be a vector
:Q => hype["Q"]
);
hype = matread(joinpath(@__DIR__, "spm12_cmc_hyperpriors.mat"));
hyperpriors = Dict(:Πλ_pr => hype["ihC"], # prior metaparameter precision, needs to be a matrix
:μλ_pr => vec(hype["hE"]), # prior metaparameter mean, needs to be a vector
:Q => hype["Q"]
);

csdsetup = (mar_order = 8, freq = freq, dt = dt);
csdsetup = (mar_order = 8, freq = freq, dt = dt);

(state, setup) = setup_sDCM(data, fullmodel, initcond, csdsetup, priors, hyperpriors, indices, modelparam, "LFP");
# HACK: on machines with very small amounts of RAM, Julia can run out of stack space while compiling the code called in this loop
# this should be rewritten to abuse the compiler less, but for now, an easy solution is just to run it with more allocated stack space.
with_stack(f, n) = fetch(schedule(Task(f,n)))
(state, setup) = setup_sDCM(data, fullmodel, initcond, csdsetup, priors, hyperpriors, indices, modelparam, "LFP");
# HACK: on machines with very small amounts of RAM, Julia can run out of stack space while compiling the code called in this loop
# this should be rewritten to abuse the compiler less, but for now, an easy solution is just to run it with more allocated stack space.
with_stack(f, n) = fetch(schedule(Task(f,n)))

with_stack(5_000_000) do # 5MB of stack space
for iter in 1:128
state.iter = iter
run_sDCM_iteration!(state, setup)
print("iteration: ", iter, " - F:", state.F[end] - state.F[2], " - dF predicted:", state.dF[end], "\n")
if iter >= 4
criterion = state.dF[end-3:end] .< setup.tolerance
if all(criterion)
print("convergence\n")
break
with_stack(5_000_000) do # 5MB of stack space
for iter in 1:128
state.iter = iter
run_sDCM_iteration!(state, setup)
print("iteration: ", iter, " - F:", state.F[end] - state.F[2], " - dF predicted:", state.dF[end], "\n")
if iter >= 4
criterion = state.dF[end-3:end] .< setup.tolerance
if all(criterion)
print("convergence\n")
break
end
end
end
end
end

### COMPARE RESULTS WITH MATLAB RESULTS ###
@show state.F[end]
@test state.F[end] > 1891*0.99
@test state.F[end] < 1891*1.01
### COMPARE RESULTS WITH MATLAB RESULTS ###
@show state.F[end]
@test state.F[end] > 1891*0.99
@test state.F[end] < 1891*1.01
true
end
end

0 comments on commit 1d2e814

Please sign in to comment.