| 
 | 1 | +# Modified from:  | 
 | 2 | +# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py  | 
 | 3 | +# Original copyright (c) 2024, NiPreps developers  | 
 | 4 | +# Licensed under the Apache License, Version 2.0  | 
 | 5 | +# Changes made by the BrainLesion Preprocessing team (2025)  | 
 | 6 | + | 
 | 7 | +from pathlib import Path  | 
 | 8 | +from typing import Optional, Union, cast  | 
 | 9 | + | 
 | 10 | +import nibabel as nib  | 
 | 11 | +import numpy as np  | 
 | 12 | +import scipy  | 
 | 13 | +import torch  | 
 | 14 | +from nibabel.nifti1 import Nifti1Image  | 
 | 15 | +from nipreps.synthstrip.model import StripModel  | 
 | 16 | +from nitransforms.linear import Affine  | 
 | 17 | + | 
 | 18 | +from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor  | 
 | 19 | +from brainles_preprocessing.utils.zenodo import fetch_synthstrip  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +class SynthStripExtractor(BrainExtractor):  | 
 | 23 | + | 
 | 24 | +    def __init__(self, border: int = 1):  | 
 | 25 | +        """  | 
 | 26 | +        Brain extraction using SynthStrip with preprocessing conforming to model requirements.  | 
 | 27 | +
  | 
 | 28 | +        This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]`  | 
 | 29 | +
  | 
 | 30 | +        Adapted from https://github.com/nipreps/synthstrip  | 
 | 31 | +
  | 
 | 32 | +        Args:  | 
 | 33 | +            border (int): Mask border threshold in mm. Defaults to 1.  | 
 | 34 | +        """  | 
 | 35 | + | 
 | 36 | +        super().__init__()  | 
 | 37 | +        self.border = border  | 
 | 38 | + | 
 | 39 | +    def _setup_model(self, device: torch.device) -> StripModel:  | 
 | 40 | +        """  | 
 | 41 | +        Load SynthStrip model and prepare it for inference on the specified device.  | 
 | 42 | +
  | 
 | 43 | +        Args:  | 
 | 44 | +            device: Device to load the model onto.  | 
 | 45 | +
  | 
 | 46 | +        Returns:  | 
 | 47 | +            A configured and ready-to-use StripModel.  | 
 | 48 | +        """  | 
 | 49 | +        # necessary for speed gains (according to original nipreps authors)  | 
 | 50 | +        torch.backends.cudnn.benchmark = True  | 
 | 51 | +        torch.backends.cudnn.deterministic = True  | 
 | 52 | + | 
 | 53 | +        with torch.no_grad():  | 
 | 54 | +            model = StripModel()  | 
 | 55 | +            model.to(device)  | 
 | 56 | +            model.eval()  | 
 | 57 | + | 
 | 58 | +        # Load the model weights  | 
 | 59 | +        weights_folder = fetch_synthstrip()  | 
 | 60 | +        weights = weights_folder / "synthstrip.1.pt"  | 
 | 61 | +        checkpoint = torch.load(weights, map_location=device)  | 
 | 62 | +        model.load_state_dict(checkpoint["model_state_dict"])  | 
 | 63 | + | 
 | 64 | +        return model  | 
 | 65 | + | 
 | 66 | +    def _conform(self, input_nii: Nifti1Image) -> Nifti1Image:  | 
 | 67 | +        """  | 
 | 68 | +        Resample the input image to match SynthStrip's expected input space.  | 
 | 69 | +
  | 
 | 70 | +        Args:  | 
 | 71 | +            input_nii (Nifti1Image): Input NIfTI image to conform.  | 
 | 72 | +
  | 
 | 73 | +        Raises:  | 
 | 74 | +            ValueError: If the input NIfTI image does not have a valid affine.  | 
 | 75 | +
  | 
 | 76 | +        Returns:  | 
 | 77 | +            A new NIfTI image with conformed shape and affine.  | 
 | 78 | +        """  | 
 | 79 | + | 
 | 80 | +        shape = np.array(input_nii.shape[:3])  | 
 | 81 | +        affine = input_nii.affine  | 
 | 82 | + | 
 | 83 | +        if affine is None:  | 
 | 84 | +            raise ValueError("Input NIfTI image must have a valid affine.")  | 
 | 85 | + | 
 | 86 | +        # Get corner voxel centers in index coords  | 
 | 87 | +        corner_centers_ijk = (  | 
 | 88 | +            np.array(  | 
 | 89 | +                [  | 
 | 90 | +                    (i, j, k)  | 
 | 91 | +                    for k in (0, shape[2] - 1)  | 
 | 92 | +                    for j in (0, shape[1] - 1)  | 
 | 93 | +                    for i in (0, shape[0] - 1)  | 
 | 94 | +                ]  | 
 | 95 | +            )  | 
 | 96 | +            + 0.5  | 
 | 97 | +        )  | 
 | 98 | + | 
 | 99 | +        # Get corner voxel centers in mm  | 
 | 100 | +        corners_xyz = (  | 
 | 101 | +            affine  | 
 | 102 | +            @ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T  | 
 | 103 | +        )  | 
 | 104 | + | 
 | 105 | +        # Target affine is 1mm voxels in LIA orientation  | 
 | 106 | +        target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)]  | 
 | 107 | + | 
 | 108 | +        # Target shape  | 
 | 109 | +        extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3]  | 
 | 110 | +        target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int)  | 
 | 111 | + | 
 | 112 | +        # SynthStrip likes dimensions be multiple of 64 (192, 256, or 320)  | 
 | 113 | +        target_shape = np.clip(  | 
 | 114 | +            np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320  | 
 | 115 | +        )  | 
 | 116 | + | 
 | 117 | +        # Ensure shape ordering is LIA too  | 
 | 118 | +        target_shape[2], target_shape[1] = target_shape[1:3]  | 
 | 119 | + | 
 | 120 | +        # Coordinates of center voxel do not change  | 
 | 121 | +        input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0))  | 
 | 122 | +        target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0))  | 
 | 123 | + | 
 | 124 | +        # Rebase the origin of the new, plumb affine  | 
 | 125 | +        target_affine[:3, 3] -= target_c[:3] - input_c[:3]  | 
 | 126 | + | 
 | 127 | +        nii = Affine(  | 
 | 128 | +            reference=Nifti1Image(  | 
 | 129 | +                np.zeros(target_shape),  | 
 | 130 | +                target_affine,  | 
 | 131 | +                None,  | 
 | 132 | +            ),  | 
 | 133 | +        ).apply(input_nii)  | 
 | 134 | +        return cast(Nifti1Image, nii)  | 
 | 135 | + | 
 | 136 | +    def _resample_like(  | 
 | 137 | +        self,  | 
 | 138 | +        image: Nifti1Image,  | 
 | 139 | +        target: Nifti1Image,  | 
 | 140 | +        output_dtype: Optional[np.dtype] = None,  | 
 | 141 | +        cval: Union[int, float] = 0,  | 
 | 142 | +    ) -> Nifti1Image:  | 
 | 143 | +        """  | 
 | 144 | +        Resample the input image to match the target's grid using an identity transform.  | 
 | 145 | +
  | 
 | 146 | +        Args:  | 
 | 147 | +            image: The image to be resampled.  | 
 | 148 | +            target: The reference image.  | 
 | 149 | +            output_dtype: Output data type.  | 
 | 150 | +            cval: Value to use for constant padding.  | 
 | 151 | +
  | 
 | 152 | +        Returns:  | 
 | 153 | +            A resampled NIfTI image.  | 
 | 154 | +        """  | 
 | 155 | +        result = Affine(reference=target).apply(  | 
 | 156 | +            image,  | 
 | 157 | +            output_dtype=output_dtype,  | 
 | 158 | +            cval=cval,  | 
 | 159 | +        )  | 
 | 160 | +        return cast(Nifti1Image, result)  | 
 | 161 | + | 
 | 162 | +    def extract(  | 
 | 163 | +        self,  | 
 | 164 | +        input_image_path: Union[str, Path],  | 
 | 165 | +        masked_image_path: Union[str, Path],  | 
 | 166 | +        brain_mask_path: Union[str, Path],  | 
 | 167 | +        device: Union[torch.device, str] = "cuda",  | 
 | 168 | +        num_threads: int = 1,  | 
 | 169 | +        **kwargs,  | 
 | 170 | +    ) -> None:  | 
 | 171 | +        """  | 
 | 172 | +        Extract the brain from an input image using SynthStrip.  | 
 | 173 | +
  | 
 | 174 | +        Args:  | 
 | 175 | +            input_image_path (Union[str, Path]): Path to the input image.  | 
 | 176 | +            masked_image_path (Union[str, Path]): Path to the output masked image.  | 
 | 177 | +            brain_mask_path (Union[str, Path]): Path to the output brain mask.  | 
 | 178 | +            device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda".  | 
 | 179 | +            num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1.  | 
 | 180 | +
  | 
 | 181 | +        Returns:  | 
 | 182 | +            None: The function saves the masked image and brain mask to the specified paths.  | 
 | 183 | +        """  | 
 | 184 | + | 
 | 185 | +        device = torch.device(device) if isinstance(device, str) else device  | 
 | 186 | +        model = self._setup_model(device=device)  | 
 | 187 | + | 
 | 188 | +        if device.type == "cpu" and num_threads > 0:  | 
 | 189 | +            torch.set_num_threads(num_threads)  | 
 | 190 | + | 
 | 191 | +        # normalize intensities  | 
 | 192 | +        image = nib.load(input_image_path)  | 
 | 193 | +        image = cast(Nifti1Image, image)  | 
 | 194 | +        conformed = self._conform(image)  | 
 | 195 | +        in_data = conformed.get_fdata(dtype="float32")  | 
 | 196 | +        in_data -= in_data.min()  | 
 | 197 | +        in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)  | 
 | 198 | +        in_data = in_data[np.newaxis, np.newaxis]  | 
 | 199 | + | 
 | 200 | +        # predict the surface distance transform  | 
 | 201 | +        input_tensor = torch.from_numpy(in_data).to(device)  | 
 | 202 | +        with torch.no_grad():  | 
 | 203 | +            sdt = model(input_tensor).cpu().numpy().squeeze()  | 
 | 204 | + | 
 | 205 | +        # unconform the sdt and extract mask  | 
 | 206 | +        sdt_target = self._resample_like(  | 
 | 207 | +            Nifti1Image(sdt, conformed.affine, None),  | 
 | 208 | +            image,  | 
 | 209 | +            output_dtype=np.dtype("int16"),  | 
 | 210 | +            cval=100,  | 
 | 211 | +        )  | 
 | 212 | +        sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16")  | 
 | 213 | + | 
 | 214 | +        # find largest CC (just do this to be safe for now)  | 
 | 215 | +        components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0]  | 
 | 216 | +        bincount = np.bincount(components.flatten())[1:]  | 
 | 217 | +        mask = components == (np.argmax(bincount) + 1)  | 
 | 218 | +        mask = scipy.ndimage.morphology.binary_fill_holes(mask)  | 
 | 219 | + | 
 | 220 | +        # write the masked output  | 
 | 221 | +        img_data = image.get_fdata()  | 
 | 222 | +        bg = np.min([0, img_data.min()])  | 
 | 223 | +        img_data[mask == 0] = bg  | 
 | 224 | +        Nifti1Image(img_data, image.affine, image.header).to_filename(  | 
 | 225 | +            masked_image_path,  | 
 | 226 | +        )  | 
 | 227 | + | 
 | 228 | +        # write the brain mask  | 
 | 229 | +        hdr = image.header.copy()  | 
 | 230 | +        hdr.set_data_dtype("uint8")  | 
 | 231 | +        Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path)  | 
0 commit comments