From f4454de7ff5faf9ac93061f49425ac40390ff2d0 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Wed, 12 Jun 2024 19:13:17 +0200 Subject: [PATCH] [MAINT] Replace np.matmul with @ operator (#201) --- .../spectral/epochs_multivariate.py | 102 ++++++++---------- 1 file changed, 43 insertions(+), 59 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index a912a17b..a55d6d37 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -265,10 +265,10 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): ) # Eq. 33 (Ewald et al.) - C_bar_aa = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul(U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul(U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar_aa = U_bar_aa.transpose(0, 1, 3, 2) @ (C_aa @ U_bar_aa) + C_bar_ab = U_bar_aa.transpose(0, 1, 3, 2) @ (C_ab @ U_bar_bb) + C_bar_bb = U_bar_bb.transpose(0, 1, 3, 2) @ (C_bb @ U_bar_bb) + C_bar_ba = U_bar_bb.transpose(0, 1, 3, 2) @ (C_ba @ U_bar_aa) C_bar = np.append( np.append(C_bar_aa, C_bar_ab, axis=3), np.append(C_bar_ba, C_bar_bb, axis=3), @@ -400,7 +400,7 @@ def _compute_e(self, C, n_seeds): T = self._compute_t(C_r, n_seeds) # Eq. 4 - D = np.matmul(T, np.matmul(C, T)) + D = T @ (C @ T) # E as imag. part of D between seeds and targets return np.imag(D[..., :n_seeds, n_seeds:]) @@ -414,8 +414,8 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): freqs = np.arange(self.n_freqs) # Eigendecomp. to find spatial filters for seeds and targets - w_seeds, V_seeds = np.linalg.eigh(np.matmul(E, E.transpose(0, 1, 3, 2))) - w_targets, V_targets = np.linalg.eigh(np.matmul(E.transpose(0, 1, 3, 2), E)) + w_seeds, V_seeds = np.linalg.eigh(E @ E.transpose(0, 1, 3, 2)) + w_targets, V_targets = np.linalg.eigh(E.transpose(0, 1, 3, 2) @ E) if len(seed_idcs) == len(target_idcs) and np.all( np.sort(seed_idcs) == np.sort(target_idcs) ): @@ -448,25 +448,19 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): # Eq. 46 (seed spatial patterns) self.patterns[0, con_i, :n_seeds] = ( - np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3)), - ) + np.real(C[..., :n_seeds, :n_seeds]) + @ (U_bar_aa @ np.expand_dims(alpha, axis=3)) )[..., 0].T # Eq. 47 (target spatial patterns) self.patterns[1, con_i, :n_targets] = ( - np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3)), - ) + np.real(C[..., n_seeds:, n_seeds:]) + @ (U_bar_bb @ np.expand_dims(beta, axis=3)) )[..., 0].T # Eq. 7 self.con_scores[con_i] = ( - np.einsum( - "ijk,ijk->ij", alpha, np.matmul(E, np.expand_dims(beta, axis=3))[..., 0] - ) + np.einsum("ijk,ijk->ij", alpha, (E @ np.expand_dims(beta, axis=3))[..., 0]) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2) ).T @@ -474,9 +468,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): def _compute_mim(self, E, seed_idcs, target_idcs, con_i): """Compute MIM (a.k.a. GIM if seeds == targets) for one connection.""" # Eq. 14 - self.con_scores[con_i] = ( - np.matmul(E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T - ) + self.con_scores[con_i] = (E @ E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T # Eq. 15 if len(seed_idcs) == len(target_idcs) and np.all( @@ -612,16 +604,14 @@ def _compute_cacoh(self, phis, C_ab, T_aa, T_bb): C_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_ab) # Eq. 9; T_aa/bb is sqrt(inv(real(C_aa/bb))) - D = np.matmul(T_aa, np.matmul(C_ab, T_bb)) + D = T_aa @ (C_ab @ T_bb) # Eq. 12 - a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] - b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + a = np.linalg.eigh(D @ D.transpose(0, 1, 3, 2))[1][..., -1] + b = np.linalg.eigh(D.transpose(0, 1, 3, 2) @ D)[1][..., -1] # Eq. 8 - numerator = np.einsum( - "ijk,ijk->ij", a, np.matmul(D, np.expand_dims(b, axis=3))[..., 0] - ) + numerator = np.einsum("ijk,ijk->ij", a, (D @ np.expand_dims(b, axis=3))[..., 0]) denominator = np.sqrt( np.einsum("ijk,ijk->ij", a, a) * np.einsum("ijk,ijk->ij", b, b) ) @@ -643,22 +633,22 @@ def _compute_patterns( ): """Compute CaCoh spatial patterns for the optimised phi.""" C_bar_ab = np.real(np.exp(-1j * np.expand_dims(phis, axis=(2, 3))) * C_bar_ab) - D = np.matmul(T_aa, np.matmul(C_bar_ab, T_bb)) - a = np.linalg.eigh(np.matmul(D, D.transpose(0, 1, 3, 2)))[1][..., -1] - b = np.linalg.eigh(np.matmul(D.transpose(0, 1, 3, 2), D))[1][..., -1] + D = T_aa @ (C_bar_ab @ T_bb) + a = np.linalg.eigh(D @ D.transpose(0, 1, 3, 2))[1][..., -1] + b = np.linalg.eigh(D.transpose(0, 1, 3, 2) @ D)[1][..., -1] # Eq. 7 rearranged - multiply both sides by sqrt(inv(real(C_aa/bb))) - alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds - beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets + alpha = T_aa @ np.expand_dims(a, axis=3) # filter for seeds + beta = T_bb @ np.expand_dims(b, axis=3) # filter for targets # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) # seed spatial patterns self.patterns[0, con_i, :n_seeds] = ( - np.matmul(np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) + np.real(C[..., :n_seeds, :n_seeds]) @ (U_bar_aa @ alpha) )[..., 0].T # target spatial patterns self.patterns[1, con_i, :n_targets] = ( - np.matmul(np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) + np.real(C[..., n_seeds:, n_seeds:]) @ (U_bar_bb @ beta) )[..., 0].T @@ -758,10 +748,10 @@ def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): C_bb = csd[..., n_seeds:, n_seeds:] C_ba = csd[..., n_seeds:, :n_seeds] - C_bar_aa = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) - C_bar_ab = np.matmul(U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) - C_bar_bb = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) - C_bar_ba = np.matmul(U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar_aa = U_bar_aa.transpose(1, 0) @ (C_aa @ U_bar_aa) + C_bar_ab = U_bar_aa.transpose(1, 0) @ (C_ab @ U_bar_bb) + C_bar_bb = U_bar_bb.transpose(1, 0) @ (C_bb @ U_bar_bb) + C_bar_ba = U_bar_bb.transpose(1, 0) @ (C_ba @ U_bar_aa) C_bar = np.append( np.append(C_bar_aa, C_bar_ab, axis=3), np.append(C_bar_ba, C_bar_bb, axis=3), @@ -874,18 +864,18 @@ def _whittle_lwr_recursion(self, G): # Perform recursion for k in np.arange(2, q + 1): - var_A = G_b[:, (r - 1) * n : r * n, :] - np.matmul( - A_f[:, :, k_f], G_b[:, k_b, :] + var_A = G_b[:, (r - 1) * n : r * n, :] - ( + A_f[:, :, k_f] @ G_b[:, k_b, :] ) - var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) + var_B = cov - (A_b[:, :, k_b] @ G_b[:, k_b, :]) AA_f = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose( 0, 2, 1 ) - var_A = G_f[:, (k - 1) * n : k * n, :] - np.matmul( - A_b[:, :, k_b], G_f[:, k_f, :] + var_A = G_f[:, (k - 1) * n : k * n, :] - ( + A_b[:, :, k_b] @ G_f[:, k_f, :] ) - var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) + var_B = cov - (A_f[:, :, k_f] @ G_f[:, k_f, :]) AA_b = np.linalg.solve(var_B, var_A.transpose(0, 2, 1)).transpose( 0, 2, 1 ) @@ -897,12 +887,8 @@ def _whittle_lwr_recursion(self, G): k_f = np.arange(k * n) k_b = np.arange(r * n, qn) - A_f[:, :, k_f] = np.dstack( - (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f) - ) - A_b[:, :, k_b] = np.dstack( - (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous)) - ) + A_f[:, :, k_f] = np.dstack((A_f_previous - (AA_f @ A_b_previous), AA_f)) + A_b[:, :, k_b] = np.dstack((AA_b, A_b_previous - (AA_b @ A_f_previous))) except np.linalg.LinAlgError as np_error: raise RuntimeError( "the autocovariance matrix is singular; check if your data is " @@ -910,7 +896,7 @@ def _whittle_lwr_recursion(self, G): "the rank of the seeds and targets" ) from np_error - V = cov - np.matmul(A_f, G_f) + V = cov - (A_f @ G_f) A_f = np.reshape(A_f, (t, n, n, q), order="F") return A_f, V @@ -954,11 +940,11 @@ def _iss_to_ugc(self, A, C, K, V, seeds, targets): H = self._iss_to_tf(A, C, K, z) # spectral transfer function V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) - HV = np.matmul(H, np.linalg.cholesky(V)) - S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 + HV = H @ np.linalg.cholesky(V) + S = HV @ HV.conj().transpose(0, 1, 3, 2) # Eq. 6 S_11 = S[np.ix_(freqs, times, targets, targets)] - HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) - HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) + HV_12 = H[np.ix_(freqs, times, targets, seeds)] @ V_22_1 + HVH = HV_12 @ HV_12.conj().transpose(0, 1, 3, 2) # Eq. 11 return np.real(np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) @@ -1013,7 +999,7 @@ def _partial_covar(self, V, seeds, targets): np.linalg.cholesky(V[np.ix_(times, targets, targets)]), V[np.ix_(times, targets, seeds)], ) - W = np.matmul(W.transpose(0, 2, 1), W) + W = W.transpose(0, 2, 1) @ W return V[np.ix_(times, seeds, seeds)] - W @@ -1033,9 +1019,7 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m): H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) for t in range(A.shape[0]): - H[t] = I_n + np.matmul( - C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t]) - ) + H[t] = I_n + (C[t] @ linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) return H