Skip to content

Commit

Permalink
Implement smoke test and make it run
Browse files Browse the repository at this point in the history
  • Loading branch information
HippocampusGirl committed May 16, 2024
1 parent 7e43bf8 commit 8740440
Show file tree
Hide file tree
Showing 20 changed files with 822 additions and 724 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ warn_unused_ignores = true
ignore_missing_imports = true
module = [
"bids.*",
"wonkyconn._version",
"h5py.*",
"nibabel.*",
"nilearn.*",
Expand All @@ -105,8 +104,10 @@ module = [
"nilearn.interfaces.*",
"nilearn.maskers.*",
"nilearn.masking.*",
"patsy.*",
"rich.*",
"scipy.ndimage.*",
"scipy.*",
"statsmodels.*",
"templateflow.*",
]

Expand Down
93 changes: 93 additions & 0 deletions wonkyconn/atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
import nibabel as nib
import numpy as np
from numpy import typing as npt
import scipy

from .logger import gc_log


@dataclass
class Atlas(ABC):
seg: str
image: nib.nifti1.Nifti1Image

structure: npt.NDArray[np.bool_] = field(
default_factory=lambda: np.ones((3, 3, 3), dtype=bool)
)

@abstractmethod
def get_centroid_points(self) -> npt.NDArray[np.float64]:
raise NotImplementedError

def get_centroids(self) -> npt.NDArray[np.float64]:
centroid_points = self.get_centroid_points()
centroid_coordinates = nib.affines.apply_affine(
self.image.affine, centroid_points
)
return centroid_coordinates

def get_distance_matrix(self) -> npt.NDArray[np.float64]:
centroids = self.get_centroids()
return scipy.spatial.distance.squareform(
scipy.spatial.distance.pdist(centroids)
)

@staticmethod
def create(seg: str, path: Path) -> "Atlas":
image = nib.nifti1.load(path)

if image.ndim <= 3 or image.shape[3] == 1:
return DsegAtlas(seg, nib.funcs.squeeze_image(image))
else:
return ProbsegAtlas(seg, image)


@dataclass
class DsegAtlas(Atlas):
def get_array(self) -> npt.NDArray[np.int64]:
return np.asarray(self.image.dataobj, dtype=np.int64)

def check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None:
for i in range(1, array.max() + 1):
mask = array == i
_, num_features = scipy.ndimage.label(mask, structure=self.structure)
if num_features > 1:
gc_log.warning(
f'Atlas "{self.seg}" region {i} has more than a single connected component'
)

def get_centroid_points(self) -> npt.NDArray[np.float64]:
array = self.get_array()
self.check_single_connected_component(array)
return np.asarray(
scipy.ndimage.center_of_mass(
input=array > 0, labels=array, index=np.arange(1, array.max() + 1)
)
)


@dataclass
class ProbsegAtlas(Atlas):
epsilon: float = 1e-6

def get_centroid_point(
self, i: int, array: npt.NDArray[np.float64]
) -> tuple[float, ...]:
mask = array > self.epsilon
_, num_features = scipy.ndimage.label(mask, structure=self.structure)
if num_features > 1:
gc_log.warning(
f'Atlas "{self.seg}" region {i} has more than a single connected component'
)
return scipy.ndimage.center_of_mass(array)

def get_centroid_points(self) -> npt.NDArray[np.float64]:
return np.asarray(
[
self.get_centroid_point(i, image.get_fdata())
for i, image in enumerate(nib.funcs.four_to_three(self.image))
]
)
15 changes: 15 additions & 0 deletions wonkyconn/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
from numpy import typing as npt


@dataclass
class ConnectivityMatrix:
path: Path
metadata: dict[str, Any]

def load(self) -> npt.NDArray[np.float64]:
return np.loadtxt(self.path, delimiter="\t", skiprows=1)
18 changes: 0 additions & 18 deletions wonkyconn/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +0,0 @@
from .quality_control_connectivity import (
qcfc,
partial_correlation,
significant_level,
calculate_median_absolute,
)
from .distance_dependency import get_atlas_pairwise_distance, get_centroid
from .network_modularity import louvain_modularity

__all__ = [
"qcfc",
"significant_level",
"partial_correlation",
"calculate_median_absolute",
"get_atlas_pairwise_distance",
"get_centroid",
"louvain_modularity",
]
Empty file added wonkyconn/features/base.py
Empty file.
133 changes: 23 additions & 110 deletions wonkyconn/features/calculate_degrees_of_freedom.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,38 @@
"""Calculate degree of freedom"""

import argparse
from pathlib import Path
import pandas as pd
import numpy as np
from ..base import ConnectivityMatrix


from fmriprep_denoise.dataset.timeseries import get_confounds
from fmriprep_denoise.dataset.fmriprep import (
get_prepro_strategy,
fetch_fmriprep_derivative,
generate_movement_summary,
)
def calculate_degrees_of_freedom_loss(
connectivity_matrices: list[ConnectivityMatrix],
) -> float:
"""
Calculate the percent of degrees of freedom lost during denoising.
Parameters:
- bids_file (BIDSFile): The BIDS file for which to calculate the degrees of freedom.
def main():
args = parse_args()
print(vars(args))
dataset_name = args.dataset_name
fmriprep_specifier = args.specifier
fmriprep_path = Path(args.fmriprep_path)
participant_tsv = Path(args.participants_tsv)
output_root = Path(args.output_path)
Returns:
- float: The percentage of degrees of freedom lost.
output_root.mkdir(exist_ok=True, parents=True)
path_movement = Path(
output_root / f"dataset-{dataset_name}_desc-movement_phenotype.tsv"
)
"""

path_dof = Path(
output_root / f"dataset-{dataset_name}_desc-confounds_phenotype.tsv"
)
values: list[float] = []

full_data = fetch_fmriprep_derivative(
dataset_name, participant_tsv, fmriprep_path, fmriprep_specifier
)
for connectivity_matrix in connectivity_matrices:
metadata = connectivity_matrix.metadata

if dataset_name == "ds000030":
# read relevant files,
participants = full_data.phenotypic.copy()
mask_quality = participants["ghost_NoGhost"] == "No_ghost"
participants = participants[mask_quality].index.tolist()
subjects = [p.split("-")[-1] for p in participants]
full_data = fetch_fmriprep_derivative(
dataset_name,
participant_tsv,
fmriprep_path,
fmriprep_specifier,
subject=subjects,
)
movement = generate_movement_summary(full_data)
movement = movement.sort_index()
movement.to_csv(path_movement, sep="\t")
print("Generate movement stats.")
total = 0

subjects = [p.split("-")[-1] for p in movement.index]
confound_regressors = metadata.get("ConfoundRegressors", list())
total += len(confound_regressors)

benchmark_strategies = get_prepro_strategy()
data_aroma = fetch_fmriprep_derivative(
dataset_name,
participant_tsv,
fmriprep_path,
fmriprep_specifier,
aroma=True,
subject=subjects,
)
data = fetch_fmriprep_derivative(
dataset_name,
participant_tsv,
fmriprep_path,
fmriprep_specifier,
subject=subjects,
)
info = {}
for strategy_name, parameters in benchmark_strategies.items():
print(f"Denoising: {strategy_name}")
print(parameters)
func_data = data_aroma.func if "aroma" in strategy_name else data.func
for img in func_data:
sub = img.split("/")[-1].split("_")[0]
reduced_confounds, sample_mask = get_confounds(
strategy_name, parameters, img
)
full_length = reduced_confounds.shape[0]
ts_length = full_length if sample_mask is None else len(sample_mask)
excised_vol = full_length - ts_length
excised_vol_pro = excised_vol / full_length
regressors = reduced_confounds.columns.tolist()
compcor = sum("comp_cor" in i for i in regressors)
high_pass = sum("cosine" in i for i in regressors)
total = len(regressors)
fixed = total - compcor if "compcor" in strategy_name else len(regressors)
total += metadata.get("NumberOfVolumesDiscardedByMotionScrubbing", 0)
total += metadata.get("NumberOfVolumesDiscardedByNonsteadyStatesDetector", 0)

if "aroma" in strategy_name:
path_aroma_ic = img.split("space-")[0] + "AROMAnoiseICs.csv"
with open(path_aroma_ic, "r") as f:
aroma = len(f.readline().split(","))
total = fixed + aroma
else:
aroma = 0
# TODO Support ICA-AROMA

if "scrub" in strategy_name:
total += excised_vol
values.append(total / connectivity_matrix.load().shape[0])

stats = {
(strategy_name, "excised_vol"): excised_vol,
(strategy_name, "excised_vol_proportion"): excised_vol_pro,
(strategy_name, "high_pass"): high_pass,
(strategy_name, "fixed_regressors"): fixed,
(strategy_name, "compcor"): compcor,
(strategy_name, "aroma"): aroma,
(strategy_name, "total"): total,
(strategy_name, "full_length"): full_length,
}
if info.get(sub):
info[sub].update(stats)
else:
info[sub] = stats
confounds_stats = pd.DataFrame.from_dict(info, orient="index")
confounds_stats = confounds_stats.sort_index()
confounds_stats.to_csv(path_dof, sep="\t")


if __name__ == "__main__":
main()
return float(np.mean(values))
Loading

0 comments on commit 8740440

Please sign in to comment.