Skip to content

Commit

Permalink
[ENH] run fit and transform of SurfaceMasker on list of Surface images (
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau authored Nov 19, 2024
1 parent 875e0dd commit cf6bc7d
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 18 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ Enhancements

- :bdg-dark:`Code` Move SurfaceMasker and SurfaceLabelsMasker from experimental to :class:`nilearn.maskers.SurfaceMasker` and to :class:`nilearn.maskers.SurfaceLabelsMasker` (:gh:`4692` and :gh:`4714` by `Rémi Gau`_).

- :bdg-dark:`Code` Allow list of :obj:`~nilearn.surface.SurfaceImage` as input to :class:`~nilearn.maskers.SurfaceMasker` fit and transform methods (:gh:`4719` by `Rémi Gau`_).

- :bdg-dark:`Code` Move SurfaceImage and associated classes to from experimental to :mod:`nilearn.surface` (:gh:`4723` by `Rémi Gau`_).

- :bdg-dark:`Code` Improved SearchLight with NIfTI Support, Mask Handling, and Reusable Transform Method (:gh:`4652` by `Prakhar Jain`_).
Expand Down
37 changes: 37 additions & 0 deletions nilearn/maskers/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from nilearn import image
from nilearn.surface import SurfaceImage


def _check_dims(imgs):
Expand Down Expand Up @@ -85,3 +86,39 @@ def get_min_max_surface_image(img):
vmin = min(min(x.ravel()) for x in img.data.parts.values())
vmax = max(max(x.ravel()) for x in img.data.parts.values())
return vmin, vmax


def concatenate_surface_images(imgs):
"""Concatenate the data of a list or tuple of SurfaceImages.
Assumes all images have same meshes.
Parameters
----------
imgs : :obj:`list` or :obj:`tuple` of SurfaceImage object
Returns
-------
SurfaceImage object
"""
if not isinstance(imgs, (tuple, list)) or any(
not isinstance(x, SurfaceImage) for x in imgs
):
raise TypeError(
"'imgs' must be a list or a tuple of SurfaceImage instances."
)

if len(imgs) == 1:
return imgs[0]

for img in imgs:
check_same_n_vertices(img.mesh, imgs[0].mesh)

output_data = {}
for part in imgs[0].data.parts:
tmp = [img.data.parts[part] for img in imgs]
output_data[part] = np.concatenate(tmp)

output = SurfaceImage(mesh=imgs[0].mesh, data=output_data)

return output
37 changes: 27 additions & 10 deletions nilearn/maskers/surface_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nilearn.maskers._utils import (
check_same_n_vertices,
compute_mean_surface_image,
concatenate_surface_images,
get_min_max_surface_image,
)
from nilearn.surface import SurfaceImage
Expand All @@ -25,9 +26,11 @@
class SurfaceMasker(TransformerMixin, CacheMixin, BaseEstimator):
"""Extract data from a :obj:`~nilearn.surface.SurfaceImage`.
.. versionadded:: 0.11.0
Parameters
----------
mask_img: :obj:`~nilearn.surface.SurfaceImage` object or None, default=None
mask_img: :obj:`~nilearn.surface.SurfaceImage` or None, default=None
%(smoothing_fwhm)s
This parameter is not implemented yet.
Expand Down Expand Up @@ -145,7 +148,7 @@ def _fit_mask_img(self, img):
Parameters
----------
img : SurfaceImage object or None
img : SurfaceImage object or :obj:`list` of SurfaceImage or None
"""
if self.mask_img is not None:
if img is not None:
Expand All @@ -160,11 +163,15 @@ def _fit_mask_img(self, img):
"or an img when calling fit()."
)

if not isinstance(img, list):
img = [img]
img = concatenate_surface_images(img)

# TODO: don't store a full array of 1 to mean "no masking"; use some
# sentinel value
mask_data = {
k: np.ones(v.n_vertices, dtype=bool)
for (k, v) in img.mesh.parts.items()
part: np.ones(v.n_vertices, dtype=bool)
for (part, v) in img.mesh.parts.items()
}
self.mask_img_ = SurfaceImage(mesh=img.mesh, data=mask_data)

Expand All @@ -173,12 +180,15 @@ def fit(self, img=None, y=None):
Parameters
----------
img : :obj:`~nilearn.surface.SurfaceImage` object or None
img : :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`list` of :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`tuple` of :obj:`~nilearn.surface.SurfaceImage` or None, \
default = None
Mesh and data for both hemispheres.
y : None
This parameter is unused. It is solely included for scikit-learn
compatibility.
This parameter is unused.
It is solely included for scikit-learn compatibility.
Returns
-------
Expand Down Expand Up @@ -218,7 +228,9 @@ def transform(
Parameters
----------
img : :obj:`~nilearn.surface.SurfaceImage` object
img : :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`list` of :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`tuple` of :obj:`~nilearn.surface.SurfaceImage`
Mesh and data for both hemispheres.
confounds : :class:`numpy.ndarray`, :obj:`str`,\
Expand Down Expand Up @@ -262,6 +274,10 @@ def transform(

self._check_fitted()

if not isinstance(img, list):
img = [img]
img = concatenate_surface_images(img)

check_same_n_vertices(self.mask_img_.mesh, img.mesh)

if self.reports:
Expand Down Expand Up @@ -308,7 +324,9 @@ def fit_transform(
Parameters
----------
img : :obj:`~nilearn.surface.SurfaceImage` object
img : :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`list` of :obj:`~nilearn.surface.SurfaceImage` or \
:obj:`tuple` of :obj:`~nilearn.surface.SurfaceImage`
Mesh and data for both hemispheres.
y : None
Expand All @@ -324,7 +342,6 @@ def fit_transform(
sample_mask : None, or any type compatible with numpy-array indexing, \
or :obj:`list` of \
shape: (number of scans - number of volumes removed) \
for explicit index, or (number of scans) for binary mask, \
default=None
sample_mask to pass to :func:`nilearn.signal.clean`.
Expand Down
31 changes: 28 additions & 3 deletions nilearn/maskers/tests/test_surface_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_check_estimator_invalid(estimator, check, name): # noqa: ARG001
check(estimator)


def test_fit_list_surf_images(surf_img):
"""Test fit on list of surface images.
resulting mask should have a single 'timepoint'.
"""
masker = SurfaceMasker()
masker.fit([surf_img((3,)), surf_img((5,))])
assert masker.mask_img_.shape == (surf_img().shape[1],)


# test with only one surface image and with 2 surface images (surface time
# series)
@pytest.mark.parametrize("shape", [(1,), (2,)])
Expand Down Expand Up @@ -70,14 +80,29 @@ def test_none_mask_img(surf_mask):
SurfaceMasker(surf_mask()).fit(None)


def test_transform_list_surf_images(surf_mask, surf_img):
"""Test transform on list of surface images."""
masker = SurfaceMasker(surf_mask()).fit()
signals = masker.transform([surf_img((3,)), surf_img((4,))])
assert signals.shape == (7, masker.output_dimension_)


def test_inverse_transform_list_surf_images(surf_mask, surf_img):
"""Test inverse_transform on list of surface images."""
masker = SurfaceMasker(surf_mask()).fit()
signals = masker.transform([surf_img((3,)), surf_img((4,))])
img = masker.inverse_transform(signals)
assert img.shape == (7, surf_mask().mesh.n_vertices)


def test_unfitted_masker(surf_mask):
masker = SurfaceMasker(surf_mask)
masker = SurfaceMasker(surf_mask())
with pytest.raises(ValueError, match="fitted"):
masker.transform(surf_mask)
masker.transform(surf_mask())


def test_check_is_fitted(surf_mask):
masker = SurfaceMasker(surf_mask)
masker = SurfaceMasker(surf_mask())
assert not masker.__sklearn_is_fitted__()


Expand Down
6 changes: 6 additions & 0 deletions nilearn/maskers/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from nilearn.maskers._utils import (
compute_mean_surface_image,
concatenate_surface_images,
get_min_max_surface_image,
)

Expand Down Expand Up @@ -41,3 +42,8 @@ def test_get_min_max_surface_image(surf_img):

assert vmin == -3.5
assert vmax == 10


def test_concatenate_surface_images(surf_img):
img = concatenate_surface_images([surf_img((3,)), surf_img((5,))])
assert img.shape == (8, 9)
12 changes: 12 additions & 0 deletions nilearn/surface/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,8 @@ class PolyData:
It is a shallow wrapper around the ``parts`` dictionary, which cannot be
empty and whose keys must be a subset of {"left", "right"}.
.. versionadded:: 0.11.0
Parameters
----------
left : :obj:`numpy.ndarray` or :obj:`str` or :obj:`pathlib.Path` or None,\
Expand Down Expand Up @@ -1322,6 +1324,8 @@ class SurfaceMesh(abc.ABC):
"""A surface :term:`mesh` having vertex, \
coordinates and faces (triangles).
.. versionadded:: 0.11.0
Attributes
----------
n_vertices : int
Expand Down Expand Up @@ -1356,6 +1360,8 @@ def to_gifti(self, gifti_file):
class InMemoryMesh(SurfaceMesh):
"""A surface mesh stored as in-memory numpy arrays.
.. versionadded:: 0.11.0
Parameters
----------
coordinates : :obj:`numpy.ndarray`
Expand Down Expand Up @@ -1383,6 +1389,8 @@ def __init__(self, coordinates, faces):
class FileMesh(SurfaceMesh):
"""A surface mesh stored in a Gifti or Freesurfer file.
.. versionadded:: 0.11.0
Parameters
----------
file_path : :obj:`str` or :obj:`pathlib.Path`
Expand Down Expand Up @@ -1434,6 +1442,8 @@ class PolyMesh:
It is a shallow wrapper around the ``parts`` dictionary, which cannot be
empty and whose keys must be a subset of {"left", "right"}.
.. versionadded:: 0.11.0
Parameters
----------
left : :obj:`str` or :obj:`pathlib.Path` \
Expand Down Expand Up @@ -1626,6 +1636,8 @@ def _sanitize_filename(filename):
class SurfaceImage:
"""Surface image containing meshes & data for both hemispheres.
.. versionadded:: 0.11.0
Parameters
----------
mesh : :obj:`nilearn.surface.PolyMesh`, \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ packages = ["nilearn"]
source = "vcs"

[tool.pytest.ini_options]
addopts = "-ra --strict-config --strict-markers --doctest-modules --showlocals -s -vv --durations=0 --template=maint_tools/templates/index.html -n auto"
addopts = "-ra --strict-config --strict-markers --doctest-modules --showlocals -s -vv --durations=0 --template=maint_tools/templates/index.html"
doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS"
junit_family = "xunit2"
log_cli_level = "INFO"
Expand Down
8 changes: 4 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ description = run tests on latest version of all dependencies (plotting not incl
passenv = {[global_var]passenv}
extras = test
commands =
pytest --cov=nilearn --cov-report=xml --report=report.html {posargs:}
pytest --cov=nilearn --cov-report=xml --report=report.html -n auto {posargs:}

[testenv:test_plotting]
description = run tests on latest version of all dependencies
Expand All @@ -99,9 +99,9 @@ extras = test
deps =
{[plotting]deps}
commands =
pytest doc/_additional_doctests.txt --report=report_doc.html
pytest -n auto doc/_additional_doctests.txt --report=report_doc.html
; TODO find a way to rely on globbing instead of listing a specific folder
pytest --doctest-glob='*.rst' doc/manipulating_images/ --report=report_doc.html
pytest -n auto --doctest-glob='*.rst' doc/manipulating_images/ --report=report_doc.html

[testenv:test_pre]
description = run test_latest and test_doc on pre-release version of all dependencies
Expand Down Expand Up @@ -149,7 +149,7 @@ commands =
pip install git+https://github.com/nipy/nibabel
pip install --pre --upgrade -i {env:nightlies_url} pandas scipy scikit-learn matplotlib
pip install --pre --upgrade -i {env:nightlies_url} numpy
pytest {posargs:}
pytest -n auto {posargs:}

[testenv:doc]
description = build doc with minimum supported version of python and all dependencies (plotting included).
Expand Down

0 comments on commit cf6bc7d

Please sign in to comment.