Skip to content

Commit

Permalink
[MAINT] check compatibility of (multi)nifti masker with scikitlearn (n…
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau authored Dec 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 8ef2050 commit d57c0d8
Showing 7 changed files with 127 additions and 39 deletions.
4 changes: 2 additions & 2 deletions doc/manipulating_images/masker_objects.rst
Original file line number Diff line number Diff line change
@@ -193,10 +193,10 @@ preparation::
>>> from nilearn import maskers
>>> masker = maskers.NiftiMasker()
>>> masker # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
NiftiMasker(detrend=False, dtype=None, high_pass=None,
NiftiMasker(cmap='CMRmap_r', detrend=False, dtype=None, high_pass=None,
high_variance_confounds=False, low_pass=None, mask_args=None,
mask_img=None, mask_strategy='background',
memory=Memory(location=None), memory_level=1, reports=True,
memory=None, memory_level=1, reports=True,
runs=None, smoothing_fwhm=None, standardize=False,
standardize_confounds=True, t_r=None,
target_affine=None, target_shape=None, verbose=0)
2 changes: 1 addition & 1 deletion nilearn/glm/tests/test_second_level.py
Original file line number Diff line number Diff line change
@@ -567,7 +567,7 @@ def test_high_level_glm_with_paths_errors(tmp_path):
X = pd.DataFrame([[1]] * 4, columns=["intercept"])

# Provide a masker as mask_img
masker = NiftiMasker(mask)
masker = NiftiMasker(mask).fit()
with pytest.warns(
UserWarning, match="Parameter memory of the masker overridden"
):
41 changes: 34 additions & 7 deletions nilearn/maskers/multi_nifti_masker.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import warnings
from functools import partial

from joblib import Memory, Parallel, delayed
from joblib import Parallel, delayed

from nilearn import image, masking
from nilearn._utils import (
@@ -77,17 +77,26 @@ class MultiNiftiMasker(NiftiMasker, CacheMixin):
Mask of the data. If not given, a mask is computed in the fit step.
Optional parameters can be set using mask_args and mask_strategy to
fine tune the mask extraction.
%(smoothing_fwhm)s
%(standardize_maskers)s
%(standardize_confounds)s
high_variance_confounds : :obj:`bool`, default=False
If True, high variance confounds are computed on provided image with
:func:`nilearn.image.high_variance_confounds` and default parameters
and regressed out.
%(detrend)s
%(low_pass)s
%(high_pass)s
%(t_r)s
target_affine : 3x3 or 4x4 :obj:`numpy.ndarray`, optional
This parameter is passed to image.resample_img. Please see the
related documentation for details.
@@ -117,10 +126,15 @@ class MultiNiftiMasker(NiftiMasker, CacheMixin):
Data type toward which the data should be converted. If "auto", the
data will be converted to int32 if dtype is discrete and float32 if it
is continuous.
%(memory)s
%(memory_level)s
%(n_jobs)s
%(verbose0)s
%(masker_kwargs)s
Attributes
@@ -165,10 +179,9 @@ def __init__(
memory_level=0,
n_jobs=1,
verbose=0,
cmap="CMRmap_r",
**kwargs,
):
if memory is None:
memory = Memory(location=None)
super().__init__(
# Mask is provided or computed
mask_img=mask_img,
@@ -185,16 +198,13 @@ def __init__(
mask_strategy=mask_strategy,
mask_args=mask_args,
dtype=dtype,
clean_kwargs={
k[7:]: v for k, v in kwargs.items() if k.startswith("clean__")
},
memory=memory,
memory_level=memory_level,
verbose=verbose,
cmap=cmap,
**kwargs,
)
self.n_jobs = n_jobs
self._shelving = False

def fit(
self,
@@ -215,6 +225,23 @@ def fit(
compatibility.
"""
if getattr(self, "_shelving", None) is None:
self._shelving = False

self._report_content = {
"description": (
"This report shows the input Nifti image overlaid "
"with the outlines of the mask (in green). We "
"recommend to inspect the report for the overlap "
"between the mask and its input image. "
),
"warning_message": None,
}
self._overlay_text = (
"\n To see the input Nifti image before resampling, "
"hover over the displayed image."
)

# Load data (if filenames are given, load them)
logger.log(
f"Loading data from {repr_niimgs(imgs, shorten=False)}.",
69 changes: 47 additions & 22 deletions nilearn/maskers/nifti_masker.py
Original file line number Diff line number Diff line change
@@ -171,24 +171,34 @@ class NiftiMasker(BaseMasker):
runs : :obj:`numpy.ndarray`, optional
Add a run level to the preprocessing. Each run will be
detrended independently. Must be a 1D array of n_samples elements.
%(smoothing_fwhm)s
%(standardize_maskers)s
%(standardize_confounds)s
high_variance_confounds : :obj:`bool`, default=False
If True, high variance confounds are computed on provided image with
:func:`nilearn.image.high_variance_confounds` and default parameters
and regressed out.
%(detrend)s
%(low_pass)s
%(high_pass)s
%(t_r)s
target_affine : 3x3 or 4x4 :obj:`numpy.ndarray`, optional
This parameter is passed to image.resample_img. Please see the
related documentation for details.
target_shape : 3-:obj:`tuple` of :obj:`int`, optional
This parameter is passed to image.resample_img. Please see the
related documentation for details.
%(mask_strategy)s
.. note::
@@ -210,12 +220,20 @@ class NiftiMasker(BaseMasker):
Data type toward which the data should be converted. If "auto", the
data will be converted to int32 if dtype is discrete and float32 if it
is continuous.
%(memory)s
%(memory_level1)s
%(verbose0)s
reports : :obj:`bool`, default=True
If set to True, data is saved in order to produce a report.
%(cmap)s
default="CMRmap_r"
Only relevant for the report figures.
%(masker_kwargs)s
Attributes
@@ -263,10 +281,9 @@ def __init__(
memory=None,
verbose=0,
reports=True,
cmap="CMRmap_r",
**kwargs,
):
if memory is None:
memory = Memory(location=None)
# Mask is provided or computed
self.mask_img = mask_img
self.runs = runs
@@ -283,30 +300,12 @@ def __init__(
self.mask_strategy = mask_strategy
self.mask_args = mask_args
self.dtype = dtype

self.memory = memory
self.memory_level = memory_level
self.verbose = verbose
self.reports = reports
self._report_content = {
"description": (
"This report shows the input Nifti image overlaid "
"with the outlines of the mask (in green). We "
"recommend to inspect the report for the overlap "
"between the mask and its input image. "
),
"warning_message": None,
}
self._overlay_text = (
"\n To see the input Nifti image before resampling, "
"hover over the displayed image."
)
self._shelving = False
self.clean_kwargs = {
k[7:]: v for k, v in kwargs.items() if k.startswith("clean__")
}

self.cmap = kwargs.get("cmap", "CMRmap_r")
self.cmap = cmap
self.clean_kwargs = kwargs

def generate_report(self):
"""Generate a report of the masker."""
@@ -429,6 +428,32 @@ def fit(
compatibility.
"""
self._report_content = {
"description": (
"This report shows the input Nifti image overlaid "
"with the outlines of the mask (in green). We "
"recommend to inspect the report for the overlap "
"between the mask and its input image. "
),
"warning_message": None,
}
self._overlay_text = (
"\n To see the input Nifti image before resampling, "
"hover over the displayed image."
)

if getattr(self, "_shelving", None) is None:
self._shelving = False

if self.memory is None:
self.memory = Memory(location=None)

self.clean_kwargs = {
k[7:]: v
for k, v in self.clean_kwargs.items()
if k.startswith("clean__")
}

# Load data (if filenames are given, load them)
logger.log(
f"Loading data from {_utils.repr_niimgs(imgs, shorten=False)}",
10 changes: 6 additions & 4 deletions nilearn/maskers/tests/test_multi_nifti_masker.py
Original file line number Diff line number Diff line change
@@ -14,20 +14,21 @@
from nilearn._utils.exceptions import DimensionError
from nilearn._utils.testing import write_imgs_to_path
from nilearn.image import get_data
from nilearn.maskers import MultiNiftiMasker, NiftiMasker
from nilearn.maskers import MultiNiftiMasker

extra_valid_checks = [
"check_estimators_unfitted",
"check_get_params_invariance",
"check_transformer_n_iter",
"check_transformers_unfitted",
"check_parameters_default_constructible",
]


@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[MultiNiftiMasker(), NiftiMasker()],
estimator=[MultiNiftiMasker()],
extra_valid_checks=extra_valid_checks,
),
)
@@ -40,7 +41,7 @@ def test_check_estimator(estimator, check, name): # noqa: ARG001
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[MultiNiftiMasker(), NiftiMasker()],
estimator=[MultiNiftiMasker()],
extra_valid_checks=extra_valid_checks,
valid=False,
),
@@ -208,9 +209,10 @@ def test_shelving():
memory=Memory(location=cachedir, mmap_mode="r", verbose=0),
)
masker_shelved._shelving = True
masker = MultiNiftiMasker(mask_img=mask_img)
epis_shelved = masker_shelved.fit_transform([epi_img1, epi_img2])
masker = MultiNiftiMasker(mask_img=mask_img)
epis = masker.fit_transform([epi_img1, epi_img2])

for epi_shelved, epi in zip(epis_shelved, epis):
epi_shelved = epi_shelved.get()
assert_array_equal(epi_shelved, epi)
36 changes: 35 additions & 1 deletion nilearn/maskers/tests/test_nifti_masker.py
Original file line number Diff line number Diff line change
@@ -17,12 +17,46 @@
from numpy.testing import assert_array_equal

from nilearn._utils import data_gen, exceptions, testing
from nilearn._utils.class_inspect import get_params
from nilearn._utils.class_inspect import check_estimator, get_params
from nilearn._utils.helpers import is_matplotlib_installed
from nilearn.image import get_data, index_img
from nilearn.maskers import NiftiMasker
from nilearn.maskers.nifti_masker import _filter_and_mask

extra_valid_checks = [
"check_parameters_default_constructible",
"check_estimators_unfitted",
"check_get_params_invariance",
"check_transformer_n_iter",
"check_transformers_unfitted",
]


@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[NiftiMasker()],
extra_valid_checks=extra_valid_checks,
),
)
def test_check_estimator(estimator, check, name): # noqa: ARG001
"""Check compliance with sklearn estimators."""
check(estimator)


@pytest.mark.xfail(reason="invalid checks should fail")
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[NiftiMasker()],
extra_valid_checks=extra_valid_checks,
valid=False,
),
)
def test_check_estimator_invalid(estimator, check, name): # noqa: ARG001
"""Check compliance with sklearn estimators."""
check(estimator)


def test_auto_mask(img_3d_rand_eye):
"""Perform a smoke test on the auto-mask option."""
4 changes: 2 additions & 2 deletions nilearn/reporting/tests/test_html_report.py
Original file line number Diff line number Diff line change
@@ -115,7 +115,7 @@ def input_parameters(
def test_report_empty_fit(masker_class, input_parameters):
"""Test minimal report generation."""
masker = masker_class(**input_parameters)
masker.fit()
masker = masker.fit()
_check_html(masker.generate_report())


@@ -177,8 +177,8 @@ def test_warning_in_report_after_empty_fit(masker_class, input_parameters):
if no images were provided to fit.
"""
masker = masker_class(**input_parameters)
assert masker._report_content["warning_message"] is None
masker.fit()
assert masker._report_content["warning_message"] is None
warn_message = f"No image provided to fit in {masker_class.__name__}."
with pytest.warns(UserWarning, match=warn_message):
html = masker.generate_report()

0 comments on commit d57c0d8

Please sign in to comment.