Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jun 12, 2024
2 parents d1b35cf + f4454de commit 8afa1b0
Showing 1 changed file with 41 additions and 53 deletions.
94 changes: 41 additions & 53 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,10 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank):
)

# Eq. 33 (Ewald et al.)
C_bar_aa = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa))
C_bar_ab = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb))
C_bar_bb = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb))
C_bar_ba = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa))
C_bar_aa = U_bar_aa.transpose(0, 1, 3, 2) @ (C_aa @ U_bar_aa)
C_bar_ab = U_bar_aa.transpose(0, 1, 3, 2) @ (C_ab @ U_bar_bb)
C_bar_bb = U_bar_bb.transpose(0, 1, 3, 2) @ (C_bb @ U_bar_bb)
C_bar_ba = U_bar_bb.transpose(0, 1, 3, 2) @ (C_ba @ U_bar_aa)
C_bar = np.append(
np.append(C_bar_aa, C_bar_ab, axis=3),
np.append(C_bar_ba, C_bar_bb, axis=3),
Expand Down Expand Up @@ -412,7 +412,7 @@ def _compute_e(self, C, n_seeds):
T = self._compute_t(C_r, n_seeds)

# Eq. 4
D = np.matmul(T, np.matmul(C, T))
D = T @ (C @ T)

# E as imag. part of D between seeds and targets
return np.imag(D[..., :n_seeds, n_seeds:])
Expand All @@ -426,8 +426,8 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i):
freqs = np.arange(self.n_freqs)

# Eigendecomp. to find spatial filters for seeds and targets
w_seeds, V_seeds = np.linalg.eigh(np.matmul(E, E.transpose(0, 1, 3, 2)))
w_targets, V_targets = np.linalg.eigh(np.matmul(E.transpose(0, 1, 3, 2), E))
w_seeds, V_seeds = np.linalg.eigh(E @ E.transpose(0, 1, 3, 2))
w_targets, V_targets = np.linalg.eigh(E.transpose(0, 1, 3, 2) @ E)
if len(seed_idcs) == len(target_idcs) and np.all(
np.sort(seed_idcs) == np.sort(target_idcs)
):
Expand Down Expand Up @@ -464,19 +464,17 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i):

# Eq. 46 (seed spatial patterns)
self.patterns[0, con_i, :n_seeds] = (
np.matmul(np.real(C[..., :n_seeds, :n_seeds]), alpha_Ubar)
np.real(C[..., :n_seeds, :n_seeds]) @ alpha_Ubar
)[..., 0].T

# Eq. 47 (target spatial patterns)
self.patterns[1, con_i, :n_targets] = (
np.matmul(np.real(C[..., n_seeds:, n_seeds:]), beta_Ubar)
np.real(C[..., n_seeds:, n_seeds:]) @ beta_Ubar
)[..., 0].T

# Eq. 7
self.con_scores[con_i] = (
np.einsum(
"ijk,ijk->ij", alpha, np.matmul(E, np.expand_dims(beta, axis=3))[..., 0]
)
np.einsum("ijk,ijk->ij", alpha, (E @ np.expand_dims(beta, axis=3))[..., 0])
/ np.linalg.norm(alpha, axis=2)
* np.linalg.norm(beta, axis=2)
).T
Expand All @@ -488,9 +486,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i):
def _compute_mim(self, E, seed_idcs, target_idcs, con_i):
"""Compute MIM (a.k.a. GIM if seeds == targets) for one connection."""
# Eq. 14
self.con_scores[con_i] = (
np.matmul(E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T
)
self.con_scores[con_i] = (E @ E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T

# Eq. 15
if len(seed_idcs) == len(target_idcs) and np.all(
Expand Down Expand Up @@ -626,16 +622,14 @@ def _compute_cacoh(self, phis, C_ab, T_aa, T_bb):
C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab)

# Eq. 9; T_aa/bb is sqrt(inv(real(C_aa/bb)))
D = np.matmul(T_aa, np.matmul(C_ab, T_bb))
D = T_aa @ (C_ab @ T_bb)

# Eq. 12
a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1]
b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1]
a = np.linalg.eigh(D @ D.transpose(0, 1, 3, 2))[1][..., -1]
b = np.linalg.eigh(D.transpose(0, 1, 3, 2) @ D)[1][..., -1]

# Eq. 8
numerator = np.einsum(
"ijk,ijk->ij", a, np.matmul(D, np.expand_dims(b, axis=3))[..., 0]
)
numerator = np.einsum("ijk,ijk->ij", a, (D @ np.expand_dims(b, axis=3))[..., 0])
denominator = np.sqrt(
np.einsum("ijk,ijk->ij", a, a) * np.einsum("ijk,ijk->ij", b, b)
)
Expand All @@ -657,13 +651,13 @@ def _compute_patterns(
):
"""Compute CaCoh spatial patterns for the optimised phi."""
C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_bar_ab)
D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb))
a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1]
b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1]
D = T_aa @ (C_bar_ab @ T_bb)
a = np.linalg.eigh(D @ D.transpose(0, 1, 3, 2))[1][..., -1]
b = np.linalg.eigh(D.transpose(0, 1, 3, 2) @ D)[1][..., -1]

# Eq. 7 rearranged - multiply both sides by sqrt(inv(real(C_aa/bb)))
alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds
beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets
alpha = T_aa @ np.expand_dims(a, axis=3) # filter for seeds
beta = T_bb @ np.expand_dims(b, axis=3) # filter for targets

# Eqs. 46 & 47 of Ewald et al. (2012); i.e. transform filters to channel space
alpha_Ubar = np.matmul(U_bar_aa, alpha)
Expand All @@ -672,11 +666,11 @@ def _compute_patterns(
# Eq. 14
# seed spatial patterns
self.patterns[0, con_i, :n_seeds] = (
np.matmul(np.real(C[..., :n_seeds, :n_seeds]), alpha_Ubar)
np.real(C[..., :n_seeds, :n_seeds]) @ alpha_Ubar
)[..., 0].T
# target spatial patterns
self.patterns[1, con_i, :n_targets] = (
np.matmul(np.real(C[..., n_seeds:, n_seeds:]), beta_Ubar)
np.real(C[..., n_seeds:, n_seeds:]) @ beta_Ubar
)[..., 0].T

if self.store_filters:
Expand Down Expand Up @@ -780,10 +774,10 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank):
C_bb = csd[..., n_seeds:, n_seeds:]
C_ba = csd[..., n_seeds:, :n_seeds]

C_bar_aa = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa))
C_bar_ab = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb))
C_bar_bb = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb))
C_bar_ba = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa))
C_bar_aa = U_bar_aa.transpose(1, 0) @ (C_aa @ U_bar_aa)
C_bar_ab = U_bar_aa.transpose(1, 0) @ (C_ab @ U_bar_bb)
C_bar_bb = U_bar_bb.transpose(1, 0) @ (C_bb @ U_bar_bb)
C_bar_ba = U_bar_bb.transpose(1, 0) @ (C_ba @ U_bar_aa)
C_bar = np.append(
np.append(C_bar_aa, C_bar_ab, axis=3),
np.append(C_bar_ba, C_bar_bb, axis=3),
Expand Down Expand Up @@ -896,18 +890,18 @@ def _whittle_lwr_recursion(self, G):

# Perform recursion
for k in np.arange(2, q + 1):
var_A = G_b[:, (r - 1) * n : r * n, :] - np.matmul(
A_f[:, :, k_f], G_b[:, k_b, :]
var_A = G_b[:, (r - 1) * n : r * n, :] - (
A_f[:, :, k_f] @ G_b[:, k_b, :]
)
var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :])
var_B = cov - (A_b[:, :, k_b] @ G_b[:, k_b, :])
AA_f = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose(
0, 2, 1
)

var_A = G_f[:, (k - 1) * n : k * n, :] - np.matmul(
A_b[:, :, k_b], G_f[:, k_f, :]
var_A = G_f[:, (k - 1) * n : k * n, :] - (
A_b[:, :, k_b] @ G_f[:, k_f, :]
)
var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :])
var_B = cov - (A_f[:, :, k_f] @ G_f[:, k_f, :])
AA_b = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose(
0, 2, 1
)
Expand All @@ -919,20 +913,16 @@ def _whittle_lwr_recursion(self, G):
k_f = np.arange(k * n)
k_b = np.arange(r * n, qn)

A_f[:, :, k_f] = np.dstack(
(A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)
)
A_b[:, :, k_b] = np.dstack(
(AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))
)
A_f[:, :, k_f] = np.dstack((A_f_previous - (AA_f @ A_b_previous), AA_f))
A_b[:, :, k_b] = np.dstack((AA_b, A_b_previous - (AA_b @ A_f_previous)))
except np.linalg.LinAlgError as np_error:
raise RuntimeError(
"the autocovariance matrix is singular; check if your data is "
"rank deficient and specify an appropriate rank argument <= "
"the rank of the seeds and targets"
) from np_error

V = cov - np.matmul(A_f, G_f)
V = cov - (A_f @ G_f)
A_f = np.reshape(A_f, (t, n, n, q), order="F")

return A_f, V
Expand Down Expand Up @@ -976,11 +966,11 @@ def _iss_to_ugc(self, A, C, K, V, seeds, targets):

H = self._iss_to_tf(A, C, K, z) # spectral transfer function
V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets))
HV = np.matmul(H, np.linalg.cholesky(V))
S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6
HV = H @ np.linalg.cholesky(V)
S = HV @ HV.conj().transpose(0, 1, 3, 2) # Eq. 6
S_11 = S[np.ix_(freqs, times, targets, targets)]
HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1)
HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2))
HV_12 = H[np.ix_(freqs, times, targets, seeds)] @ V_22_1
HVH = HV_12 @ HV_12.conj().transpose(0, 1, 3, 2)

# Eq. 11
return np.real(np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH)))
Expand Down Expand Up @@ -1035,7 +1025,7 @@ def _partial_covar(self, V, seeds, targets):
np.linalg.cholesky(V[np.ix_(times, targets, targets)]),
V[np.ix_(times, targets, seeds)],
)
W = np.matmul(W.transpose(0, 2, 1), W)
W = W.transpose(0, 2, 1) @ W

return V[np.ix_(times, seeds, seeds)] - W

Expand All @@ -1050,9 +1040,7 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m):

H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128)
for t in range(A.shape[0]):
H[t] = I_n + np.matmul(
C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])
)
H[t] = I_n + (C[t] @ linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t]))

return H

Expand Down

0 comments on commit 8afa1b0

Please sign in to comment.