diff --git a/nilearn/maskers/surface_labels_masker.py b/nilearn/maskers/surface_labels_masker.py index 344761d749..0326bb37d6 100644 --- a/nilearn/maskers/surface_labels_masker.py +++ b/nilearn/maskers/surface_labels_masker.py @@ -20,6 +20,31 @@ from nilearn.surface import SurfaceImage +def _apply_mask(labels_masker, mask_data, labels_data): + """Apply mask to labels data.""" + labels_data[np.logical_not(mask_data.flatten())] = ( + labels_masker.background_label + ) + labels_after_mask = {int(label) for label in np.unique(labels_data)} + labels_before_mask = {int(label) for label in labels_masker._labels_} + labels_diff = labels_before_mask - labels_after_mask + if labels_diff: + warnings.warn( + "After applying mask to the labels image, " + "the following labels were " + f"removed: {labels_diff}. " + f"Out of {len(labels_before_mask)} labels, the " + "masked labels image only contains " + f"{len(labels_after_mask)} labels " + "(including background).", + stacklevel=3, + ) + labels = np.unique(labels_data) + labels = labels[labels != labels_masker.background_label] + + return labels_data, labels + + @fill_doc class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator): """Extract data from a SurfaceImage, averaging over atlas regions. @@ -29,7 +54,8 @@ class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator): Parameters ---------- labels_img : :obj:`~nilearn.surface.SurfaceImage` object - Region definitions, as one image of labels. + Region definitions, as one image of labels. The data for \ + each hemisphere is of shape (n_vertices_per_hemisphere, n_regions). labels : :obj:`list` of :obj:`str`, default=None Full labels corresponding to the labels image. @@ -47,6 +73,11 @@ class SurfaceLabelsMasker(TransformerMixin, CacheMixin, BaseEstimator): This value must be consistent with label values and image provided. + mask_img : :obj:`~nilearn.surface.SurfaceImage` object, optional + Mask to apply to labels_img before extracting signals. Defines the \ + overall area of the brain to consider. The data for each \ + hemisphere is of shape (n_vertices_per_hemisphere, n_regions). + %(smoothing_fwhm)s This parameter is not implemented yet. @@ -98,6 +129,7 @@ def __init__( labels_img=None, labels=None, background_label=0, + mask_img=None, smoothing_fwhm=None, standardize=False, standardize_confounds=True, @@ -116,6 +148,7 @@ def __init__( self.labels_img = labels_img self.labels = labels self.background_label = background_label + self.mask_img = mask_img self.smoothing_fwhm = smoothing_fwhm self.standardize = standardize self.standardize_confounds = standardize_confounds @@ -172,6 +205,13 @@ def fit(self, img=None, y=None): else: self.label_names_ = [self.labels[x] for x in self._labels_] + if self.mask_img is not None: + check_same_n_vertices(self.labels_img.mesh, self.mask_img.mesh) + + if not self.reports: + self._reporting_data = None + return self + self._shelving = False # content to inject in the HTML template self._report_content = { @@ -185,6 +225,7 @@ def fit(self, img=None, y=None): "number_of_regions": self.n_elements_, "summary": {}, } + for part in self.labels_img.data.parts: self._report_content["n_vertices"][part] = ( self.labels_img.mesh.parts[part].n_vertices @@ -285,6 +326,16 @@ def transform(self, img, confounds=None, sample_mask=None): # concatenate data over hemispheres img_data = np.concatenate(list(img.data.parts.values()), axis=0) + labels_data = self._labels_data + labels = self._labels_ + if self.mask_img is not None: + mask_data = np.concatenate( + list(self.mask_img.data.parts.values()), axis=0 + ) + labels_data, labels = _apply_mask( + self, mask_data, self._labels_data + ) + if self.smoothing_fwhm is not None: warnings.warn( "Parameter smoothing_fwhm " @@ -312,9 +363,9 @@ def transform(self, img, confounds=None, sample_mask=None): self.memory = Memory(location=None) n_time_points = 1 if len(img_data.shape) == 1 else img_data.shape[1] - output = np.empty((n_time_points, len(self._labels_))) - for i, label in enumerate(self._labels_): - output[:, i] = img_data[self._labels_data == label].mean(axis=0) + output = np.empty((n_time_points, len(labels))) + for i, label in enumerate(labels): + output[:, i] = img_data[labels_data == label].mean(axis=0) # signal cleaning here output = cache( @@ -374,29 +425,44 @@ def fit_transform(self, img, y=None, confounds=None, sample_mask=None): del y return self.fit().transform(img, confounds, sample_mask) - def inverse_transform(self, masked_img): - """Transform extracted signal back to surface object. + def inverse_transform(self, signals): + """Transform extracted signal back to surface image. Parameters ---------- - masked_img : :obj:`numpy.ndarray` - Extracted signal. + signals : :obj:`numpy.ndarray` + Extracted signal for each region. + If a 1D array is provided, then the shape of each hemisphere's data + should be (number of elements,) in the returned surface image. + If a 2D array is provided, then it would be + (number of scans, number of elements). + Returns ------- :obj:`~nilearn.surface.SurfaceImage` object Mesh and data for both hemispheres. """ + self._check_fitted() + + # we will only get the data back according to the mask that was applied + # so if some labels were removed, we will only get the data for the + # remaining labels, the vertices that were masked out will be set to 0 + labels = self._labels_ + if self.mask_img is not None: + mask_data = np.concatenate( + list(self.mask_img.data.parts.values()), axis=0 + ) + _, labels = _apply_mask(self, mask_data, self._labels_data) + data = {} for part_name, labels_part in self.labels_img.data.parts.items(): data[part_name] = np.zeros( - (labels_part.shape[0], masked_img.shape[0]), - dtype=masked_img.dtype, + (labels_part.shape[0], signals.shape[0]), + dtype=signals.dtype, ) - for label_idx, label in enumerate(self._labels_): - data[part_name][labels_part == label] = masked_img[ - :, label_idx - ].T + for label_idx, label in enumerate(labels): + data[part_name][labels_part == label] = signals[:, label_idx].T return SurfaceImage(mesh=self.labels_img.mesh, data=data) def generate_report(self): diff --git a/nilearn/maskers/tests/test_surface_labels_masker.py b/nilearn/maskers/tests/test_surface_labels_masker.py index 7226187b3e..f87d15e93c 100644 --- a/nilearn/maskers/tests/test_surface_labels_masker.py +++ b/nilearn/maskers/tests/test_surface_labels_masker.py @@ -122,6 +122,37 @@ def test_surface_label_masker_transform( assert signal.shape == (n_timepoints, n_labels) +def test_surface_label_masker_transform_with_mask(surf_mesh, surf_img_2d): + """Test transform extract signals with a mask and check warning.""" + # create a labels image + labels_data = { + "left": np.asarray([1, 1, 1, 2]), + "right": np.asarray([3, 3, 2, 2, 2]), + } + surf_label_img = SurfaceImage(surf_mesh(), labels_data) + + # create a mask image + # we are keeping labels 1 and 2 out of 3 + # so we should only get signals for labels 1 and 2 + # plus masker should throw a warning that label 3 is being removed due to + # mask + mask_data = { + "left": np.asarray([1, 1, 1, 1]), + "right": np.asarray([0, 0, 1, 1, 1]), + } + surf_mask = SurfaceImage(surf_mesh(), mask_data) + masker = SurfaceLabelsMasker(labels_img=surf_label_img, mask_img=surf_mask) + masker = masker.fit() + n_timepoints = 5 + with pytest.warns( + UserWarning, + match="the following labels were removed", + ): + signal = masker.transform(surf_img_2d(n_timepoints)) + assert isinstance(signal, np.ndarray) + assert signal.shape == (n_timepoints, 2) + + def test_surface_label_masker_check_output_1d(surf_mesh, rng): """Check actual content of the transform and inverse_transform. @@ -389,6 +420,51 @@ def test_surface_label_masker_inverse_transform(surf_label_img, surf_img_1d): assert img.shape == (surf_img_1d.shape[0], 1) +def test_surface_label_masker_inverse_transform_with_mask( + surf_mesh, surf_img_2d +): + """Test inverse_transform with mask: inverted image's shape, warning if + mask removes labels and data corresponding to removed labels is zeros. + """ + # create a labels image + labels_data = { + "left": np.asarray([1, 1, 1, 2]), + "right": np.asarray([3, 3, 2, 2, 2]), + } + surf_label_img = SurfaceImage(surf_mesh(), labels_data) + + # create a mask image + # we are keeping labels 1 and 3 out of 3 + # so we should only get signals for labels 1 and 3 + # plus masker should throw a warning that label 2 is being removed due to + # mask + mask_data = { + "left": np.asarray([1, 1, 1, 0]), + "right": np.asarray([1, 1, 0, 0, 0]), + } + surf_mask = SurfaceImage(surf_mesh(), mask_data) + masker = SurfaceLabelsMasker(labels_img=surf_label_img, mask_img=surf_mask) + masker = masker.fit() + n_timepoints = 5 + with pytest.warns( + UserWarning, + match="the following labels were removed", + ): + signal = masker.transform(surf_img_2d(n_timepoints)) + img_inverted = masker.inverse_transform(signal) + assert img_inverted.shape == surf_img_2d(n_timepoints).shape + # the data for label 2 should be zeros + assert np.all(img_inverted.data.parts["left"][-1, :] == 0) + assert np.all(img_inverted.data.parts["right"][2:, :] == 0) + + +def test_surface_label_masker_inverse_transform_before_fit(surf_label_img): + """Test inverse_transform requires masker to be fitted.""" + masker = SurfaceLabelsMasker(labels_img=surf_label_img) + with pytest.raises(ValueError, match="has not been fitted"): + masker.inverse_transform(np.zeros((1, 1))) + + def test_surface_label_masker_transform_list_surf_images( surf_label_img, surf_img_1d, surf_img_2d ):