Skip to content

Commit

Permalink
fix: use sample_mask of nilearn.signal.clean for scrubbing in fMRIPre…
Browse files Browse the repository at this point in the history
…pConfoundRemover
  • Loading branch information
synchon committed Jan 24, 2025
1 parent 1332737 commit 2080792
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions junifer/preprocess/confounds/fmriprep_confound_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,6 @@ def _process_fmriprep_spec(
"Check if this file is really an fmriprep confounds file. "
"You can also deactivate spike (set spike = None)."
)
# Add std_dvars
if self.std_dvars_threshold is not None:
if "std_dvars" not in available_vars:
raise_error(
"Invalid confounds file. Missing std_dvars "
"(standardized DVARS) confound. "
"Check if this file is really an fMRIPrep confounds file. "
)
out = to_select, squares_to_compute, derivatives_to_compute, spike_name
return out

Expand Down Expand Up @@ -423,14 +415,6 @@ def _pick_confounds(self, input: dict[str, Any]) -> pd.DataFrame:
out_df["spike"] = fd
to_select.append("spike")

# add binary std_dvars regressor if needed at given threshold
if self.std_dvars_threshold is not None:
std_dvars = out_df["std_dvars"].copy()
std_dvars.loc[std_dvars > self.std_dvars_threshold] = 1
std_dvars.loc[std_dvars != 1] = 0
out_df["std_dvars"] = std_dvars
to_select.append("std_dvars")

# Now pick all the relevant confounds
out_df = out_df[to_select]

Expand All @@ -442,6 +426,41 @@ def _pick_confounds(self, input: dict[str, Any]) -> pd.DataFrame:

return out_df

def _get_scrub_mask(self, input: dict[str, Any]) -> np.ndarray:
"""Get boolean mask for scrubbing.
Parameters
----------
input : dict
Dictionary containing the ``BOLD.confounds`` value from the
Junifer Data object.
Returns
-------
numpy.ndarray
Index of volumes to be kept.
Raises
------
RuntimeError
If ``std_dvars`` is not found in the confounds file.
"""
confounds_df = input["data"]
# Check confounds file
if "std_dvars" not in confounds_df.columns:
raise_error(
"Invalid confounds file. Missing std_dvars "
"(standardized DVARS) confound. "
"Check if this file is really an fMRIPrep confounds file. "
)
# Make first row 0 and then threshold
return np.flatnonzero(
(
confounds_df["std_dvars"].fillna(0) > self.std_dvars_threshold
).to_numpy()
)

def _validate_data(
self,
input: dict[str, Any],
Expand Down Expand Up @@ -605,6 +624,13 @@ def preprocess(
}
}
)

# Set up scrubbing mask if needed
if self.std_dvars_threshold is not None:
sample_mask = self._get_scrub_mask(input["confounds"])
else:
sample_mask = None

# Clean image
logger.info("Cleaning image using nilearn")
logger.debug(f"\tdetrend: {self.detrend}")
Expand All @@ -622,6 +648,7 @@ def preprocess(
high_pass=self.high_pass,
t_r=t_r,
mask_img=mask_img,
clean__sample_mask=sample_mask,
)
# Fix t_r as nilearn messes it up
cleaned_img.header["pixdim"][4] = t_r
Expand Down

0 comments on commit 2080792

Please sign in to comment.