-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement smoke test and make it run
- Loading branch information
1 parent
7e43bf8
commit 8740440
Showing
20 changed files
with
822 additions
and
724 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
import nibabel as nib | ||
import numpy as np | ||
from numpy import typing as npt | ||
import scipy | ||
|
||
from .logger import gc_log | ||
|
||
|
||
@dataclass | ||
class Atlas(ABC): | ||
seg: str | ||
image: nib.nifti1.Nifti1Image | ||
|
||
structure: npt.NDArray[np.bool_] = field( | ||
default_factory=lambda: np.ones((3, 3, 3), dtype=bool) | ||
) | ||
|
||
@abstractmethod | ||
def get_centroid_points(self) -> npt.NDArray[np.float64]: | ||
raise NotImplementedError | ||
|
||
def get_centroids(self) -> npt.NDArray[np.float64]: | ||
centroid_points = self.get_centroid_points() | ||
centroid_coordinates = nib.affines.apply_affine( | ||
self.image.affine, centroid_points | ||
) | ||
return centroid_coordinates | ||
|
||
def get_distance_matrix(self) -> npt.NDArray[np.float64]: | ||
centroids = self.get_centroids() | ||
return scipy.spatial.distance.squareform( | ||
scipy.spatial.distance.pdist(centroids) | ||
) | ||
|
||
@staticmethod | ||
def create(seg: str, path: Path) -> "Atlas": | ||
image = nib.nifti1.load(path) | ||
|
||
if image.ndim <= 3 or image.shape[3] == 1: | ||
return DsegAtlas(seg, nib.funcs.squeeze_image(image)) | ||
else: | ||
return ProbsegAtlas(seg, image) | ||
|
||
|
||
@dataclass | ||
class DsegAtlas(Atlas): | ||
def get_array(self) -> npt.NDArray[np.int64]: | ||
return np.asarray(self.image.dataobj, dtype=np.int64) | ||
|
||
def check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None: | ||
for i in range(1, array.max() + 1): | ||
mask = array == i | ||
_, num_features = scipy.ndimage.label(mask, structure=self.structure) | ||
if num_features > 1: | ||
gc_log.warning( | ||
f'Atlas "{self.seg}" region {i} has more than a single connected component' | ||
) | ||
|
||
def get_centroid_points(self) -> npt.NDArray[np.float64]: | ||
array = self.get_array() | ||
self.check_single_connected_component(array) | ||
return np.asarray( | ||
scipy.ndimage.center_of_mass( | ||
input=array > 0, labels=array, index=np.arange(1, array.max() + 1) | ||
) | ||
) | ||
|
||
|
||
@dataclass | ||
class ProbsegAtlas(Atlas): | ||
epsilon: float = 1e-6 | ||
|
||
def get_centroid_point( | ||
self, i: int, array: npt.NDArray[np.float64] | ||
) -> tuple[float, ...]: | ||
mask = array > self.epsilon | ||
_, num_features = scipy.ndimage.label(mask, structure=self.structure) | ||
if num_features > 1: | ||
gc_log.warning( | ||
f'Atlas "{self.seg}" region {i} has more than a single connected component' | ||
) | ||
return scipy.ndimage.center_of_mass(array) | ||
|
||
def get_centroid_points(self) -> npt.NDArray[np.float64]: | ||
return np.asarray( | ||
[ | ||
self.get_centroid_point(i, image.get_fdata()) | ||
for i, image in enumerate(nib.funcs.four_to_three(self.image)) | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import numpy as np | ||
from numpy import typing as npt | ||
|
||
|
||
@dataclass | ||
class ConnectivityMatrix: | ||
path: Path | ||
metadata: dict[str, Any] | ||
|
||
def load(self) -> npt.NDArray[np.float64]: | ||
return np.loadtxt(self.path, delimiter="\t", skiprows=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +0,0 @@ | ||
from .quality_control_connectivity import ( | ||
qcfc, | ||
partial_correlation, | ||
significant_level, | ||
calculate_median_absolute, | ||
) | ||
from .distance_dependency import get_atlas_pairwise_distance, get_centroid | ||
from .network_modularity import louvain_modularity | ||
|
||
__all__ = [ | ||
"qcfc", | ||
"significant_level", | ||
"partial_correlation", | ||
"calculate_median_absolute", | ||
"get_atlas_pairwise_distance", | ||
"get_centroid", | ||
"louvain_modularity", | ||
] | ||
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,125 +1,38 @@ | ||
"""Calculate degree of freedom""" | ||
|
||
import argparse | ||
from pathlib import Path | ||
import pandas as pd | ||
import numpy as np | ||
from ..base import ConnectivityMatrix | ||
|
||
|
||
from fmriprep_denoise.dataset.timeseries import get_confounds | ||
from fmriprep_denoise.dataset.fmriprep import ( | ||
get_prepro_strategy, | ||
fetch_fmriprep_derivative, | ||
generate_movement_summary, | ||
) | ||
def calculate_degrees_of_freedom_loss( | ||
connectivity_matrices: list[ConnectivityMatrix], | ||
) -> float: | ||
""" | ||
Calculate the percent of degrees of freedom lost during denoising. | ||
Parameters: | ||
- bids_file (BIDSFile): The BIDS file for which to calculate the degrees of freedom. | ||
def main(): | ||
args = parse_args() | ||
print(vars(args)) | ||
dataset_name = args.dataset_name | ||
fmriprep_specifier = args.specifier | ||
fmriprep_path = Path(args.fmriprep_path) | ||
participant_tsv = Path(args.participants_tsv) | ||
output_root = Path(args.output_path) | ||
Returns: | ||
- float: The percentage of degrees of freedom lost. | ||
output_root.mkdir(exist_ok=True, parents=True) | ||
path_movement = Path( | ||
output_root / f"dataset-{dataset_name}_desc-movement_phenotype.tsv" | ||
) | ||
""" | ||
|
||
path_dof = Path( | ||
output_root / f"dataset-{dataset_name}_desc-confounds_phenotype.tsv" | ||
) | ||
values: list[float] = [] | ||
|
||
full_data = fetch_fmriprep_derivative( | ||
dataset_name, participant_tsv, fmriprep_path, fmriprep_specifier | ||
) | ||
for connectivity_matrix in connectivity_matrices: | ||
metadata = connectivity_matrix.metadata | ||
|
||
if dataset_name == "ds000030": | ||
# read relevant files, | ||
participants = full_data.phenotypic.copy() | ||
mask_quality = participants["ghost_NoGhost"] == "No_ghost" | ||
participants = participants[mask_quality].index.tolist() | ||
subjects = [p.split("-")[-1] for p in participants] | ||
full_data = fetch_fmriprep_derivative( | ||
dataset_name, | ||
participant_tsv, | ||
fmriprep_path, | ||
fmriprep_specifier, | ||
subject=subjects, | ||
) | ||
movement = generate_movement_summary(full_data) | ||
movement = movement.sort_index() | ||
movement.to_csv(path_movement, sep="\t") | ||
print("Generate movement stats.") | ||
total = 0 | ||
|
||
subjects = [p.split("-")[-1] for p in movement.index] | ||
confound_regressors = metadata.get("ConfoundRegressors", list()) | ||
total += len(confound_regressors) | ||
|
||
benchmark_strategies = get_prepro_strategy() | ||
data_aroma = fetch_fmriprep_derivative( | ||
dataset_name, | ||
participant_tsv, | ||
fmriprep_path, | ||
fmriprep_specifier, | ||
aroma=True, | ||
subject=subjects, | ||
) | ||
data = fetch_fmriprep_derivative( | ||
dataset_name, | ||
participant_tsv, | ||
fmriprep_path, | ||
fmriprep_specifier, | ||
subject=subjects, | ||
) | ||
info = {} | ||
for strategy_name, parameters in benchmark_strategies.items(): | ||
print(f"Denoising: {strategy_name}") | ||
print(parameters) | ||
func_data = data_aroma.func if "aroma" in strategy_name else data.func | ||
for img in func_data: | ||
sub = img.split("/")[-1].split("_")[0] | ||
reduced_confounds, sample_mask = get_confounds( | ||
strategy_name, parameters, img | ||
) | ||
full_length = reduced_confounds.shape[0] | ||
ts_length = full_length if sample_mask is None else len(sample_mask) | ||
excised_vol = full_length - ts_length | ||
excised_vol_pro = excised_vol / full_length | ||
regressors = reduced_confounds.columns.tolist() | ||
compcor = sum("comp_cor" in i for i in regressors) | ||
high_pass = sum("cosine" in i for i in regressors) | ||
total = len(regressors) | ||
fixed = total - compcor if "compcor" in strategy_name else len(regressors) | ||
total += metadata.get("NumberOfVolumesDiscardedByMotionScrubbing", 0) | ||
total += metadata.get("NumberOfVolumesDiscardedByNonsteadyStatesDetector", 0) | ||
|
||
if "aroma" in strategy_name: | ||
path_aroma_ic = img.split("space-")[0] + "AROMAnoiseICs.csv" | ||
with open(path_aroma_ic, "r") as f: | ||
aroma = len(f.readline().split(",")) | ||
total = fixed + aroma | ||
else: | ||
aroma = 0 | ||
# TODO Support ICA-AROMA | ||
|
||
if "scrub" in strategy_name: | ||
total += excised_vol | ||
values.append(total / connectivity_matrix.load().shape[0]) | ||
|
||
stats = { | ||
(strategy_name, "excised_vol"): excised_vol, | ||
(strategy_name, "excised_vol_proportion"): excised_vol_pro, | ||
(strategy_name, "high_pass"): high_pass, | ||
(strategy_name, "fixed_regressors"): fixed, | ||
(strategy_name, "compcor"): compcor, | ||
(strategy_name, "aroma"): aroma, | ||
(strategy_name, "total"): total, | ||
(strategy_name, "full_length"): full_length, | ||
} | ||
if info.get(sub): | ||
info[sub].update(stats) | ||
else: | ||
info[sub] = stats | ||
confounds_stats = pd.DataFrame.from_dict(info, orient="index") | ||
confounds_stats = confounds_stats.sort_index() | ||
confounds_stats.to_csv(path_dof, sep="\t") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
return float(np.mean(values)) |
Oops, something went wrong.