Skip to content

Commit

Permalink
Add plotting and enhance performance and add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
HippocampusGirl committed May 17, 2024
1 parent 8740440 commit 5e18f5e
Show file tree
Hide file tree
Showing 27 changed files with 472 additions and 2,635 deletions.
90 changes: 39 additions & 51 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@ name = "wonkyconn"
description = "Evaluating the residual motion in fMRI connectome and visualise reports."
readme = "README.md"
requires-python = ">=3.11"
license = { file="LICENSE" }
authors = [
{ name="Hao-Ting Wang", email="[email protected]" },
]
license = { file = "LICENSE" }
authors = [{ name = "Hao-Ting Wang", email = "[email protected]" }]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"nilearn[plotting] >=0.10.3",
"pybids >=0.15.0, <0.16.0",
"templateflow < 23.0.0",
"setuptools",
"jinja2 >= 2.0",
"numpy",
"scipy",
"patsy",
"pandas",
"rich",
"numba",
"seaborn",
"matplotlib",
]
dynamic = ["version"]

Expand All @@ -35,22 +36,14 @@ dev = [
"flake8",
"pre-commit",
"wonkyconn[test]",
'tox',
'mypy',
'types-all',
'pandas-stubs',
'types-tqdm'
]
test = [
"pytest",
"pytest-cov",
]
docs = [
"sphinx",
"sphinx_rtd_theme",
"myst-parser",
"sphinx-argparse"
"tox",
"mypy",
"types-all",
"pandas-stubs",
"types-tqdm",
]
test = ["nibabel", "nilearn", "pytest", "pytest-cov", "templateflow < 23.0.0"]
docs = ["sphinx", "sphinx_rtd_theme", "myst-parser", "sphinx-argparse"]
# Aliases
tests = ["wonkyconn[test]"]

Expand All @@ -69,13 +62,10 @@ exclude = [".git_archival.txt"]

[tool.hatch.build.targets.wheel]
packages = ["wonkyconn"]
exclude = [
".github",
"wonkyconn/data/test_data"
]
exclude = [".github", "wonkyconn/data/test_data"]

[tool.black]
target-version = ['py311']
target-version = ["py311"]
exclude = "wonkyconn/_version.py"
line-length = 79

Expand All @@ -84,7 +74,7 @@ check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
enable_error_code = ["ignore-without-code", "redundant-expr"] # "truthy-bool"
enable_error_code = ["ignore-without-code", "redundant-expr"] # "truthy-bool"
no_implicit_optional = true
show_error_codes = true
# strict = true
Expand All @@ -95,36 +85,34 @@ warn_unused_ignores = true
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"bids.*",
"h5py.*",
"nibabel.*",
"nilearn.*",
"nilearn.connectome.*",
"nilearn.image.*",
"nilearn.interfaces.*",
"nilearn.maskers.*",
"nilearn.masking.*",
"patsy.*",
"rich.*",
"scipy.*",
"statsmodels.*",
"templateflow.*",
"bids.*",
"matplotlib.*",
"numba.*",
"patsy.*",
"rich.*",
"scipy.*",
"seaborn.*",
"statsmodels.*",
"templateflow.*",
]

[[tool.mypy.overrides]]
ignore_errors = true
module = [
'wonkyconn.tests.*',
'conf',
]
module = ["wonkyconn.tests.*", "conf"]

[tool.pytest.ini_options]
markers = [
"smoke: smoke tests that will run on a simulated dataset (deselect with '-m \"not smoke\"')",
]
# filterwarnings = ["error"]
minversion = "7"
log_cli_level = "INFO"
xfail_strict = true
testpaths = ["wonkyconn/tests"]
addopts = ["-ra", "--strict-config", "--strict-markers", "--doctest-modules", "-v"]
markers = [
"smoke: smoke tests that will run on a downsampled real dataset (deselect with '-m \"not smoke\"')",
addopts = [
"-ra",
"--strict-config",
"--strict-markers",
"--doctest-modules",
"-v",
]
# filterwarnings = ["error"]
51 changes: 47 additions & 4 deletions wonkyconn/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

@dataclass
class Atlas(ABC):
"""
Abstract base class representing a brain atlas.
Attributes:
seg (str): The "seg" value that the atlas corresponds to. A "seg" uniquely
identifies an atlas in a given space and resolution.
image (nib.nifti1.Nifti1Image): The Nifti1Image object for the atlas file.
"""

seg: str
image: nib.nifti1.Nifti1Image

Expand All @@ -20,23 +30,56 @@ class Atlas(ABC):

@abstractmethod
def get_centroid_points(self) -> npt.NDArray[np.float64]:
"""
Returns the centroid points of the atlas regions.
Returns:
npt.NDArray[np.float64]: An array of centroid indices.
"""
raise NotImplementedError

def get_centroids(self) -> npt.NDArray[np.float64]:
"""
Returns the centroid coordinates of the atlas regions.
Returns:
npt.NDArray[np.float64]: An array of centroid coordinates.
"""
centroid_points = self.get_centroid_points()
centroid_coordinates = nib.affines.apply_affine(
self.image.affine, centroid_points
)
return centroid_coordinates

def get_distance_matrix(self) -> npt.NDArray[np.float64]:
"""
Calculates the pairwise distance matrix between the centroids
of the atlas regions.
Returns:
npt.NDArray[np.float64]: The distance matrix.
"""
centroids = self.get_centroids()
return scipy.spatial.distance.squareform(
scipy.spatial.distance.pdist(centroids)
)

@staticmethod
def create(seg: str, path: Path) -> "Atlas":
"""
Create an Atlas object based based on it's "seg" value and path.
Parameters:
seg (str): The "seg" value.
path (Path): The path to the image.
Returns:
Atlas: An instance of the Atlas class.
Raises:
None
"""
image = nib.nifti1.load(path)

if image.ndim <= 3 or image.shape[3] == 1:
Expand All @@ -50,7 +93,7 @@ class DsegAtlas(Atlas):
def get_array(self) -> npt.NDArray[np.int64]:
return np.asarray(self.image.dataobj, dtype=np.int64)

def check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None:
def _check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None:
for i in range(1, array.max() + 1):
mask = array == i
_, num_features = scipy.ndimage.label(mask, structure=self.structure)
Expand All @@ -61,7 +104,7 @@ def check_single_connected_component(self, array: npt.NDArray[np.int64]) -> None

def get_centroid_points(self) -> npt.NDArray[np.float64]:
array = self.get_array()
self.check_single_connected_component(array)
self._check_single_connected_component(array)
return np.asarray(
scipy.ndimage.center_of_mass(
input=array > 0, labels=array, index=np.arange(1, array.max() + 1)
Expand All @@ -73,7 +116,7 @@ def get_centroid_points(self) -> npt.NDArray[np.float64]:
class ProbsegAtlas(Atlas):
epsilon: float = 1e-6

def get_centroid_point(
def _get_centroid_point(
self, i: int, array: npt.NDArray[np.float64]
) -> tuple[float, ...]:
mask = array > self.epsilon
Expand All @@ -87,7 +130,7 @@ def get_centroid_point(
def get_centroid_points(self) -> npt.NDArray[np.float64]:
return np.asarray(
[
self.get_centroid_point(i, image.get_fdata())
self._get_centroid_point(i, image.get_fdata())
for i, image in enumerate(nib.funcs.four_to_three(self.image))
]
)
25 changes: 25 additions & 0 deletions wonkyconn/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any

Expand All @@ -8,8 +9,32 @@

@dataclass
class ConnectivityMatrix:
"""
Represents a connectivity matrix.
Attributes:
path (Path): The path to the ".tsv" file containing the connectivity matrix.
metadata (dict[str, Any]): Additional metadata associated with the connectivity matrix.
"""

path: Path
metadata: dict[str, Any]

def load(self) -> npt.NDArray[np.float64]:
"""
Load the connectivity matrix from the file.
Returns:
ndarray: The loaded connectivity matrix as a NumPy array.
"""
return np.loadtxt(self.path, delimiter="\t", skiprows=1)

@cached_property
def region_count(self) -> int:
"""
Get the number of regions in the connectivity matrix.
Returns:
int: The number of regions.
"""
return self.load().shape[0]
44 changes: 44 additions & 0 deletions wonkyconn/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import scipy
from numba import guvectorize
from numpy import typing as npt


def correlation_p_value(r: npt.NDArray[np.float64], m: int) -> npt.NDArray[np.float64]:
ab = m / 2 - 1
distribution = scipy.stats.beta(ab, ab, loc=-1, scale=2)
pvalue = 2 * (distribution.sf(np.abs(r)))
return pvalue


@guvectorize(
["void(float64[:], float64[:], float64[:, :], float64[:])"],
"(n),(n),(n,m)->()",
nopython=True,
)
def partial_correlation(
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
cov: npt.NDArray[np.float64],
out: npt.NDArray[np.float64],
) -> None:
"""A minimal implementation of partial correlation.
Parameters
----------
x, y : np.ndarray
Variable of interest.
cov : np.ndarray
Variable to be removed from variable of interest.
Returns
-------
dict
Correlation and p-value.
"""
beta_cov_x, _, _, _ = np.linalg.lstsq(cov, x)
beta_cov_y, _, _, _ = np.linalg.lstsq(cov, y)
resid_x = x - cov @ beta_cov_x
resid_y = y - cov @ beta_cov_y
out[0] = np.corrcoef(resid_x, resid_y)[0, 1]
14 changes: 14 additions & 0 deletions wonkyconn/features/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Self

import numpy as np


@dataclass
class MeanAndSEMResult:
mean: float
sem: float

@classmethod
def empty(cls) -> Self:
return cls(mean=np.nan, sem=np.nan)
Loading

0 comments on commit 5e18f5e

Please sign in to comment.