Skip to content

Commit 8b5dd3b

Browse files
authored
Merge pull request #150 from BrainLesion/132-feature-consider-synthstrip-brain-extraction
132 feature consider synthstrip brain extraction
2 parents 813b2be + 23dd8f6 commit 8b5dd3b

File tree

13 files changed

+598
-288
lines changed

13 files changed

+598
-288
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ dmypy.json
133133
.DS_Store
134134

135135
brainles_preprocessing/registration/atlases
136+
brainles_preprocessing/brain_extraction/weights

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ We provide a (WIP) documentation. Have a look [here](https://brainles-preprocess
138138
Please credit the authors by citing their work.
139139

140140
### Registration
141-
We currently provide support for [ANTs](https://github.com/ANTsX/ANTs) (default), [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux). We also offer basic support for [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) and [elastix](https://pypi.org/project/itk-elastix/0.13.0/).
141+
We currently fully support:
142+
- [ANTs](https://github.com/ANTsX/ANTs) (default)
143+
- [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux)
144+
145+
We also offer basic support for:
146+
- [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) (Optional dependency, install via: `pip install brainles_preprocessing[picsl_greedy]`)
147+
- [elastix](https://pypi.org/project/itk-elastix/0.13.0/) (Optional dependency, install via: `pip install brainles_preprocessing[itk-elastix]`)
148+
149+
As of now we do not offer inverse transforms for greedy and elastix. Please resort to ANTs or Niftyreg for this.
142150

143151
### Atlas Reference
144152
We provide a range of different atlases, namely:
@@ -154,7 +162,9 @@ We also support supplying a custom atlas in NIfTI format
154162
We currently provide support for N4 Bias correction based on [SimpleITK](https://simpleitk.org/)
155163

156164
### Brain extraction
157-
We currently provide support for [HD-BET](https://github.com/MIC-DKFZ/HD-BET).
165+
We currently support:
166+
- [HD-BET](https://github.com/MIC-DKFZ/HD-BET)
167+
- [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) (Optional dependency, install via: `pip install brainles_preprocessing[synthstrip]`)
158168

159169
### Defacing
160170
We currently provide support for [Quickshear](https://github.com/nipy/quickshear).

brainles_preprocessing/brain_extraction/brain_extractor.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ def extract(
2121
input_image_path: Union[str, Path],
2222
masked_image_path: Union[str, Path],
2323
brain_mask_path: Union[str, Path],
24-
log_file_path: Optional[Union[str, Path]],
25-
mode: Union[str, Mode],
2624
**kwargs,
2725
) -> None:
2826
"""
@@ -32,7 +30,6 @@ def extract(
3230
input_image_path (str or Path): Path to the input image.
3331
masked_image_path (str or Path): Path where the brain-extracted image will be saved.
3432
brain_mask_path (str or Path): Path where the brain mask will be saved.
35-
log_file_path (str or Path, Optional): Path to the log file.
3633
mode (str or Mode): Extraction mode.
3734
**kwargs: Additional keyword arguments.
3835
"""
@@ -86,11 +83,10 @@ def extract(
8683
input_image_path: Union[str, Path],
8784
masked_image_path: Union[str, Path],
8885
brain_mask_path: Union[str, Path],
89-
log_file_path: Optional[Union[str, Path]] = None,
90-
# TODO convert mode to enum
9186
mode: Union[str, Mode] = Mode.ACCURATE,
9287
device: Optional[Union[int, str]] = 0,
93-
do_tta: Optional[bool] = True,
88+
do_tta: bool = True,
89+
**kwargs,
9490
) -> None:
9591
# GPU + accurate + TTA
9692
"""
@@ -100,7 +96,6 @@ def extract(
10096
input_image_path (str or Path): Path to the input image.
10197
masked_image_path (str or Path): Path where the brain-extracted image will be saved.
10298
brain_mask_path (str or Path): Path where the brain mask will be saved.
103-
log_file_path (str or Path, Optional): Path to the log file.
10499
mode (str or Mode): Extraction mode ('fast' or 'accurate').
105100
device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU).
106101
do_tta (bool): whether to do test time data augmentation by mirroring along all axes.
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Modified from:
2+
# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py
3+
# Original copyright (c) 2024, NiPreps developers
4+
# Licensed under the Apache License, Version 2.0
5+
# Changes made by the BrainLesion Preprocessing team (2025)
6+
7+
from pathlib import Path
8+
from typing import Optional, Union, cast
9+
10+
import nibabel as nib
11+
import numpy as np
12+
import scipy
13+
import torch
14+
from nibabel.nifti1 import Nifti1Image
15+
from nipreps.synthstrip.model import StripModel
16+
from nitransforms.linear import Affine
17+
18+
from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
19+
from brainles_preprocessing.utils.zenodo import fetch_synthstrip
20+
21+
22+
class SynthStripExtractor(BrainExtractor):
23+
24+
def __init__(self, border: int = 1):
25+
"""
26+
Brain extraction using SynthStrip with preprocessing conforming to model requirements.
27+
28+
This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]`
29+
30+
Adapted from https://github.com/nipreps/synthstrip
31+
32+
Args:
33+
border (int): Mask border threshold in mm. Defaults to 1.
34+
"""
35+
36+
super().__init__()
37+
self.border = border
38+
39+
def _setup_model(self, device: torch.device) -> StripModel:
40+
"""
41+
Load SynthStrip model and prepare it for inference on the specified device.
42+
43+
Args:
44+
device: Device to load the model onto.
45+
46+
Returns:
47+
A configured and ready-to-use StripModel.
48+
"""
49+
# necessary for speed gains (according to original nipreps authors)
50+
torch.backends.cudnn.benchmark = True
51+
torch.backends.cudnn.deterministic = True
52+
53+
with torch.no_grad():
54+
model = StripModel()
55+
model.to(device)
56+
model.eval()
57+
58+
# Load the model weights
59+
weights_folder = fetch_synthstrip()
60+
weights = weights_folder / "synthstrip.1.pt"
61+
checkpoint = torch.load(weights, map_location=device)
62+
model.load_state_dict(checkpoint["model_state_dict"])
63+
64+
return model
65+
66+
def _conform(self, input_nii: Nifti1Image) -> Nifti1Image:
67+
"""
68+
Resample the input image to match SynthStrip's expected input space.
69+
70+
Args:
71+
input_nii (Nifti1Image): Input NIfTI image to conform.
72+
73+
Raises:
74+
ValueError: If the input NIfTI image does not have a valid affine.
75+
76+
Returns:
77+
A new NIfTI image with conformed shape and affine.
78+
"""
79+
80+
shape = np.array(input_nii.shape[:3])
81+
affine = input_nii.affine
82+
83+
if affine is None:
84+
raise ValueError("Input NIfTI image must have a valid affine.")
85+
86+
# Get corner voxel centers in index coords
87+
corner_centers_ijk = (
88+
np.array(
89+
[
90+
(i, j, k)
91+
for k in (0, shape[2] - 1)
92+
for j in (0, shape[1] - 1)
93+
for i in (0, shape[0] - 1)
94+
]
95+
)
96+
+ 0.5
97+
)
98+
99+
# Get corner voxel centers in mm
100+
corners_xyz = (
101+
affine
102+
@ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T
103+
)
104+
105+
# Target affine is 1mm voxels in LIA orientation
106+
target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)]
107+
108+
# Target shape
109+
extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3]
110+
target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int)
111+
112+
# SynthStrip likes dimensions be multiple of 64 (192, 256, or 320)
113+
target_shape = np.clip(
114+
np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320
115+
)
116+
117+
# Ensure shape ordering is LIA too
118+
target_shape[2], target_shape[1] = target_shape[1:3]
119+
120+
# Coordinates of center voxel do not change
121+
input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0))
122+
target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0))
123+
124+
# Rebase the origin of the new, plumb affine
125+
target_affine[:3, 3] -= target_c[:3] - input_c[:3]
126+
127+
nii = Affine(
128+
reference=Nifti1Image(
129+
np.zeros(target_shape),
130+
target_affine,
131+
None,
132+
),
133+
).apply(input_nii)
134+
return cast(Nifti1Image, nii)
135+
136+
def _resample_like(
137+
self,
138+
image: Nifti1Image,
139+
target: Nifti1Image,
140+
output_dtype: Optional[np.dtype] = None,
141+
cval: Union[int, float] = 0,
142+
) -> Nifti1Image:
143+
"""
144+
Resample the input image to match the target's grid using an identity transform.
145+
146+
Args:
147+
image: The image to be resampled.
148+
target: The reference image.
149+
output_dtype: Output data type.
150+
cval: Value to use for constant padding.
151+
152+
Returns:
153+
A resampled NIfTI image.
154+
"""
155+
result = Affine(reference=target).apply(
156+
image,
157+
output_dtype=output_dtype,
158+
cval=cval,
159+
)
160+
return cast(Nifti1Image, result)
161+
162+
def extract(
163+
self,
164+
input_image_path: Union[str, Path],
165+
masked_image_path: Union[str, Path],
166+
brain_mask_path: Union[str, Path],
167+
device: Union[torch.device, str] = "cuda",
168+
num_threads: int = 1,
169+
**kwargs,
170+
) -> None:
171+
"""
172+
Extract the brain from an input image using SynthStrip.
173+
174+
Args:
175+
input_image_path (Union[str, Path]): Path to the input image.
176+
masked_image_path (Union[str, Path]): Path to the output masked image.
177+
brain_mask_path (Union[str, Path]): Path to the output brain mask.
178+
device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda".
179+
num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1.
180+
181+
Returns:
182+
None: The function saves the masked image and brain mask to the specified paths.
183+
"""
184+
185+
device = torch.device(device) if isinstance(device, str) else device
186+
model = self._setup_model(device=device)
187+
188+
if device.type == "cpu" and num_threads > 0:
189+
torch.set_num_threads(num_threads)
190+
191+
# normalize intensities
192+
image = nib.load(input_image_path)
193+
image = cast(Nifti1Image, image)
194+
conformed = self._conform(image)
195+
in_data = conformed.get_fdata(dtype="float32")
196+
in_data -= in_data.min()
197+
in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)
198+
in_data = in_data[np.newaxis, np.newaxis]
199+
200+
# predict the surface distance transform
201+
input_tensor = torch.from_numpy(in_data).to(device)
202+
with torch.no_grad():
203+
sdt = model(input_tensor).cpu().numpy().squeeze()
204+
205+
# unconform the sdt and extract mask
206+
sdt_target = self._resample_like(
207+
Nifti1Image(sdt, conformed.affine, None),
208+
image,
209+
output_dtype=np.dtype("int16"),
210+
cval=100,
211+
)
212+
sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16")
213+
214+
# find largest CC (just do this to be safe for now)
215+
components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0]
216+
bincount = np.bincount(components.flatten())[1:]
217+
mask = components == (np.argmax(bincount) + 1)
218+
mask = scipy.ndimage.morphology.binary_fill_holes(mask)
219+
220+
# write the masked output
221+
img_data = image.get_fdata()
222+
bg = np.min([0, img_data.min()])
223+
img_data[mask == 0] = bg
224+
Nifti1Image(img_data, image.affine, image.header).to_filename(
225+
masked_image_path,
226+
)
227+
228+
# write the brain mask
229+
hdr = image.header.copy()
230+
hdr.set_data_dtype("uint8")
231+
Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path)

brainles_preprocessing/modality.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Dict, Optional, Union
66

7+
import torch
78
from auxiliary.io import read_image, write_image
89
from loguru import logger
910

@@ -16,7 +17,7 @@
1617
NiftyRegRegistrator,
1718
)
1819
from brainles_preprocessing.registration.registrator import Registrator
19-
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
20+
from brainles_preprocessing.utils.zenodo import fetch_atlases
2021

2122

2223
class Modality:
@@ -598,6 +599,7 @@ def extract_brain_region(
598599
self,
599600
brain_extractor: BrainExtractor,
600601
bet_dir_path: Union[str, Path],
602+
use_gpu: bool = True,
601603
) -> Path:
602604
"""
603605
@@ -606,6 +608,7 @@ def extract_brain_region(
606608
Args:
607609
brain_extractor (BrainExtractor): The brain extractor object.
608610
bet_dir_path (str or Path): Directory to store brain extraction results.
611+
use_gpu (bool): Whether to use GPU for brain extraction if available.
609612
610613
Returns:
611614
Path: Path to the extracted brain mask.
@@ -617,11 +620,16 @@ def extract_brain_region(
617620
bet = bet_dir_path / f"{self.modality_name}_bet.nii.gz"
618621
mask_path = bet_dir_path / f"{self.modality_name}_brain_mask.nii.gz"
619622

623+
device = torch.device(
624+
"cuda" if use_gpu and torch.cuda.is_available() else "cpu"
625+
)
626+
620627
brain_extractor.extract(
621628
input_image_path=self.current,
622629
masked_image_path=bet,
623630
brain_mask_path=mask_path,
624631
log_file_path=bet_log,
632+
device=device,
625633
)
626634

627635
# always temporarily store bet image for center modality, since e.g. quickshear defacing could require it
@@ -666,7 +674,7 @@ def deface(
666674

667675
# resolve atlas image path
668676
if isinstance(defacer.atlas_image_path, Atlas):
669-
atlas_folder = verify_or_download_atlases()
677+
atlas_folder = fetch_atlases()
670678
atlas_image_path = atlas_folder / defacer.atlas_image_path.value
671679
else:
672680
atlas_image_path = Path(defacer.atlas_image_path)

brainles_preprocessing/preprocessor/atlas_centric_preprocessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from brainles_preprocessing.n4_bias_correction import N4BiasCorrector
1212
from brainles_preprocessing.preprocessor.preprocessor import BasePreprocessor
1313
from brainles_preprocessing.registration.registrator import Registrator
14-
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
14+
from brainles_preprocessing.utils.zenodo import fetch_atlases
1515

1616

1717
class AtlasCentricPreprocessor(BasePreprocessor):
@@ -42,7 +42,7 @@ def __init__(
4242
atlas_image_path: Union[str, Path, Atlas] = Atlas.BRATS_SRI24,
4343
n4_bias_corrector: Optional[N4BiasCorrector] = None,
4444
temp_folder: Optional[Union[str, Path]] = None,
45-
use_gpu: Optional[bool] = None,
45+
use_gpu: bool = True,
4646
limit_cuda_visible_devices: Optional[str] = None,
4747
):
4848
super().__init__(
@@ -58,7 +58,7 @@ def __init__(
5858
)
5959

6060
if isinstance(atlas_image_path, Atlas):
61-
atlas_folder = verify_or_download_atlases()
61+
atlas_folder = fetch_atlases()
6262
self.atlas_image_path = atlas_folder / atlas_image_path.value
6363
else:
6464
self.atlas_image_path = Path(atlas_image_path)

0 commit comments

Comments
 (0)