diff --git a/wonkyconn/features/quality_control_connectivity.py b/wonkyconn/features/quality_control_connectivity.py index d36c738..16a50ab 100644 --- a/wonkyconn/features/quality_control_connectivity.py +++ b/wonkyconn/features/quality_control_connectivity.py @@ -1,3 +1,4 @@ +from itertools import chain from typing import Iterable import numpy as np @@ -33,29 +34,34 @@ def calculate_qcfc( pd.DataFrame: The QCFC values between connectivity matrices and the metric. """ - metrics = pd.Series( + metrics = np.asarray( [ connectivity_matrix.metadata.get(metric_key, np.nan) for connectivity_matrix in connectivity_matrices ] ) - covariates = dmatrix("age + gender", data_frame) + covariates = np.asarray(dmatrix("age + gender", data_frame)) - connectivity_array = np.concatenate( - [ - connectivity_matrix.load()[:, :, np.newaxis] - for connectivity_matrix in tqdm( - connectivity_matrices, desc="Loading connectivity matrices", leave=False - ) - ], - axis=2, - ) - n, _, m = connectivity_array.shape + connectivity_arrays = [ + connectivity_matrix.load() + for connectivity_matrix in tqdm( + connectivity_matrices, desc="Loading connectivity matrices", leave=False + ) + ] + + # Ensure that all arrays are square and have the same shape + (n,) = set(chain.from_iterable(a.shape for a in connectivity_arrays)) + # Extract the lower triangles i, j = np.tril_indices(n, k=-1) - correlation = partial_correlation( - connectivity_array[i, j], metrics.to_numpy(), np.asarray(covariates) + connectivity_array = np.concatenate( + [a[i, j, np.newaxis] for a in connectivity_arrays], + axis=1, ) + + _, m = connectivity_array.shape + correlation = partial_correlation(connectivity_array, metrics, covariates) + p_value = correlation_p_value(correlation, m) qcfc = pd.DataFrame(dict(i=i, j=j, correlation=correlation, p_value=p_value))