Skip to content

Commit 8173030

Browse files
authored
Merge pull request #136 from BrainLesion/135-feature-generalize-back-to-native-class-to-allow-transforms-in-both-directions
135 feature generalize back to native class to allow transforms in both directions
2 parents 7e0123c + 43ec7d6 commit 8173030

File tree

6 files changed

+189
-36
lines changed

6 files changed

+189
-36
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
`BrainLes preprocessing` is a comprehensive tool for preprocessing tasks in biomedical imaging, with a focus on (but not limited to) multi-modal brain MRI. It can be used to build modular preprocessing pipelines:
1212

13-
This includes **normalization**, **co-registration**, **atlas registration** and **skulstripping / brain extraction**.
13+
This includes **normalization**, **co-registration**, **atlas registration**, **skullstripping / brain extraction**, **N4 Bias correction** and **defacing**.
14+
We provide means to transform images and segmentations in both directions between native and atlas space.
1415

15-
BrainLes is written `backend-agnostic` meaning it allows to swap the registration, brain extraction tools and defacing tools.
16+
BrainLes is written modular and `backend-agnostic` meaning it allows to skip or swap registration, brain extraction, N4 bias correction and defacing tools.
1617

1718
<!-- TODO include image here -->
1819

@@ -86,7 +87,7 @@ moving_modalities = [
8687
)
8788
]
8889

89-
# instantiate and run the preprocessor using defaults for registration/ brain extraction/ defacing backends
90+
# instantiate and run the preprocessor using defaults for backends (registration, brain extraction, bias correction, defacing)
9091
preprocessor = Preprocessor(
9192
center_modality=center,
9293
moving_modalities=moving_modalities,
@@ -127,6 +128,9 @@ We currently provide support for [ANTs](https://github.com/ANTsX/ANTs) (default)
127128
We provide the SRI-24 atlas from this [publication](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2915788/).
128129
However, custom atlases in NIfTI format are supported.
129130

131+
### N4 Bias correction
132+
We currently provide support for N4 Bias correction based on [SimpleITK](https://simpleitk.org/)
133+
130134
### Brain extraction
131135
We currently provide support for [HD-BET](https://github.com/MIC-DKFZ/HD-BET).
132136

brainles_preprocessing/registration/ANTs/ANTs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def transform(
142142
transformed_image_path: Union[str, Path],
143143
matrix_path: str | Path | List[str | Path],
144144
log_file_path: Union[str, Path],
145-
interpolator: str = "linear",
145+
interpolator: str = "nearestNeighbor",
146146
**kwargs,
147147
) -> None:
148148
"""
@@ -154,7 +154,7 @@ def transform(
154154
transformed_image_path (str or Path): Path to the transformed image (output).
155155
matrix_path (str or Path or List[str | Path]): Path to the transformation matrix or a list of matrices.
156156
log_file_path (str or Path): Path to the log file.
157-
interpolator (str): Interpolator to use for the transformation. Default is 'linear'.
157+
interpolator (str): Interpolator to use for the transformation. Default is 'nearestNeighbor'.
158158
**kwargs: Additional transformation parameters to update the instantiated defaults.
159159
Raises:
160160
AssertionError: If the interpolator is not valid.
@@ -231,7 +231,7 @@ def inverse_transform(
231231
transformed_image_path: Union[str, Path],
232232
matrix_path: str | Path | List[str | Path],
233233
log_file_path: Union[str, Path],
234-
interpolator: str = "linear",
234+
interpolator: str = "nearestNeighbor",
235235
**kwargs,
236236
) -> None:
237237
"""
@@ -243,7 +243,7 @@ def inverse_transform(
243243
transformed_image_path (str or Path): Path to the transformed image (output).
244244
matrix_path (str or Path): Path to the transformation matrix.
245245
log_file_path (str or Path): Path to the log file.
246-
interpolator (str): Interpolator to use for the transformation. Default is 'linear'.
246+
interpolator (str): Interpolator to use for the transformation. Default is 'nearestNeighbor'.
247247
**kwargs: Additional transformation parameters to update the instantiated defaults.
248248
"""
249249
if not isinstance(matrix_path, list):

brainles_preprocessing/registration/niftyreg/niftyreg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def transform(
150150
transformed_image_path: str,
151151
matrix_path: str | Path | List[str | Path],
152152
log_file_path: str,
153-
interpolator: str = "1",
153+
interpolator: str = "0",
154154
**kwargs: dict,
155155
) -> None:
156156
"""
@@ -226,7 +226,7 @@ def inverse_transform(
226226
transformed_image_path: str,
227227
matrix_path: List[str | Path],
228228
log_file_path: str,
229-
interpolator: str = "1",
229+
interpolator: str = "0",
230230
) -> None:
231231
"""
232232
Apply inverse transformation using NiftyReg.
Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
from pathlib import Path
2-
from typing import List, Optional, Union
2+
from typing import Optional, Union
33

4-
from brainles_preprocessing.constants import Atlas, PreprocessorSteps
5-
from brainles_preprocessing.defacing import Defacer, QuickshearDefacer
6-
from brainles_preprocessing.modality import CenterModality, Modality
74
from brainles_preprocessing.registration import ANTsRegistrator
85
from brainles_preprocessing.registration.registrator import Registrator
96
from brainles_preprocessing.utils.logging_utils import LoggingManager
10-
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
117

128
logging_man = LoggingManager(name=__name__)
139
logger = logging_man.get_logger()
1410

1511

16-
class BackToNativeSpace:
12+
class Transform:
13+
"""
14+
Class to apply precomputed transformations based on the registration process.
15+
Common use case is to apply inverse transformations to transform e.g. segmentations in atlas space back to native space or transform existing labels in native space to atlas space.
16+
"""
1717

1818
def __init__(
1919
self,
2020
transformations_dir: Union[str, Path],
2121
registrator: Optional[Registrator] = None,
22-
):
22+
) -> None:
23+
"""
24+
Initialize the Transform class.
25+
26+
Args:
27+
transformations_dir (Union[str, Path]): Directory containing precomputed transformations for each modality.
28+
registrator (Optional[Registrator], optional): Registrator instance to use for applying transformations. If None, defaults to ANTsRegistrator.
29+
"""
2330

2431
self.transformations_dir = Path(transformations_dir)
2532

@@ -29,17 +36,18 @@ def __init__(
2936
)
3037
self.registrator: Registrator = registrator or ANTsRegistrator()
3138

32-
def transform(
39+
def apply(
3340
self,
3441
target_modality_name: str,
3542
target_modality_img: Union[str, Path],
3643
moving_image: Union[str, Path],
3744
output_img_path: Union[str, Path],
3845
log_file_path: Union[str, Path],
3946
interpolator: Optional[str] = None,
40-
):
47+
inverse: bool = False,
48+
) -> None:
4149
"""
42-
Apply inverse transformation to a moving image to align it with a target modality.
50+
Apply forward/ inverse transformation to a moving image to align it with a target modality.
4351
4452
Args:
4553
target_modality_name (str): Name of the target modality. Must match the name used to create the transformations.
@@ -48,11 +56,12 @@ def transform(
4856
output_img_path (Union[str, Path]): Path where the transformed image will be saved.
4957
log_file_path (Union[str, Path]): Path to the log file where transformation details will be written.
5058
interpolator (Optional[str]): Interpolation method used during transformation.
59+
inverse (bool): If True, applies the inverse transformation. Default is False.
5160
Available options depend on the chosen registrator:
5261
5362
- **ANTsRegistrator**:
54-
- "linear" (default)
55-
- "nearestNeighbor"
63+
- "linear"
64+
- "nearestNeighbor" (default)
5665
- "multiLabel" (deprecated, prefer "genericLabel")
5766
- "gaussian"
5867
- "bSpline"
@@ -63,16 +72,16 @@ def transform(
6372
- "genericLabel" (recommended for label images)
6473
6574
- **NiftyReg**:
66-
- "0": nearest neighbor
67-
- "1": linear (default)
75+
- "0": nearest neighbor (default)
76+
- "1": linear
6877
- "3": cubic spline
6978
- "4": sinc
7079
7180
Raises:
7281
AssertionError: If the transformations directory for the given modality does not exist.
7382
"""
7483
logger.info(
75-
f"Applying inverse transformation for {target_modality_name} using {self.registrator.__class__.__name__}."
84+
f"Applying {'inverse' if inverse else ''} transformation for {target_modality_name} using {self.registrator.__class__.__name__}."
7685
)
7786

7887
# assert modality name eixsts in transformations_dir
@@ -86,7 +95,9 @@ def transform(
8695

8796
transforms = list(modality_transformations_dir.iterdir())
8897
transforms.sort() # sort by name to get order for forward transform
89-
transforms = transforms[::-1] # inverse order for inverse transform
98+
99+
if inverse:
100+
transforms = transforms[::-1] # inverse order for inverse transform
90101

91102
kwargs = {
92103
"fixed_image_path": target_modality_img,
@@ -97,4 +108,8 @@ def transform(
97108
}
98109
if interpolator is not None:
99110
kwargs["interpolator"] = interpolator
100-
self.registrator.inverse_transform(**kwargs)
111+
112+
if inverse:
113+
self.registrator.inverse_transform(**kwargs)
114+
else:
115+
self.registrator.transform(**kwargs)

tests/test_preprocessor.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from types import SimpleNamespace
2+
3+
import pytest
4+
5+
from brainles_preprocessing.preprocessor import CenterModality, Modality, Preprocessor
6+
7+
8+
# --- Dummy classes for testing ---
9+
class DummyModality:
10+
def __init__(self, name):
11+
self.modality_name = name
12+
13+
14+
class DummyCenterModality(DummyModality):
15+
pass
16+
17+
18+
@pytest.fixture
19+
def dummy_registrator():
20+
return SimpleNamespace()
21+
22+
23+
@pytest.fixture
24+
def dummy_brain_extractor():
25+
return SimpleNamespace()
26+
27+
28+
@pytest.fixture
29+
def dummy_defacer():
30+
return SimpleNamespace()
31+
32+
33+
# --- Tests ---
34+
35+
36+
def test_no_name_conflicts(dummy_registrator, dummy_brain_extractor, dummy_defacer):
37+
center = CenterModality("T1", input_path="", raw_skull_output_path="tmp")
38+
moving = [
39+
Modality("T2", input_path="", raw_skull_output_path="tmp"),
40+
Modality("FLAIR", input_path="", raw_skull_output_path="tmp"),
41+
]
42+
# Should not raise
43+
Preprocessor(
44+
center_modality=center,
45+
moving_modalities=moving,
46+
registrator=dummy_registrator,
47+
brain_extractor=dummy_brain_extractor,
48+
defacer=dummy_defacer,
49+
)
50+
51+
52+
def test_single_duplicate_name_raises(
53+
dummy_registrator, dummy_brain_extractor, dummy_defacer
54+
):
55+
center = CenterModality("T1", input_path="", raw_skull_output_path="tmp")
56+
moving = [
57+
Modality("T1", input_path="", raw_skull_output_path="tmp"), # Duplicate name
58+
Modality("FLAIR", input_path="", raw_skull_output_path="tmp"),
59+
]
60+
61+
with pytest.raises(ValueError, match=r"Duplicate modality names found: T1"):
62+
Preprocessor(
63+
center_modality=center,
64+
moving_modalities=moving,
65+
registrator=dummy_registrator,
66+
brain_extractor=dummy_brain_extractor,
67+
defacer=dummy_defacer,
68+
)
69+
70+
71+
def test_multiple_duplicate_names_raises(
72+
dummy_registrator, dummy_brain_extractor, dummy_defacer
73+
):
74+
center = CenterModality("T1", input_path="", raw_skull_output_path="tmp")
75+
moving = [
76+
Modality("T1", input_path="", raw_skull_output_path="tmp"), # Duplicate
77+
Modality("FLAIR", input_path="", raw_skull_output_path="tmp"),
78+
Modality("FLAIR", input_path="", raw_skull_output_path="tmp"), # Duplicate
79+
]
80+
81+
with pytest.raises(ValueError) as exc_info:
82+
Preprocessor(
83+
center_modality=center,
84+
moving_modalities=moving,
85+
registrator=dummy_registrator,
86+
brain_extractor=dummy_brain_extractor,
87+
defacer=dummy_defacer,
88+
)
89+
90+
# Check that all duplicates are reported in the error message
91+
error_msg = str(exc_info.value)
92+
assert "T1" in error_msg
93+
assert "FLAIR" in error_msg

0 commit comments

Comments
 (0)