Skip to content

Commit

Permalink
improve type stability of spectralDCM code
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Mar 22, 2024
1 parent 7a778bb commit 67a1e98
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
8 changes: 4 additions & 4 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ mutable struct VLState
dFdθθ::Matrix{Float64}
end

struct VLSetup
model_at_x0
y_csd::Array{Complex}
struct VLSetup{F}
model_at_x0::F
y_csd::Array{ComplexF64}
tolerance::Float64
systemnums::Vector{Int}
systemvecs::Vector{Vector{Float64}}
systemmatrices::Vector{Matrix{Float64}}
Q::Matrix{Complex}
Q::Matrix{ComplexF64}
end


Expand Down
80 changes: 38 additions & 42 deletions src/datafitting/spectralDCM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ function LinearAlgebra.eigen(M::Matrix{Dual{T, P, np}}) where {T, P, np}
nd = size(M, 1)
A = (p->p.value).(M)
F = eigen(A, sortby=nothing, permute=true)
λ, V = F.values, F.vectors
λ, V = F
local ∂λ_agg, ∂V_agg
# compute eigenvalue and eigenvector derivatives for all partials
for i = 1:np
dA = (p->p.partials[i]).(M)
tmp = V \ dA
∂K = tmp * V # V^-1 * dA * V
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ_tmp = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag) # why do only copy when complex??
∂λ_tmp = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag) # copy only needed for Complex because `real.(v)` makes a new array
∂K ./= transpose(λ) .- λ
fill!(∂Kdiag, 0)
∂V_tmp = mul!(tmp, V, ∂K)
Expand All @@ -54,29 +54,33 @@ function LinearAlgebra.eigen(M::Matrix{Dual{T, P, np}}) where {T, P, np}
∂λ_agg = cat(∂λ_agg, ∂λ_tmp, dims=2)
end
end
∂V = Array{Partials}(undef, nd, nd)
∂λ = Array{Partials}(undef, nd)
# reassemble the aggregated vectors and values into a Partials type
for i = 1:nd
∂λ[i] = Partials(Tuple(∂λ_agg[i, :]))
for j = 1:nd
∂V[i, j] = Partials(Tuple(∂V_agg[i, j, :]))
end
∂V = map(Iterators.product(1:nd, 1:nd)) do (i, j)
Partials(NTuple{np}(∂V_agg[i, j, :]))
end
∂λ = map(1:nd) do i
Partials(NTuple{np}(∂λ_agg[i, :]))
end
if eltype(V) <: Complex
evals = map((x,y)->Complex(Dual{T, Float64, length(y)}(real(x), Partials(Tuple(real(y)))),
Dual{T, Float64, length(y)}(imag(x), Partials(Tuple(imag(y))))), F.values, ∂λ)
evecs = map((x,y)->Complex(Dual{T, Float64, length(y)}(real(x), Partials(Tuple(real(y)))),
Dual{T, Float64, length(y)}(imag(x), Partials(Tuple(imag(y))))), F.vectors, ∂V)
evals = map(λ, ∂λ) do x, y
rex, imx = reim(x)
rey, imy = real.(Tuple(y)), imag.(Tuple(y))
Complex(Dual{T}(rex, Partials(rey)), Dual{T}(imx, Partials(imy)))
end
evecs = map(V, ∂V) do x, y
rex, imx = reim(x)
rey, imy = real.(Tuple(y)), imag.(Tuple(y))
Complex(Dual{T}(rex, Partials(rey)), Dual{T}(imx, Partials(imy)))
end
else
evals = Dual{T, Float64, length(∂λ[1])}.(F.values, ∂λ)
evecs = Dual{T, Float64, length(∂V[1])}.(F.vectors, ∂V)
evals = Dual{T}.(λ, ∂λ)
evecs = Dual{T}.(V, ∂V)
end
return Eigen(evals, evecs)
end

function transferfunction_fmri(ω, derivatives, params, params_idx)
∂f = derivatives[:∂f](params[params_idx[:evolpars]])
∂f = derivatives.∂f(params[params_idx[:evolpars]])
if ∂f isa Vector
∂f = reshape(∂f, sqrt(length(∂f)), sqrt(length(∂f)))
end
Expand All @@ -101,7 +105,7 @@ function transferfunction_fmri(ω, derivatives, params, params_idx)
Λ = F.values
V = F.vectors

∂g = derivatives[:∂g](params[params_idx[:obspars]][1])
∂g = derivatives.∂g(params[params_idx[:obspars]][1])
dgdv = ∂g*V
dvdu = V\dfdu # u is external variable which we don't use right now. With external variable this would read V/dfdu

Expand All @@ -114,8 +118,7 @@ function transferfunction_fmri(ω, derivatives, params, params_idx)
for i = 1:ng
for k = 1:nk
# transfer functions (FFT of kernel)
Sk = (1im*2*pi*ω .- Λ[k]).^-1
S[:,i,j] .+= dgdv[i,k]*dvdu[k,j]*Sk
S[:,i,j] .+= (dgdv[i,k]*dvdu[k,j]) .* ((1im*2*pi) .* ω .- Λ[k]).^-1
end
end
end
Expand Down Expand Up @@ -149,7 +152,7 @@ function csd_approx(ω, derivatives, params, params_idx)
Gu = zeros(eltype(G), nω, nd, nd)
Gn = zeros(eltype(G), nω, nd, nd)
for i = 1:nd
Gu[:, i, i] .+= exp(α[1])*G
Gu[:, i, i] .+= exp(α[1]) .* G
end
# region specific observation noise (1/f or AR(1) form)
G = ω.^(-exp(β[2])/2)
Expand Down Expand Up @@ -205,9 +208,9 @@ end
"""
function matlab_norm(M, p)
if p == 1
return maximum(vec(sum(abs.(M),dims=1)))
return maximum(sum(abs, M, dims=1))
elseif p == Inf
return maximum(vec(sum(abs.(M),dims=2)))
return maximum(sum(abs, M, dims=2))
elseif p == 2
print("Not implemented yet!\n")
return NaN
Expand All @@ -226,7 +229,7 @@ function csd_Q(csd)
end
end
end
Q = inv(Q .+ matlab_norm(Q, 1)/32*Matrix(I, size(Q))) # TODO: MATLAB's and Julia's norm function are different! Reconciliate?
Q = inv(Q + matlab_norm(Q, 1)*I/32) # TODO: MATLAB's and Julia's norm function are different! Reconciliate?
return Q
end

Expand All @@ -241,10 +244,10 @@ end
function spm_logdet(M)
TOL = 1e-16
s = diag(M)
if sum(abs.(s)) != sum(abs.(M[:]))
~, s, ~ = svd(M)
if sum(abs, s) != sum(abs, M)
s = svdvals(M)
end
return sum(log.(s[(s .> TOL) .& (s .< TOL^-1)]))
return sum((log(sval) for sval in s if TOL < sval < inv(TOL)), init=zero(eltype(s)))
end

"""
Expand All @@ -258,7 +261,7 @@ end
function vecparam(param::OrderedDict)
flatparam = Float64[]
for v in values(param)
if (typeof(v) <: Array)
if v isa Array
for vv in v
push!(flatparam, vv)
end
Expand Down Expand Up @@ -517,8 +520,8 @@ function spectralVI(data, neuraldynmodel, observationmodel, initcond, csdsetup,
obs = get_hemodynamic_observers(neuraldynmodel, nr)
obsstates = map(obs -> [initcond[s] for s in obs], values(obs[2]))

derivatives = Dict(:∂f => par -> jac_f(statevals, addnontunableparams(par, neuraldynmodel), t),
:∂g => par -> grad_full(grad_g, obsstates, obs[1], par, nr, ns))
derivatives = (∂f = par -> jac_f(statevals, addnontunableparams(par, neuraldynmodel), t),
∂g = par -> grad_full(grad_g, obsstates, obs[1], par, nr, ns))

θΣ = diagm(vecparam(OrderedDict(priors.name .=> priors.variance)))
# depending on the definition of the priors (note that we take it from the SPM12 code), some dimensions are set to 0 and thus are not changed.
Expand All @@ -535,8 +538,8 @@ function spectralVI(data, neuraldynmodel, observationmodel, initcond, csdsetup,

### Collect prior means and covariances ###
Q = csd_Q(y_csd); # compute prior of Q, the precision (of the data) components. See Friston etal. 2007 Appendix A
priors = Dict(:μ => OrderedDict(priors.name .=> priors.mean),
:Σ => Dict(
priors = (μ = OrderedDict(priors.name .=> priors.mean),
Σ = Dict(
:Πθ_pr => inv(θΣ), # prior model parameter precision
:Πλ_pr => hyperpriors[:Πλ_pr], # prior metaparameter precision
:μλ_pr => hyperpriors[:μλ_pr], # prior metaparameter mean
Expand Down Expand Up @@ -576,8 +579,8 @@ function setup_sDCM(data, stateevolutionmodel, observationmodel, initcond, csdse
# match states of observation model with different states of evolution model
obs = get_hemodynamic_observers(stateevolutionmodel, nr)
obsstates = Dict(map((v, k) -> k => [initcond[s] for s in v], values(obs[2]), keys(obs[2])))
derivatives = Dict(:∂f => par -> jac_f(statevals, addnontunableparams(par, stateevolutionmodel), t),
:∂g => par -> grad_full(grad_g, obsstates, obs[1], par, nr, ns))
derivatives = (∂f = par -> jac_f(statevals, addnontunableparams(par, stateevolutionmodel), t),
∂g = par -> grad_full(grad_g, obsstates, obs[1], par, nr, ns))

μθ_pr = vecparam(OrderedDict(priors.name .=> priors.mean)) # note: μθ_po is posterior and μθ_pr is prior
Σθ_pr = diagm(vecparam(OrderedDict(priors.name .=> priors.variance)))
Expand All @@ -598,7 +601,7 @@ function setup_sDCM(data, stateevolutionmodel, observationmodel, initcond, csdse
0, # iter
-4, # log ascent rate
[-Inf], # free energy
[], # delta free energy
Float64[], # delta free energy
8*ones(nh), # metaparameter, initial condition. TODO: why are we not just using the prior mean?
zeros(np), # parameter estimation error ϵ_θ
[zeros(np), 8*ones(nh)], # memorize reset state
Expand All @@ -622,13 +625,7 @@ function setup_sDCM(data, stateevolutionmodel, observationmodel, initcond, csdse
end

function run_sDCM_iteration!(state::VLState, setup::VLSetup)
μθ_po = state.μθ_po

λ = state.λ
v = state.v
ϵ_θ = state.ϵ_θ
dFdθ = state.dFdθ
dFdθθ = state.dFdθθ
(;μθ_po, λ, v, ϵ_θ, dFdθ, dFdθθ) = state

f = setup.model_at_x0
y = setup.y_csd # cross-spectral density
Expand All @@ -637,7 +634,6 @@ function run_sDCM_iteration!(state::VLState, setup::VLSetup)
(Πθ_pr, Πλ_pr) = setup.systemmatrices
# Πθ_pr = deserialize("tmp.dat")[vcat(1:20, 24), :]' *Πθ_pr* deserialize("tmp.dat")[vcat(1:20, 24), :]
Q = setup.Q

dfdp = jacobian(f, μθ_po)# * deserialize("tmp.dat")[vcat(1:20, 24), :]

norm_dfdp = matlab_norm(dfdp, Inf);
Expand Down
8 changes: 4 additions & 4 deletions test/datafitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ for (k, v) in paramvariance
end

priors = DataFrame(name=[k for k in keys(modelparam)], mean=[m for m in values(modelparam)], variance=[v for v in values(paramvariance)])
hyperpriors = Dict(:Πλ_pr => vars["ihC"]*ones(1, 1), # prior metaparameter precision, needs to be a matrix
:μλ_pr => [vars["hE"]] # prior metaparameter mean, needs to be a vector
);
hyperpriors = (Πλ_pr = vars["ihC"]*ones(1, 1), # prior metaparameter precision, needs to be a matrix
μλ_pr = [vars["hE"]] # prior metaparameter mean, needs to be a vector
);

csdsetup = Dict(:p => 8, :freq => vec(vars["Hz"]), :dt => vars["dt"]);
csdsetup = (p = 8, freq = vec(vars["Hz"]), dt = vars["dt"]);

(state, setup) = setup_sDCM(data, neuronmodel, bold, initcond, csdsetup, priors, hyperpriors, params_idx);
for iter in 1:26
Expand Down

0 comments on commit 67a1e98

Please sign in to comment.