Skip to content

Commit

Permalink
[ENH] Accept mask_img in SurfaceLabelsMasker (nilearn#4937)
Browse files Browse the repository at this point in the history
Co-authored-by: Remi Gau <[email protected]>
  • Loading branch information
man-shu and Remi-Gau authored Dec 12, 2024
1 parent 3d90178 commit 547d238
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 14 deletions.
94 changes: 80 additions & 14 deletions nilearn/maskers/surface_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
from nilearn.surface import SurfaceImage


def _apply_mask(labels_masker, mask_data, labels_data):
"""Apply mask to labels data."""
labels_data[np.logical_not(mask_data.flatten())] = (
labels_masker.background_label
)
labels_after_mask = {int(label) for label in np.unique(labels_data)}
labels_before_mask = {int(label) for label in labels_masker._labels_}
labels_diff = labels_before_mask - labels_after_mask
if labels_diff:
warnings.warn(
"After applying mask to the labels image, "
"the following labels were "
f"removed: {labels_diff}. "
f"Out of {len(labels_before_mask)} labels, the "
"masked labels image only contains "
f"{len(labels_after_mask)} labels "
"(including background).",
stacklevel=3,
)
labels = np.unique(labels_data)
labels = labels[labels != labels_masker.background_label]

return labels_data, labels


@fill_doc
class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator):
"""Extract data from a SurfaceImage, averaging over atlas regions.
Expand All @@ -29,7 +54,8 @@ class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator):
Parameters
----------
labels_img : :obj:`~nilearn.surface.SurfaceImage` object
Region definitions, as one image of labels.
Region definitions, as one image of labels. The data for \
each hemisphere is of shape (n_vertices_per_hemisphere, n_regions).
labels : :obj:`list` of :obj:`str`, default=None
Full labels corresponding to the labels image.
Expand All @@ -47,6 +73,11 @@ class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator):
This value must be consistent with label values
and image provided.
mask_img : :obj:`~nilearn.surface.SurfaceImage` object, optional
Mask to apply to labels_img before extracting signals. Defines the \
overall area of the brain to consider. The data for each \
hemisphere is of shape (n_vertices_per_hemisphere, n_regions).
%(smoothing_fwhm)s
This parameter is not implemented yet.
Expand Down Expand Up @@ -98,6 +129,7 @@ def __init__(
labels_img=None,
labels=None,
background_label=0,
mask_img=None,
smoothing_fwhm=None,
standardize=False,
standardize_confounds=True,
Expand All @@ -116,6 +148,7 @@ def __init__(
self.labels_img = labels_img
self.labels = labels
self.background_label = background_label
self.mask_img = mask_img
self.smoothing_fwhm = smoothing_fwhm
self.standardize = standardize
self.standardize_confounds = standardize_confounds
Expand Down Expand Up @@ -172,6 +205,13 @@ def fit(self, img=None, y=None):
else:
self.label_names_ = [self.labels[x] for x in self._labels_]

if self.mask_img is not None:
check_same_n_vertices(self.labels_img.mesh, self.mask_img.mesh)

if not self.reports:
self._reporting_data = None
return self

self._shelving = False
# content to inject in the HTML template
self._report_content = {
Expand All @@ -185,6 +225,7 @@ def fit(self, img=None, y=None):
"number_of_regions": self.n_elements_,
"summary": {},
}

for part in self.labels_img.data.parts:
self._report_content["n_vertices"][part] = (
self.labels_img.mesh.parts[part].n_vertices
Expand Down Expand Up @@ -285,6 +326,16 @@ def transform(self, img, confounds=None, sample_mask=None):
# concatenate data over hemispheres
img_data = np.concatenate(list(img.data.parts.values()), axis=0)

labels_data = self._labels_data
labels = self._labels_
if self.mask_img is not None:
mask_data = np.concatenate(
list(self.mask_img.data.parts.values()), axis=0
)
labels_data, labels = _apply_mask(
self, mask_data, self._labels_data
)

if self.smoothing_fwhm is not None:
warnings.warn(
"Parameter smoothing_fwhm "
Expand Down Expand Up @@ -312,9 +363,9 @@ def transform(self, img, confounds=None, sample_mask=None):
self.memory = Memory(location=None)

n_time_points = 1 if len(img_data.shape) == 1 else img_data.shape[1]
output = np.empty((n_time_points, len(self._labels_)))
for i, label in enumerate(self._labels_):
output[:, i] = img_data[self._labels_data == label].mean(axis=0)
output = np.empty((n_time_points, len(labels)))
for i, label in enumerate(labels):
output[:, i] = img_data[labels_data == label].mean(axis=0)

# signal cleaning here
output = cache(
Expand Down Expand Up @@ -374,29 +425,44 @@ def fit_transform(self, img, y=None, confounds=None, sample_mask=None):
del y
return self.fit().transform(img, confounds, sample_mask)

def inverse_transform(self, masked_img):
"""Transform extracted signal back to surface object.
def inverse_transform(self, signals):
"""Transform extracted signal back to surface image.
Parameters
----------
masked_img : :obj:`numpy.ndarray`
Extracted signal.
signals : :obj:`numpy.ndarray`
Extracted signal for each region.
If a 1D array is provided, then the shape of each hemisphere's data
should be (number of elements,) in the returned surface image.
If a 2D array is provided, then it would be
(number of scans, number of elements).
Returns
-------
:obj:`~nilearn.surface.SurfaceImage` object
Mesh and data for both hemispheres.
"""
self._check_fitted()

# we will only get the data back according to the mask that was applied
# so if some labels were removed, we will only get the data for the
# remaining labels, the vertices that were masked out will be set to 0
labels = self._labels_
if self.mask_img is not None:
mask_data = np.concatenate(
list(self.mask_img.data.parts.values()), axis=0
)
_, labels = _apply_mask(self, mask_data, self._labels_data)

data = {}
for part_name, labels_part in self.labels_img.data.parts.items():
data[part_name] = np.zeros(
(labels_part.shape[0], masked_img.shape[0]),
dtype=masked_img.dtype,
(labels_part.shape[0], signals.shape[0]),
dtype=signals.dtype,
)
for label_idx, label in enumerate(self._labels_):
data[part_name][labels_part == label] = masked_img[
:, label_idx
].T
for label_idx, label in enumerate(labels):
data[part_name][labels_part == label] = signals[:, label_idx].T
return SurfaceImage(mesh=self.labels_img.mesh, data=data)

def generate_report(self):
Expand Down
76 changes: 76 additions & 0 deletions nilearn/maskers/tests/test_surface_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,37 @@ def test_surface_label_masker_transform(
assert signal.shape == (n_timepoints, n_labels)


def test_surface_label_masker_transform_with_mask(surf_mesh, surf_img_2d):
"""Test transform extract signals with a mask and check warning."""
# create a labels image
labels_data = {
"left": np.asarray([1, 1, 1, 2]),
"right": np.asarray([3, 3, 2, 2, 2]),
}
surf_label_img = SurfaceImage(surf_mesh(), labels_data)

# create a mask image
# we are keeping labels 1 and 2 out of 3
# so we should only get signals for labels 1 and 2
# plus masker should throw a warning that label 3 is being removed due to
# mask
mask_data = {
"left": np.asarray([1, 1, 1, 1]),
"right": np.asarray([0, 0, 1, 1, 1]),
}
surf_mask = SurfaceImage(surf_mesh(), mask_data)
masker = SurfaceLabelsMasker(labels_img=surf_label_img, mask_img=surf_mask)
masker = masker.fit()
n_timepoints = 5
with pytest.warns(
UserWarning,
match="the following labels were removed",
):
signal = masker.transform(surf_img_2d(n_timepoints))
assert isinstance(signal, np.ndarray)
assert signal.shape == (n_timepoints, 2)


def test_surface_label_masker_check_output_1d(surf_mesh, rng):
"""Check actual content of the transform and inverse_transform.
Expand Down Expand Up @@ -389,6 +420,51 @@ def test_surface_label_masker_inverse_transform(surf_label_img, surf_img_1d):
assert img.shape == (surf_img_1d.shape[0], 1)


def test_surface_label_masker_inverse_transform_with_mask(
surf_mesh, surf_img_2d
):
"""Test inverse_transform with mask: inverted image's shape, warning if
mask removes labels and data corresponding to removed labels is zeros.
"""
# create a labels image
labels_data = {
"left": np.asarray([1, 1, 1, 2]),
"right": np.asarray([3, 3, 2, 2, 2]),
}
surf_label_img = SurfaceImage(surf_mesh(), labels_data)

# create a mask image
# we are keeping labels 1 and 3 out of 3
# so we should only get signals for labels 1 and 3
# plus masker should throw a warning that label 2 is being removed due to
# mask
mask_data = {
"left": np.asarray([1, 1, 1, 0]),
"right": np.asarray([1, 1, 0, 0, 0]),
}
surf_mask = SurfaceImage(surf_mesh(), mask_data)
masker = SurfaceLabelsMasker(labels_img=surf_label_img, mask_img=surf_mask)
masker = masker.fit()
n_timepoints = 5
with pytest.warns(
UserWarning,
match="the following labels were removed",
):
signal = masker.transform(surf_img_2d(n_timepoints))
img_inverted = masker.inverse_transform(signal)
assert img_inverted.shape == surf_img_2d(n_timepoints).shape
# the data for label 2 should be zeros
assert np.all(img_inverted.data.parts["left"][-1, :] == 0)
assert np.all(img_inverted.data.parts["right"][2:, :] == 0)


def test_surface_label_masker_inverse_transform_before_fit(surf_label_img):
"""Test inverse_transform requires masker to be fitted."""
masker = SurfaceLabelsMasker(labels_img=surf_label_img)
with pytest.raises(ValueError, match="has not been fitted"):
masker.inverse_transform(np.zeros((1, 1)))


def test_surface_label_masker_transform_list_surf_images(
surf_label_img, surf_img_1d, surf_img_2d
):
Expand Down

0 comments on commit 547d238

Please sign in to comment.