diff --git a/pyproject.toml b/pyproject.toml index a34f41c..68e9864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,6 @@ warn_unused_ignores = true ignore_missing_imports = true module = [ "bids.*", - "wonkyconn._version", "h5py.*", "nibabel.*", "nilearn.*", @@ -105,8 +104,10 @@ module = [ "nilearn.interfaces.*", "nilearn.maskers.*", "nilearn.masking.*", + "patsy.*", "rich.*", - "scipy.ndimage.*", + "scipy.*", + "statsmodels.*", "templateflow.*", ] diff --git a/wonkyconn/atlas.py b/wonkyconn/atlas.py new file mode 100644 index 0000000..3735b5d --- /dev/null +++ b/wonkyconn/atlas.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +import nibabel as nib +import numpy as np +from numpy import typing as npt +import scipy + +from .logger import gc_log + + +@dataclass +class Atlas(ABC): + seg: str + image: nib.nifti1.Nifti1Image + + structure: npt.NDArray[np.bool_] = field( + default_factory=lambda: np.ones((3, 3, 3), dtype=bool) + ) + + @abstractmethod + def get_centroid_points(self) -> npt.NDArray[np.float64]: + raise NotImplementedError + + def get_centroids(self) -> npt.NDArray[np.float64]: + centroid_points = self.get_centroid_points() + centroid_coordinates = nib.affines.apply_affine( + self.image.affine, centroid_points + ) + return centroid_coordinates + + def get_distance_matrix(self) -> npt.NDArray[np.float64]: + centroids = self.get_centroids() + return scipy.spatial.distance.squareform( + scipy.spatial.distance.pdist(centroids) + ) + + @staticmethod + def create(seg: str, path: Path) -> "Atlas": + image = nib.nifti1.load(path) + + if image.ndim <= 3 or image.shape[3] == 1: + return DsegAtlas(seg, nib.funcs.squeeze_image(image)) + else: + return ProbsegAtlas(seg, image) + + +@dataclass +class DsegAtlas(Atlas): + def get_array(self) -> npt.NDArray[np.int64]: + return np.asarray(self.image.dataobj, dtype=np.int64) + + def check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None: + for i in range(1, array.max() + 1): + mask = array == i + _, num_features = scipy.ndimage.label(mask, structure=self.structure) + if num_features > 1: + gc_log.warning( + f'Atlas "{self.seg}" region {i} has more than a single connected component' + ) + + def get_centroid_points(self) -> npt.NDArray[np.float64]: + array = self.get_array() + self.check_single_connected_component(array) + return np.asarray( + scipy.ndimage.center_of_mass( + input=array > 0, labels=array, index=np.arange(1, array.max() + 1) + ) + ) + + +@dataclass +class ProbsegAtlas(Atlas): + epsilon: float = 1e-6 + + def get_centroid_point( + self, i: int, array: npt.NDArray[np.float64] + ) -> tuple[float, ...]: + mask = array > self.epsilon + _, num_features = scipy.ndimage.label(mask, structure=self.structure) + if num_features > 1: + gc_log.warning( + f'Atlas "{self.seg}" region {i} has more than a single connected component' + ) + return scipy.ndimage.center_of_mass(array) + + def get_centroid_points(self) -> npt.NDArray[np.float64]: + return np.asarray( + [ + self.get_centroid_point(i, image.get_fdata()) + for i, image in enumerate(nib.funcs.four_to_three(self.image)) + ] + ) diff --git a/wonkyconn/base.py b/wonkyconn/base.py new file mode 100644 index 0000000..0250c77 --- /dev/null +++ b/wonkyconn/base.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from numpy import typing as npt + + +@dataclass +class ConnectivityMatrix: + path: Path + metadata: dict[str, Any] + + def load(self) -> npt.NDArray[np.float64]: + return np.loadtxt(self.path, delimiter="\t", skiprows=1) diff --git a/wonkyconn/features/__init__.py b/wonkyconn/features/__init__.py index 3fda5d4..e69de29 100644 --- a/wonkyconn/features/__init__.py +++ b/wonkyconn/features/__init__.py @@ -1,18 +0,0 @@ -from .quality_control_connectivity import ( - qcfc, - partial_correlation, - significant_level, - calculate_median_absolute, -) -from .distance_dependency import get_atlas_pairwise_distance, get_centroid -from .network_modularity import louvain_modularity - -__all__ = [ - "qcfc", - "significant_level", - "partial_correlation", - "calculate_median_absolute", - "get_atlas_pairwise_distance", - "get_centroid", - "louvain_modularity", -] diff --git a/wonkyconn/features/base.py b/wonkyconn/features/base.py new file mode 100644 index 0000000..e69de29 diff --git a/wonkyconn/features/calculate_degrees_of_freedom.py b/wonkyconn/features/calculate_degrees_of_freedom.py index 2852e9e..d88cde5 100644 --- a/wonkyconn/features/calculate_degrees_of_freedom.py +++ b/wonkyconn/features/calculate_degrees_of_freedom.py @@ -1,125 +1,38 @@ """Calculate degree of freedom""" -import argparse -from pathlib import Path -import pandas as pd +import numpy as np +from ..base import ConnectivityMatrix -from fmriprep_denoise.dataset.timeseries import get_confounds -from fmriprep_denoise.dataset.fmriprep import ( - get_prepro_strategy, - fetch_fmriprep_derivative, - generate_movement_summary, -) +def calculate_degrees_of_freedom_loss( + connectivity_matrices: list[ConnectivityMatrix], +) -> float: + """ + Calculate the percent of degrees of freedom lost during denoising. + Parameters: + - bids_file (BIDSFile): The BIDS file for which to calculate the degrees of freedom. -def main(): - args = parse_args() - print(vars(args)) - dataset_name = args.dataset_name - fmriprep_specifier = args.specifier - fmriprep_path = Path(args.fmriprep_path) - participant_tsv = Path(args.participants_tsv) - output_root = Path(args.output_path) + Returns: + - float: The percentage of degrees of freedom lost. - output_root.mkdir(exist_ok=True, parents=True) - path_movement = Path( - output_root / f"dataset-{dataset_name}_desc-movement_phenotype.tsv" - ) + """ - path_dof = Path( - output_root / f"dataset-{dataset_name}_desc-confounds_phenotype.tsv" - ) + values: list[float] = [] - full_data = fetch_fmriprep_derivative( - dataset_name, participant_tsv, fmriprep_path, fmriprep_specifier - ) + for connectivity_matrix in connectivity_matrices: + metadata = connectivity_matrix.metadata - if dataset_name == "ds000030": - # read relevant files, - participants = full_data.phenotypic.copy() - mask_quality = participants["ghost_NoGhost"] == "No_ghost" - participants = participants[mask_quality].index.tolist() - subjects = [p.split("-")[-1] for p in participants] - full_data = fetch_fmriprep_derivative( - dataset_name, - participant_tsv, - fmriprep_path, - fmriprep_specifier, - subject=subjects, - ) - movement = generate_movement_summary(full_data) - movement = movement.sort_index() - movement.to_csv(path_movement, sep="\t") - print("Generate movement stats.") + total = 0 - subjects = [p.split("-")[-1] for p in movement.index] + confound_regressors = metadata.get("ConfoundRegressors", list()) + total += len(confound_regressors) - benchmark_strategies = get_prepro_strategy() - data_aroma = fetch_fmriprep_derivative( - dataset_name, - participant_tsv, - fmriprep_path, - fmriprep_specifier, - aroma=True, - subject=subjects, - ) - data = fetch_fmriprep_derivative( - dataset_name, - participant_tsv, - fmriprep_path, - fmriprep_specifier, - subject=subjects, - ) - info = {} - for strategy_name, parameters in benchmark_strategies.items(): - print(f"Denoising: {strategy_name}") - print(parameters) - func_data = data_aroma.func if "aroma" in strategy_name else data.func - for img in func_data: - sub = img.split("/")[-1].split("_")[0] - reduced_confounds, sample_mask = get_confounds( - strategy_name, parameters, img - ) - full_length = reduced_confounds.shape[0] - ts_length = full_length if sample_mask is None else len(sample_mask) - excised_vol = full_length - ts_length - excised_vol_pro = excised_vol / full_length - regressors = reduced_confounds.columns.tolist() - compcor = sum("comp_cor" in i for i in regressors) - high_pass = sum("cosine" in i for i in regressors) - total = len(regressors) - fixed = total - compcor if "compcor" in strategy_name else len(regressors) + total += metadata.get("NumberOfVolumesDiscardedByMotionScrubbing", 0) + total += metadata.get("NumberOfVolumesDiscardedByNonsteadyStatesDetector", 0) - if "aroma" in strategy_name: - path_aroma_ic = img.split("space-")[0] + "AROMAnoiseICs.csv" - with open(path_aroma_ic, "r") as f: - aroma = len(f.readline().split(",")) - total = fixed + aroma - else: - aroma = 0 + # TODO Support ICA-AROMA - if "scrub" in strategy_name: - total += excised_vol + values.append(total / connectivity_matrix.load().shape[0]) - stats = { - (strategy_name, "excised_vol"): excised_vol, - (strategy_name, "excised_vol_proportion"): excised_vol_pro, - (strategy_name, "high_pass"): high_pass, - (strategy_name, "fixed_regressors"): fixed, - (strategy_name, "compcor"): compcor, - (strategy_name, "aroma"): aroma, - (strategy_name, "total"): total, - (strategy_name, "full_length"): full_length, - } - if info.get(sub): - info[sub].update(stats) - else: - info[sub] = stats - confounds_stats = pd.DataFrame.from_dict(info, orient="index") - confounds_stats = confounds_stats.sort_index() - confounds_stats.to_csv(path_dof, sep="\t") - - -if __name__ == "__main__": - main() + return float(np.mean(values)) diff --git a/wonkyconn/features/derivatives.py b/wonkyconn/features/derivatives.py deleted file mode 100644 index 5a0a3e5..0000000 --- a/wonkyconn/features/derivatives.py +++ /dev/null @@ -1,155 +0,0 @@ -import json -import tarfile -from pathlib import Path -import pandas as pd -from nilearn.connectome import ConnectivityMeasure - -from fmriprep_denoise.visualization import tables - - -MOTION_QC_FILE = "motion_qc.json" -project_root = Path(__file__).parents[2] -inputs = project_root / "data" -group_info_column = {"ds000228": "Child_Adult", "ds000030": "diagnosis"} - - -def get_qc_criteria(strategy_name=None): - """ - Select an automatic quality control strategy and associated parameters. - - Parameter - --------- - - strategy_name : None or str - Name of the denoising strategy. See motion_qc.json. - Default to None, returns all strategies. - - Return - ------ - - dict - Motion quality control parameter to pass to filter subjects. - """ - motion_qc_file = Path(__file__).parent / MOTION_QC_FILE - with open(motion_qc_file, "r") as file: - qc_strategies = json.load(file) - - if isinstance(strategy_name, str) and strategy_name not in qc_strategies: - raise NotImplementedError( - f"Strategy '{strategy_name}' is not implemented. Select from the" - f"following: None, {[*qc_strategies]}" - ) - - if strategy_name is None: - print("No motion QC.") - return {"gross_fd": None, "fd_thresh": None, "proportion_thresh": None} - (f"Process strategy '{strategy_name}'.") - return qc_strategies[strategy_name] - - -def compute_connectome( - atlas, - extracted_path, - dataset, - fmriprep_version, - path_root, - file_pattern, - gross_fd=None, - fd_thresh=None, - proportion_thresh=None, -): - """Compute connectome of all valid data. - - Parameters - ---------- - - atlas : str - Atlas name matching keys in fmriprep_denoise.dataset.atlas.ATLAS_METADATA. - - extracted_path : pathlib.Path - Path object to where the time series were saved. - - dataset : str - Name of the dataset. - - fmriprep_version : str {fmrieprep-20.2.1lts, fmrieprep-20.2.5lts} - fMRIPrep version used for preporcessin. - - file_pattern : str - Details about the atlas and description of the file. - - Returns - ------- - pandas.DataFrame, pandas.DataFrame - Flatten connectomes and phenotypes. - """ - _, phenotype, _ = tables.get_descriptive_data( - dataset, fmriprep_version, path_root, gross_fd, fd_thresh, proportion_thresh - ) - participant_id = phenotype.index.tolist() - valid_ids, valid_ts = _load_valid_timeseries( - atlas, extracted_path, participant_id, file_pattern - ) - correlation_measure = ConnectivityMeasure( - kind="correlation", vectorize=True, discard_diagonal=True - ) - subject_conn = correlation_measure.fit_transform(valid_ts) - subject_conn = pd.DataFrame(subject_conn, index=valid_ids) - if subject_conn.shape[0] != phenotype.shape[0]: - print("take conjunction of the phenotype and connectome.") - idx = subject_conn.index.intersection(phenotype.index) - subject_conn, phenotype = ( - subject_conn.loc[idx, :], - phenotype.loc[idx, :], - ) - return subject_conn, phenotype - - -def check_extraction(input_path, extracted_path_root=None): - """Check if the tar.gz of a fmriprep dataset has been extracted. - - Parameters - ---------- - - input_path : pathlib.Path - Location of the tar.gz of the fMRIPrep output. - - extracted_path_root : None, pathlib.Path - Destination of the extraction. - - Returns - ------- - - pathlib.Path - Correct file path of the extracted dataset. - """ - dir_name = input_path.name.split(".tar")[0] - extracted_path_root = inputs if extracted_path_root is None else extracted_path_root - - extracted_path = extracted_path_root / dir_name - - if not extracted_path.is_dir() and input_path.is_file(): - print(f"Cannot file extracted file at {extracted_path}. " "Extracting...") - with tarfile.open(input_path, "r:gz") as tar: - tar.extractall(extracted_path_root) - return extracted_path - - -def _load_valid_timeseries(atlas, extracted_path, participant_id, file_pattern): - """Load time series from tsv file.""" - valid_ids, valid_ts = [], [] - for subject in participant_id: - subject_path = extracted_path / f"atlas-{atlas}" / subject - file_path = list( - subject_path.glob(f"{subject}_*_{file_pattern}_timeseries.tsv") - ) - if len(file_path) > 1: - raise ValueError("Found more than one valid file." f"{file_path}") - file_path = file_path[0] - if file_path.stat().st_size > 1: - ts = pd.read_csv(file_path, sep="\t", header=0) - valid_ids.append(subject) - valid_ts.append(ts.values) - else: - continue - return valid_ids, valid_ts diff --git a/wonkyconn/features/distance_dependence.py b/wonkyconn/features/distance_dependence.py new file mode 100644 index 0000000..07966c2 --- /dev/null +++ b/wonkyconn/features/distance_dependence.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import numpy as np +import pandas as pd + +from wonkyconn.atlas import Atlas +from scipy.stats import spearmanr + + +def calculate_distance_dependence(qcfc: pd.DataFrame, atlas: Atlas) -> float: + distance_matrix = atlas.get_distance_matrix() + i, j = map(list, zip(*qcfc.index)) + distance_vector = distance_matrix[i, j] + r, _ = spearmanr(distance_vector, qcfc.correlation) + return r diff --git a/wonkyconn/features/distance_dependency.py b/wonkyconn/features/distance_dependency.py deleted file mode 100644 index 9ccbc94..0000000 --- a/wonkyconn/features/distance_dependency.py +++ /dev/null @@ -1,132 +0,0 @@ -from pathlib import Path - -import numpy as np -import pandas as pd -from scipy.spatial import distance - -from fmriprep_denoise.dataset.atlas import fetch_atlas_path, ATLAS_METADATA -from nilearn.image import index_img -from nilearn.plotting import find_probabilistic_atlas_cut_coords - - -def get_atlas_pairwise_distance(atlas_name, dimension): - """ - Compute pairwise distance of nodes in the atlas. - - Parameters - ---------- - - atlas_name : str - Atlas name. Must be a key in ATLAS_METADATA. - - dimension : str or int - Atlas dimension. - - Returns - ------- - - pandas.DataFrame - Node ID paire and the distnace. - - """ - if atlas_name == "gordon333": - file_dist = "atlas-gordon333_nroi-333_desc-distance.tsv" - return pd.read_csv(Path(__file__).parent / "data" / file_dist, sep="\t") - centroids = get_centroid(atlas_name, dimension) - pairwise_distance = distance.cdist(centroids, centroids) - labels = range(1, pairwise_distance.shape[0] + 1) - - # Transform into pandas dataframe - pairwise_distance = pd.DataFrame(pairwise_distance, index=labels, columns=labels) - # keep lower triangle and flatten match nilearn.connectome.sym_matrix_to_vec - lower_mask = np.tril(np.ones(pairwise_distance.shape), k=-1).astype(np.bool) - pairwise_distance = pairwise_distance.where(lower_mask) - pairwise_distance = pairwise_distance.stack().reset_index() - pairwise_distance.columns = ["row", "column", "distance"] - return pairwise_distance - - -def get_centroid(atlas_name, dimension): - """ - Load parcel centroid for each atlas. - - Parameters - ---------- - - atlas_name : str - Atlas name. Must be a key in ATLAS_METADATA. - - dimension : str or int - Atlas dimension. - - Returns - ------- - - pandas.DataFrame - Centroid coordinates. - """ - if atlas_name not in ATLAS_METADATA: - raise NotImplementedError("Selected atlas is not supported.") - - if atlas_name == "schaefer7networks": - url = ( - "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/master/" - "stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/" - "Parcellations/MNI/Centroid_coordinates/" - f"Schaefer2018_{dimension}Parcels_7Networks_order_FSLMNI152_2mm" - ".Centroid_RAS.csv" - ) - return pd.read_csv(url).loc[:, ["R", "S", "A"]].values - if atlas_name == "gordon333": - file_dist = "atlas-gordon333_nroi-333_desc-distance.tsv" - return pd.read_csv(Path(__file__).parent / "data" / file_dist, sep="\t") - - if atlas_name == "mist": - current_atlas = fetch_atlas_path(atlas_name, dimension) - return current_atlas.labels.loc[:, ["x", "y", "z"]].values - if atlas_name == "difumo": - # find files - p = ( - Path(__file__).parent - / "data" - / f"atlas-DiFuMo_nroi-{dimension}_desc-distance.tsv" - ) - if not p.is_file(): - get_difumo_centroids(dimension) - return pd.read_csv(p, sep="\t").loc[:, ["x", "y", "z"]].values - - -def get_difumo_centroids(d): - """ - Compute difumo centroids. - - Parameters - ---------- - - d : int - Atlas dimension. - - """ - current_atlas = fetch_atlas_path("difumo", d) - if d > 256: - # split the map and work on individual maps - n_roi = current_atlas.labels.shape[0] - centroid = [] - for i in range(0, n_roi, 200): - if i == 0: - start = i - continue - img = index_img(current_atlas.maps, slice(start, i)) - c = find_probabilistic_atlas_cut_coords(img) - centroid.append(c) - start = i - img = index_img(current_atlas.maps, slice(start, n_roi)) - c = find_probabilistic_atlas_cut_coords(img) - centroid.append(c) - centroid = np.vstack(centroid) - else: - centroid = find_probabilistic_atlas_cut_coords(current_atlas.maps) - centroid = pd.DataFrame(centroid, columns=["x", "y", "z"]) - centroid = pd.concat([current_atlas.labels, centroid], axis=1) - output = Path(__file__).parent / "data" / f"atlas-DiFuMo_nroi-{d}_desc-distance.tsv" - centroid.to_csv(output, sep="\t") diff --git a/wonkyconn/features/network_modularity.py b/wonkyconn/features/network_modularity.py deleted file mode 100644 index 5ece82a..0000000 --- a/wonkyconn/features/network_modularity.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -from nilearn.connectome import vec_to_sym_matrix -from bct import modularity_louvain_und_sign -from math import sqrt - - -def louvain_modularity(vect): - """ - Wrapper for `modularity_louvain_und_sign` from the Brain Connectivity - tool box. - - Parameters - ---------- - - vect : np.ndarray - Flatten connetome. - - Returns - ------- - np.ndarray - modularity (qtype dependent) - - """ - vect = np.array(vect) - n = vect.shape[-1] - n_columns = int((sqrt(8 * n + 1) - 1.0) / 2) + 1 # no diagnal - - full_graph = vec_to_sym_matrix(vect, diagonal=np.ones(n_columns)) - _, modularity = compute_commuity(full_graph, num_opt=100) - return modularity - - -def compute_commuity(G, num_opt=100): - """ - Compute community affiliation vector. Wrapper for - `modularity_louvain_und_sign` from the Brain Connectivity tool box. - - Parameters - ---------- - - G : np.ndarray - Symmetric Graph - - num_opt : int - Number of Louvain optimizations to perform - - Return - ------ - - np.ndarray - community affiliation vector - - np.ndarray - modularity (qtype dependent) - """ - CI = np.empty((G.shape[0], num_opt)) - Qs = np.empty((num_opt)) - for i in range(num_opt): - P, Q = modularity_louvain_und_sign(G) - CI[:, i] = P - Qs[i] = Q - return CI, Qs.mean() diff --git a/wonkyconn/features/quality_control_connectivity.py b/wonkyconn/features/quality_control_connectivity.py index 6441593..15f0a3a 100644 --- a/wonkyconn/features/quality_control_connectivity.py +++ b/wonkyconn/features/quality_control_connectivity.py @@ -1,46 +1,27 @@ -import pandas as pd +from typing import Iterable, NamedTuple + import numpy as np -from scipy import stats, linalg +import pandas as pd +import scipy +from numpy import typing as npt +from patsy.highlevel import dmatrix +from tqdm.auto import tqdm -from statsmodels.stats import multitest +from ..base import ConnectivityMatrix +from statsmodels.stats.multitest import multipletests -def calculate_median_absolute(x): - """Calculate Absolute median value""" - return x.abs().median() +class PartialCorrelationResult(NamedTuple): + correlation: float + p_value: float -def significant_level(x, alpha=0.05, correction=None): - """ - Apply FDR correction to a pandas.Series p-value object. - Parameters - ---------- - - x : pandas.Series - Uncorrected p-values. - - alpha : float - Alpha threshold. - - method : None or str - Default as None for no multiple comparison - Mutiple comparison methods. - See statsmodels.stats.multitest.multipletests - - Returns - ------- - ndarray, boolean - Mask for data passing multiple comparison test. - """ - if isinstance(correction, str): - res, _, _, _ = multitest.multipletests(x, alpha=alpha, method=correction) - else: - res = x < 0.05 - return res - - -def partial_correlation(x, y, cov=None): +def partial_correlation( + x: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + cov: npt.NDArray[np.float64] | None = None, +) -> PartialCorrelationResult: """A minimal implementation of partial correlation. Parameters @@ -57,18 +38,22 @@ def partial_correlation(x, y, cov=None): dict Correlation and p-value. """ - if isinstance(cov, np.ndarray): - beta_cov_x = linalg.lstsq(cov, x)[0] - beta_cov_y = linalg.lstsq(cov, y)[0] + if cov is not None: + beta_cov_x = np.linalg.lstsq(cov, x)[0] + beta_cov_y = np.linalg.lstsq(cov, y)[0] resid_x = x - cov.dot(beta_cov_x) resid_y = y - cov.dot(beta_cov_y) - r, p_val = stats.pearsonr(resid_x, resid_y) + r, p_value = scipy.stats.pearsonr(resid_x, resid_y) else: - r, p_val = stats.pearsonr(x, y) - return {"correlation": r, "pvalue": p_val} + r, p_value = scipy.stats.pearsonr(x, y) + return PartialCorrelationResult(r, p_value) -def qcfc(movement, connectomes, covarates=None): +def calculate_qcfc( + data_frame: pd.DataFrame, + connectivity_matrices: Iterable[ConnectivityMatrix], + metric_key: str = "MeanFramewiseDisplacement", +) -> pd.DataFrame: """ metric calculation: quality control / functional connectivity @@ -77,46 +62,98 @@ def qcfc(movement, connectomes, covarates=None): QC-FC relationships were calculated as partial correlations that accounted for participant age and sex + Parameters: + data_frame (pd.DataFrame): The data frame containing the covariates "age" and "gender". It needs to have one row for each connectivity matrix. + connectivity_matrices (Iterable[ConnectivityMatrix]): The connectivity matrices to calculate QCFC for. + metric_key (str, optional): The key of the metric to use for QCFC calculation. Defaults to "MeanFramewiseDisplacement". + + Returns: + pd.DataFrame: The QCFC values between connectivity matrices and the metric. + + """ + metrics = pd.Series( + [ + connectivity_matrix.metadata.get(metric_key, np.nan) + for connectivity_matrix in connectivity_matrices + ] + ) + covariates = dmatrix("age + gender", data_frame) + + connectivity_array = np.concatenate( + [ + connectivity_matrix.load()[:, :, np.newaxis] + for connectivity_matrix in tqdm( + connectivity_matrices, desc="Loading connectivity matrices", leave=False + ) + ], + axis=2, + ) + n, _, _ = connectivity_array.shape + + indices = list(zip(*np.tril_indices(n, k=-1), strict=True)) + + records = list() + for i, j in tqdm(indices, desc="Calculating QC-FC", leave=False): + record = partial_correlation( + connectivity_array[i, j], metrics, covariates + )._asdict() + record["i"] = i + record["j"] = j + records.append(record) + + qcfc = pd.DataFrame.from_records(records, index=["i", "j"]) + return qcfc + + +def calculate_median_absolute(x: pd.Series) -> float: + """Calculate Absolute median value""" + return x.abs().median() + + +def significant_level( + x: pd.Series, alpha: float = 0.05, correction: str | None = None +) -> npt.NDArray[np.bool_]: + """ + Apply FDR correction to a pandas.Series p-value object. + Parameters ---------- - movement: pandas.DataFrame - Containing header: ["mean_framewise_displacement"] + x : pandas.Series + Uncorrected p-values. - connectomes: pandas.DataFrame - Flattened connectome of a whole dataset. - Index: subjets - Columns: ROI-ROI pairs + alpha : float + Alpha threshold. - covariates: pandas.DataFrame or None - Age", Gender + method : None or str + Default as None for no multiple comparison + Mutiple comparison methods. + See statsmodels.stats.multitest.multipletests Returns ------- + ndarray, boolean + Mask for data passing multiple comparison test. + """ + if isinstance(correction, str): + res, _, _, _ = multipletests(x, alpha=alpha, method=correction) + else: + res = x < 0.05 + return res - List of dict - QC/FC per connectome edge. + +def calculate_qcfc_percentage(qcfc: pd.DataFrame) -> float: + """ + Calculate the percentage of significant QC-FC relationships. + + Parameters + ---------- + qcfc : pd.DataFrame + The QC-FC values between connectivity matrices and the metric. + + Returns + ------- + float + The percentage of significant QC-FC relationships. """ - # concatenate information to match by subject id - edge_ids = connectomes.columns.tolist() - connectomes = pd.concat((connectomes, movement), axis=1) - - if covarates is not None: - covarates = covarates.apply(stats.zscore) - cov_names = covarates.columns - connectomes = pd.concat((connectomes, covarates), axis=1) - - # drop subject with no edge value - connectomes = connectomes.dropna(axis=0) - - qcfc_edge = [] - for edge_id in edge_ids: - # QC-FC - metric = partial_correlation( - connectomes[edge_id].values, - connectomes["mean_framewise_displacement"].values, - connectomes[cov_names].values, - ) - qcfc_edge.append(metric) - - return qcfc_edge + return 100 * significant_level(qcfc.p_value).mean() diff --git a/wonkyconn/file_index/__init__.py b/wonkyconn/file_index/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wonkyconn/file_index/base.py b/wonkyconn/file_index/base.py new file mode 100644 index 0000000..3cda4ec --- /dev/null +++ b/wonkyconn/file_index/base.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +from __future__ import annotations + +from collections import defaultdict +from hashlib import sha1 +from pathlib import Path +from typing import Container, Mapping + + +def create_defaultdict_of_set() -> defaultdict[str, set[Path]]: + return defaultdict(set) + + +class FileIndex: + def __init__(self) -> None: + self.paths_by_tags: dict[str, dict[str, set[Path]]] = defaultdict( + create_defaultdict_of_set + ) + self.tags_by_paths: dict[Path, dict[str, str]] = defaultdict(dict) + + @property + def hexdigest(self) -> str: + """ + A forty character hash code of the paths in the index, obtained using the `sha1` algorithm. + """ + hash_algorithm = sha1() + + for path in sorted(self.tags_by_paths.keys(), key=str): + path_bytes = str(path).encode() + hash_algorithm.update(path_bytes) + + return hash_algorithm.hexdigest() + + def get(self, **tags: str | None) -> set[Path]: + """ + Find all paths that match the query tags. + + Args: + **tags: A dictionary of tags to match against. The keys are the tag names + and the values are the tag values. Pass a value of `None` to + select paths without that tag. + + Returns: + A set of `Path` objects that match all the specified tags. + """ + + matches: set[Path] | None = None + for key, value in tags.items(): + if key not in self.paths_by_tags: + return set() + + values = self.paths_by_tags[key] + if value is not None: + if value not in values: + return set() + paths: set[Path] = values[value] + else: + paths_in_index = set(self.tags_by_paths.keys()) + paths = paths_in_index.difference(*values.values()) + + if matches is not None: + matches &= paths + else: + matches = paths.copy() + + if not matches: + return set() + else: + return matches + + def get_tags(self, path: Path) -> Mapping[str, str | None]: + if path in self.tags_by_paths: + return self.tags_by_paths[path] + else: + return dict() + + def get_tag_value(self, path: Path, key: str) -> str | None: + return self.get_tags(path).get(key) + + def set_tag_value(self, path: Path, key: str, value: str) -> None: + # remove previous value + if self.get_tag_value(path, key) is not None: + previous_value = self.tags_by_paths[path].pop(key) + self.paths_by_tags[key][previous_value].remove(path) + if value is not None: + self.tags_by_paths[path][key] = value + self.paths_by_tags[key][value].add(path) + + def get_tag_mapping(self, key: str) -> Mapping[str, set[Path]]: + return self.paths_by_tags[key] + + def get_tag_values(self, key: str, paths: set[Path] | None = None) -> set[str]: + if key not in self.paths_by_tags: + return set() + + if paths is None: + return set(self.paths_by_tags[key].keys()) + + return set( + k for k, v in self.paths_by_tags[key].items() if not paths.isdisjoint(v) + ) + + def get_tag_groups( + self, keys: Container[str], paths: set[Path] | None = None + ) -> list[Mapping[str, str]]: + from pyrsistent import pmap + + if paths is None: + paths = set(self.tags_by_paths.keys()) + + groups: set[Mapping[str, str]] = { + pmap({k: v for k, v in self.tags_by_paths[path].items() if k in keys}) + for path in paths + } + + return [dict(group) for group in groups] + + def get_associated_paths(self, path: Path, **tags: str) -> set[Path]: + matches = self.get(**tags) + for key, value in self.get_tags(path).items(): + if key == "extension": + continue + valid = self.get(**{key: value}) | self.get(**{key: None}) + matches &= valid + return matches diff --git a/wonkyconn/file_index/bids.py b/wonkyconn/file_index/bids.py new file mode 100644 index 0000000..96e01a9 --- /dev/null +++ b/wonkyconn/file_index/bids.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +import json +from pathlib import Path +from typing import Any, MutableSequence + +from .base import FileIndex + + +def split_ext(path: str | Path) -> tuple[str, str]: + """Splits filename and extension (.gz safe) + >>> split_ext('some/file.nii.gz') + ('file', '.nii.gz') + >>> split_ext('some/other/file.nii') + ('file', '.nii') + >>> split_ext('otherext.tar.gz') + ('otherext', '.tar.gz') + >>> split_ext('text.txt') + ('text', '.txt') + + Adapted from niworkflows + """ + from pathlib import Path + + if isinstance(path, str): + path = Path(path) + + name = str(path.name) + + safe_name = name + for compound_extension in [".gz", ".xz"]: + safe_name = safe_name.removesuffix(compound_extension) + + stem = Path(safe_name).stem + return stem, name[len(stem) :] + + +def parse(path: Path) -> dict[str, str] | None: + """ + Parses a BIDS-formatted filename and returns a dictionary of its tags. + + Args: + path (Path): The path to the file to parse. + + Returns: + dict[str, str] | None: A dictionary of the file's BIDS tags, or None if the + file is not a valid BIDS-formatted file. + """ + if path.is_dir(): + return None # Skip directories + + stem, extension = split_ext(path) + if stem.startswith("."): + return None # Skip hidden files + + tokens = stem.split("_") + + # Parse tokens + keys: MutableSequence[str | None] = list() + values: MutableSequence[str] = list() + for token in tokens: + if "-" in token: # A bids tag + key: str | None = token.split("-")[0] + if key is None: + continue + keys.append(key) + values.append(token[len(key) + 1 :]) + + else: # A suffix + keys.append(None) + values.append(token) + + # Extract bids suffixes + suffixes: list[str] = list() + while keys and keys[-1] is None: + keys.pop(-1) + suffixes.insert(0, values.pop(-1)) + + # Merge other suffixes with their preceding tag value + for i, (key, value) in enumerate(zip(keys, values, strict=False)): + if i < 1: + continue + if key is None: + values[i - 1] += f"_{value}" + + # Build tags + tags = dict( + suffix="_".join(suffixes), + ) + if extension: + tags["extension"] = extension + parent_name = Path(str(path.parent)).name + if parent_name in ("anat", "func", "fmap"): + tags["datatype"] = parent_name + for key, value in zip(keys, values, strict=False): + if key is not None: + tags[key] = value + return tags + + +class BIDSIndex(FileIndex): + def put(self, root: Path) -> None: + for path in root.glob("**/*"): + tags = parse(path) + + if tags is None: + continue # not a valid path + + for key, value in tags.items(): + self.paths_by_tags[key][value].add(path) + + self.tags_by_paths[path] = tags + + def get_metadata(self, path: Path) -> dict[str, Any]: + metadata: dict[str, Any] = dict() + + for metadata_path in self.get_associated_paths(path, extension=".json"): + with metadata_path.open("r") as file: + metadata.update(json.load(file)) + + return metadata diff --git a/wonkyconn/logger.py b/wonkyconn/logger.py index 6a9e9b2..c2b70af 100644 --- a/wonkyconn/logger.py +++ b/wonkyconn/logger.py @@ -19,3 +19,19 @@ def gc_logger(log_level: str = "INFO") -> logging.Logger: ) return logging.getLogger("giga_connectome") + + +gc_log = gc_logger() + + +def set_verbosity(verbosity: int | list[int]) -> None: + if isinstance(verbosity, list): + verbosity = verbosity[0] + if verbosity == 0: + gc_log.setLevel("ERROR") + elif verbosity == 1: + gc_log.setLevel("WARNING") + elif verbosity == 2: + gc_log.setLevel("INFO") + elif verbosity == 3: + gc_log.setLevel("DEBUG") diff --git a/wonkyconn/run.py b/wonkyconn/run.py index 30aab11..2fd6179 100644 --- a/wonkyconn/run.py +++ b/wonkyconn/run.py @@ -5,7 +5,7 @@ from typing import Sequence from . import __version__ -from .workflow import workflow +from .workflow import workflow, gc_log def global_parser() -> argparse.ArgumentParser: @@ -43,6 +43,7 @@ def global_parser() -> argparse.ArgumentParser: "--phenotypes", type=str, help="Path to the phenotype file that has the columns `participant_id`, `gender` coded as `M` and `F` and `age` in years.", + required=True, ) parser.add_argument( "--seg-to-atlas", @@ -54,6 +55,7 @@ def global_parser() -> argparse.ArgumentParser: help="Specify the atlas file to use for a segmentation label in the data", ) + parser.add_argument("--debug", action="store_true", default=False) parser.add_argument( "--verbosity", help=""" @@ -69,11 +71,17 @@ def global_parser() -> argparse.ArgumentParser: def main(argv: None | Sequence[str] = None) -> None: - """Entry point.""" parser = global_parser() args = parser.parse_args(argv) - workflow(args) + try: + workflow(args) + except Exception as e: + gc_log.exception("Exception: %s", e, exc_info=True) + if args.debug: + import pdb + + pdb.post_mortem() if __name__ == "__main__": diff --git a/wonkyconn/tests/test_atlas.py b/wonkyconn/tests/test_atlas.py new file mode 100644 index 0000000..0c5f4be --- /dev/null +++ b/wonkyconn/tests/test_atlas.py @@ -0,0 +1,81 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from pkg_resources import resource_filename +import scipy +from nilearn.plotting import find_probabilistic_atlas_cut_coords +from templateflow.api import get as get_template + +from wonkyconn.atlas import Atlas + + +def test_dseg_atlas() -> None: + url = ( + "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/master/" + "stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/" + "Parcellations/MNI/Centroid_coordinates/" + "Schaefer2018_400Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.csv" + ) + _centroids = pd.read_csv(url).loc[:, ["R", "A", "S"]].values + _distance_matrix = scipy.spatial.distance.squareform( + scipy.spatial.distance.pdist(_centroids) + ) + + path = get_template( + template="MNI152NLin6Asym", + atlas="Schaefer2018", + desc="400Parcels7Networks", + resolution=2, + suffix="dseg", + extension=".nii.gz", + ) + assert isinstance(path, Path) + + atlas = Atlas.create("Schaefer2018400Parcels7Networks", path) + centroids = atlas.get_centroids() + + distance = np.sqrt(np.square(_centroids - centroids).sum(axis=1)) + assert distance.mean() < 2 # mm + + distance_matrix = atlas.get_distance_matrix() + assert np.abs(_distance_matrix - distance_matrix).mean() < 1 # mm + + +def _get_centroids(path: Path): + """ + Compute centroids. + + Parameters + ---------- + + d : int + Atlas dimension. + + """ + centroids = find_probabilistic_atlas_cut_coords(path) + return centroids + + +def test_probseg_atlas() -> None: + path = Path( + resource_filename( + "wonkyconn", + "data/test_data/tpl-MNI152NLin2009cAsym_res-03_atlas-DiFuMo_desc-64dimensionsSegmented_probseg.nii.gz", + ) + ) + assert isinstance(path, Path) + + _centroids = _get_centroids(path) + _distance_matrix = scipy.spatial.distance.squareform( + scipy.spatial.distance.pdist(_centroids) + ) + + atlas = Atlas.create("DiFuMo256dimensions", path) + centroids = atlas.get_centroids() + + distance = np.sqrt(np.square(_centroids - centroids).sum(axis=1)) + assert distance.mean() < 4 # mm + + distance_matrix = atlas.get_distance_matrix() + assert np.abs(_distance_matrix - distance_matrix).mean() < 3 # mm diff --git a/wonkyconn/tests/test_cli.py b/wonkyconn/tests/test_cli.py index 85c79ea..1ac7752 100644 --- a/wonkyconn/tests/test_cli.py +++ b/wonkyconn/tests/test_cli.py @@ -5,13 +5,19 @@ from pathlib import Path import json +import re +from shutil import copyfile +import numpy as np import pytest from pkg_resources import resource_filename import pandas as pd +import scipy +from tqdm.auto import tqdm -from giga_connectome import __version__ -from giga_connectome.run import main +from wonkyconn import __version__ +from wonkyconn.run import global_parser, main +from wonkyconn.workflow import workflow def test_version(capsys): @@ -29,65 +35,116 @@ def test_help(capsys): except SystemExit: pass captured = capsys.readouterr() - assert "Generate denoised timeseries" in captured.out + assert "show program's version number and exit" in captured.out -@pytest.mark.smoke -def test_smoke(tmp_path, capsys): - bids_dir = resource_filename( - "giga_connectome", - "data/test_data/ds000017-fmriprep22.0.1-downsampled-nosurface", - ) - output_dir = tmp_path / "output" - work_dir = tmp_path / "output/work" - - if not Path(output_dir).exists: - Path(output_dir).mkdir() - - main( - [ - "--participant_label", - "1", - "-w", - str(work_dir), - "--atlas", - "Schaefer20187Networks", - "--denoise-strategy", - "simple", - "--reindex-bids", - "--calculate-intranetwork-average-correlation", - str(bids_dir), - str(output_dir), - "participant", - ] - ) +def _copy_file(path: Path, new_path: Path, sub: str) -> None: + new_path = Path(re.sub(r"sub-\d+", f"sub-{sub}", str(new_path))) + new_path.parent.mkdir(parents=True, exist_ok=True) - output_folder = output_dir / "sub-1" / "ses-timepoint1" / "func" + if "relmat" in path.name and path.suffix == ".tsv": + relmat = pd.read_csv(path, sep="\t") + (n,) = set(relmat.shape) - base = ( - "sub-1_ses-timepoint1_task-probabilisticclassification" - "_run-01_space-MNI152NLin2009cAsym_res-2" - "_atlas-Schaefer20187Networks" - ) + array = scipy.spatial.distance.squareform(relmat.to_numpy() - np.eye(n)) + np.random.shuffle(array) + + new_array = scipy.spatial.distance.squareform(array) + np.eye(n) - relmat_file = output_folder / ( - base - + "_meas-PearsonCorrelation" - + "_desc-100Parcels7NetworksSimple_relmat.tsv" + new_relmat = pd.DataFrame(new_array, columns=relmat.columns) + new_relmat.to_csv(new_path, sep="\t", index=False) + elif "timeseries" in path.name and path.suffix == ".json": + with open(path, "r") as f: + content = json.load(f) + content["MeanFramewiseDisplacement"] += np.random.uniform(0, 1) + with open(new_path, "w") as f: + json.dump(content, f) + else: + copyfile(path, new_path) + + +@pytest.mark.smoke +def test_smoke(tmp_path: Path): + data_path = Path( + resource_filename( + "wonkyconn", "data/test_data/connectome_Schaefer20187Networks_dev" + ) ) - assert relmat_file.exists() - relmat = pd.read_csv(relmat_file, sep="\t") - assert len(relmat) == 100 - - json_file = relmat_file = output_folder / (base + "_timeseries.json") - assert json_file.exists() - with open(json_file, "r") as f: - content = json.load(f) - assert content.get("SamplingFrequency") == 0.5 - - timeseries_file = relmat_file = output_folder / ( - base + "_desc-100Parcels7NetworksSimple_timeseries.tsv" + + bids_dir = tmp_path / "bids" + bids_dir.mkdir() + output_dir = tmp_path / "output" + output_dir.mkdir() + + subjects = ["2", "3", "4", "5", "6", "7"] + + paths = list(data_path.glob("**/*")) + for path in tqdm(paths, desc="Generating test data"): + if not path.is_file(): + continue + for sub in subjects: + _copy_file(path, bids_dir / path.relative_to(data_path), str(sub)) + + phenotypes = pd.DataFrame( + dict( + participant_id=subjects, + age=np.random.uniform(18, 80, len(subjects)), + gender=np.random.choice(["m", "f"], len(subjects)), + ) ) - assert timeseries_file.exists() - timeseries = pd.read_csv(timeseries_file, sep="\t") - assert len(timeseries.columns) == 100 + phenotypes_path = bids_dir / "participants.tsv" + phenotypes.to_csv(phenotypes_path, sep="\t", index=False) + + seg_to_atlas_args: list[str] = [] + for n in [100, 200, 300, 400, 500, 600, 800]: + seg_to_atlas_args.append("--seg-to-atlas") + seg_to_atlas_args.append(f"Schaefer20187Networks{n}Parcels") + dseg_path = ( + data_path + / "atlases" + / "sub-1" + / "func" + / f"sub-1_seg-Schaefer20187Networks{n}Parcels_dseg.nii.gz" + ) + seg_to_atlas_args.append(str(dseg_path)) + + parser = global_parser() + argv = [ + "--phenotypes", + str(phenotypes_path), + *seg_to_atlas_args, + str(bids_dir), + str(output_dir), + "group", + ] + args = parser.parse_args(argv) + + workflow(args) + + # output_folder = output_dir / "sub-1" / "ses-timepoint1" / "func" + + # base = ( + # "sub-1_ses-timepoint1_task-probabilisticclassification" + # "_run-01_space-MNI152NLin2009cAsym_res-2" + # "_atlas-Schaefer20187Networks" + # ) + + # relmat_file = output_folder / ( + # base + "_meas-PearsonCorrelation" + "_desc-100Parcels7NetworksSimple_relmat.tsv" + # ) + # assert relmat_file.exists() + # relmat = pd.read_csv(relmat_file, sep="\t") + # assert len(relmat) == 100 + + # json_file = relmat_file = output_folder / (base + "_timeseries.json") + # assert json_file.exists() + # with open(json_file, "r") as f: + # content = json.load(f) + # assert content.get("SamplingFrequency") == 0.5 + + # timeseries_file = relmat_file = output_folder / ( + # base + "_desc-100Parcels7NetworksSimple_timeseries.tsv" + # ) + # assert timeseries_file.exists() + # timeseries = pd.read_csv(timeseries_file, sep="\t") + # assert len(timeseries.columns) == 100 diff --git a/wonkyconn/visualization/utils.py b/wonkyconn/visualization/utils.py index f702ee6..b22e1df 100644 --- a/wonkyconn/visualization/utils.py +++ b/wonkyconn/visualization/utils.py @@ -47,8 +47,7 @@ def repo2data_path(): def get_data_root(): """Get motion metric data path root.""" - default_path = Path(__file__).parents[2] / \ - "data" / "fmriprep-denoise-benchmark" + default_path = Path(__file__).parents[2] / "data" / "fmriprep-denoise-benchmark" if not (default_path / "data_requirement.json").exists(): default_path = repo2data_path() return default_path @@ -406,32 +405,6 @@ def _get_qcfc_metric(file_path, metric, group): return qcfc_per_edge -def _get_corr_distance(files_qcfc, labels, group): - """Load correlation of QC/FC with node distances.""" - qcfc_per_edge = _get_qcfc_metric(files_qcfc, metric="correlation", group=group) - corr_distance = [] - for df, label in zip(qcfc_per_edge, labels): - atlas_name = label.split("atlas-")[-1].split("_")[0] - dimension = label.split("nroi-")[-1].split("_")[0] - pairwise_distance = get_atlas_pairwise_distance(atlas_name, dimension) - cols = df.columns - df, _ = spearmanr(pairwise_distance.iloc[:, -1], df) - df = pd.DataFrame(df[1:, 0], index=cols, columns=[label]) - corr_distance.append(df) - - if len(corr_distance) == 1: - corr_distance = corr_distance[0] - else: - corr_distance = pd.concat(corr_distance, axis=1) - - return { - "data": corr_distance.T, - "order": list(GRID_LOCATION.values()), - "title": "Correlation between\nnodewise distance and QC-FC", - "label": "Pearson's correlation", - } - - def _corr_modularity_motion(movement, files_network, labels): """Load correlation of mean FD with network modularity.""" mean_corr, mean_modularity = [], [] diff --git a/wonkyconn/workflow.py b/wonkyconn/workflow.py index ac3ae8a..cb958af 100644 --- a/wonkyconn/workflow.py +++ b/wonkyconn/workflow.py @@ -2,98 +2,103 @@ Process fMRIPrep outputs to timeseries based on denoising strategy. """ -from __future__ import annotations - import argparse - -from giga_connectome.mask import generate_gm_mask_atlas -from giga_connectome.atlas import load_atlas_setting -from giga_connectome.denoise import get_denoise_strategy -from giga_connectome import methods, utils -from giga_connectome.postprocess import run_postprocessing_dataset - -from giga_connectome.denoise import is_ica_aroma -from giga_connectome.logger import gc_logger - -gc_log = gc_logger() - - -def set_verbosity(verbosity: int | list[int]) -> None: - if isinstance(verbosity, list): - verbosity = verbosity[0] - if verbosity == 0: - gc_log.setLevel("ERROR") - elif verbosity == 1: - gc_log.setLevel("WARNING") - elif verbosity == 2: - gc_log.setLevel("INFO") - elif verbosity == 3: - gc_log.setLevel("DEBUG") +from collections import defaultdict +from pathlib import Path +from typing import Any + +import pandas as pd +from tqdm.auto import tqdm + +from .atlas import Atlas +from .base import ConnectivityMatrix +from .features.calculate_degrees_of_freedom import calculate_degrees_of_freedom_loss +from .features.distance_dependence import calculate_distance_dependence +from .features.quality_control_connectivity import ( + calculate_median_absolute, + calculate_qcfc, + calculate_qcfc_percentage, +) +from .file_index.bids import BIDSIndex +from .logger import gc_log, set_verbosity def workflow(args: argparse.Namespace) -> None: + set_verbosity(args.verbosity) gc_log.info(vars(args)) - # set file paths + # Check BIDS path bids_dir = args.bids_dir - output_dir = args.output_dir - working_dir = args.work_dir - standardize = True # always standardising the time series - smoothing_fwhm = args.smoothing_fwhm - calculate_average_correlation = ( - args.calculate_intranetwork_average_correlation - ) - bids_filters = utils.parse_bids_filter(args.bids_filter_file) + index = BIDSIndex() + index.put(bids_dir) - subjects = utils.get_subject_lists(args.participant_label, bids_dir) - strategy = get_denoise_strategy(args.denoise_strategy) - - atlas = load_atlas_setting(args.atlas) - - set_verbosity(args.verbosity) - - # check output path + # Check output path + output_dir = args.output_dir output_dir.mkdir(parents=True, exist_ok=True) - working_dir.mkdir(parents=True, exist_ok=True) - # get template information; currently we only support the fmriprep defaults - template = ( - "MNI152NLin6Asym" if is_ica_aroma(strategy) else "MNI152NLin2009cAsym" + # Load data frame + data_frame = pd.read_csv( + args.phenotypes, + sep="\t", + index_col="participant_id", + dtype={"participant_id": str}, ) - - gc_log.info(f"Indexing BIDS directory:\n\t{bids_dir}") - - utils.create_ds_description(output_dir) - utils.create_sidecar(output_dir / "meas-PearsonCorrelation_relmat.json") - methods.generate_method_section( - output_dir=output_dir, - atlas=atlas["name"], - smoothing_fwhm=smoothing_fwhm, - standardize="zscore", - strategy=args.denoise_strategy, - mni_space=template, - average_correlation=calculate_average_correlation, + if "gender" not in data_frame.columns: + raise ValueError('Phenotypes file is missing the "gender" column') + if "age" not in data_frame.columns: + raise ValueError('Phenotypes file is missing the "age" column') + + # Load atlases + seg_to_atlas: dict[str, Atlas] = { + seg: Atlas.create(seg, Path(atlas_path_str)) + for seg, atlas_path_str in args.seg_to_atlas + } + + seg_to_connectivity_matrices: defaultdict[str, list[ConnectivityMatrix]] = ( + defaultdict(list) ) - - for subject in subjects: - subj_data, _ = utils.get_bids_images( - [subject], template, bids_dir, args.reindex_bids, bids_filters - ) - group_mask, resampled_atlases = generate_gm_mask_atlas( - working_dir, atlas, template, subj_data["mask"] + for timeseries_path in index.get(suffix="timeseries", extension=".tsv"): + query = dict(**index.get_tags(timeseries_path)) + del query["suffix"] + + metadata = index.get_metadata(timeseries_path) + if not metadata: + gc_log.warning(f"Skipping {timeseries_path} due to missing metadata") + continue + + for relmat_path in index.get(suffix="relmat", **query): + seg = index.get_tag_value(relmat_path, "seg") + if seg not in seg_to_atlas: + gc_log.warning(f"Skipping {relmat_path} due to missing atlas for {seg}") + continue + connectivity_matrix = ConnectivityMatrix(relmat_path, metadata) + seg_to_connectivity_matrices[seg].append(connectivity_matrix) + + if not seg_to_connectivity_matrices: + raise ValueError("No connectivity matrices found") + + records: list[dict[str, Any]] = [] + for seg, connectivity_matrices in tqdm( + seg_to_connectivity_matrices.items(), unit="seg" + ): + seg_subjects = [ + index.get_tag_value(c.path, "sub") for c in connectivity_matrices + ] + seg_data_frame = data_frame.loc[seg_subjects] + + qcfc = calculate_qcfc(seg_data_frame, connectivity_matrices) + + atlas = seg_to_atlas[seg] + record = dict( + seg=seg, + median_absolute_qcfc=calculate_median_absolute(qcfc.correlation), + percentage_significant_qcfc=calculate_qcfc_percentage(qcfc), + distance_dependence=calculate_distance_dependence(qcfc, atlas), + degrees_of_freedom_loss=calculate_degrees_of_freedom_loss( + connectivity_matrices + ), ) + records.append(record) - gc_log.info(f"Generate subject level connectomes: sub-{subject}") - - run_postprocessing_dataset( - strategy, - atlas, - resampled_atlases, - subj_data["bold"], - group_mask, - standardize, - smoothing_fwhm, - output_dir, - calculate_average_correlation, - ) - return + result_frame = pd.DataFrame.from_records(records) + result_frame.to_csv(output_dir / "metrics.tsv", sep="\t", index=False)