Skip to content

Commit 7e0123c

Browse files
authored
Merge pull request #133 from BrainLesion/78-feature-request-n4-bias-correction-support
78 feature request n4 bias correction support
2 parents 6632b5e + cfcebce commit 7e0123c

File tree

12 files changed

+273
-99
lines changed

12 files changed

+273
-99
lines changed

brainles_preprocessing/brain_extraction/brain_extractor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional, Union
66
from enum import Enum
77

8-
from auxiliary.nifti.io import read_nifti, write_nifti
8+
from auxiliary.io import read_image, write_image
99
from brainles_hd_bet import run_hd_bet
1010

1111

@@ -55,8 +55,8 @@ def apply_mask(
5555

5656
try:
5757
# Read data
58-
input_data = read_nifti(str(input_image_path))
59-
mask_data = read_nifti(str(mask_path))
58+
input_data = read_image(str(input_image_path))
59+
mask_data = read_image(str(mask_path))
6060
except FileNotFoundError as e:
6161
raise FileNotFoundError(f"File not found: {e.filename}") from e
6262
except Exception as e:
@@ -70,10 +70,10 @@ def apply_mask(
7070
masked_data = input_data * mask_data
7171

7272
try:
73-
write_nifti(
73+
write_image(
7474
input_array=masked_data,
75-
output_nifti_path=str(bet_image_path),
76-
reference_nifti_path=str(input_image_path),
75+
output_path=str(bet_image_path),
76+
reference_path=str(input_image_path),
7777
create_parent_directory=True,
7878
)
7979
except Exception as e:

brainles_preprocessing/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ class PreprocessorSteps(IntEnum):
66
COREGISTERED = 1
77
ATLAS_REGISTERED = 2
88
ATLAS_CORRECTED = 3
9-
BET = 4
10-
DEFACED = 5
9+
N4_BIAS_CORRECTED = 4
10+
BET = 5
11+
DEFACED = 6
1112

1213

1314
class Atlas(str, Enum):

brainles_preprocessing/defacing/defacer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from typing import Union
44

5-
from auxiliary.nifti.io import read_nifti, write_nifti
5+
from auxiliary.io import read_image, write_image
66

77

88
class Defacer(ABC):
@@ -52,8 +52,8 @@ def apply_mask(
5252

5353
try:
5454
# Read data
55-
input_data = read_nifti(str(input_image_path))
56-
mask_data = read_nifti(str(mask_path))
55+
input_data = read_image(str(input_image_path))
56+
mask_data = read_image(str(mask_path))
5757
except Exception as e:
5858
raise RuntimeError(
5959
f"An error occurred while reading input files: {e}"
@@ -67,9 +67,9 @@ def apply_mask(
6767
masked_data = input_data * mask_data
6868

6969
# Save the defaced image
70-
write_nifti(
70+
write_image(
7171
input_array=masked_data,
72-
output_nifti_path=str(defaced_image_path),
73-
reference_nifti_path=str(input_image_path),
72+
output_path=str(defaced_image_path),
73+
reference_path=str(input_image_path),
7474
create_parent_directory=True,
7575
)

brainles_preprocessing/defacing/quickshear/quickshear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import Union
33

44
import nibabel as nib
5+
from auxiliary.io import read_image, write_image
56

67
from brainles_preprocessing.defacing.defacer import Defacer
78
from brainles_preprocessing.defacing.quickshear.nipy_quickshear import run_quickshear
8-
from auxiliary.nifti.io import write_nifti
99

1010

1111
class QuickshearDefacer(Defacer):
@@ -56,8 +56,11 @@ def deface(
5656

5757
bet_img = nib.load(str(input_image_path))
5858
mask = run_quickshear(bet_img=bet_img, buffer=self.buffer)
59-
write_nifti(
59+
60+
# transpose to match simpletik order
61+
mask = mask.transpose(2, 1, 0)
62+
write_image(
6063
input_array=mask,
61-
output_nifti_path=str(mask_image_path),
62-
reference_nifti_path=str(input_image_path),
64+
output_path=str(mask_image_path),
65+
reference_path=str(input_image_path),
6366
)

brainles_preprocessing/modality.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
from typing import Dict, Optional, Union
66

7-
from auxiliary.nifti.io import read_nifti, write_nifti
7+
from auxiliary.io import read_image, write_image
88

99
from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
1010
from brainles_preprocessing.constants import PreprocessorSteps
@@ -34,6 +34,7 @@ class Modality:
3434
normalized_skull_output_path (str or Path, optional): Path to save the normalized modality data with skull. Requires a normalizer.
3535
normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
3636
atlas_correction (bool, optional): Indicates whether atlas correction should be performed.
37+
n4_bias_correction (bool, optional): Indicates whether N4 bias correction should be performed.
3738
3839
Attributes:
3940
modality_name (str): Name of the modality.
@@ -47,6 +48,7 @@ class Modality:
4748
normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
4849
bet (bool): Indicates whether brain extraction is enabled.
4950
atlas_correction (bool): Indicates whether atlas correction should be performed.
51+
n4_bias_correction (bool): Indicates whether N4 bias correction should be performed.
5052
coregistration_transform_path (str or None): Path to the coregistration transformation matrix, will be set after coregistration.
5153
5254
Example:
@@ -72,13 +74,15 @@ def __init__(
7274
normalized_skull_output_path: Optional[Union[str, Path]] = None,
7375
normalized_defaced_output_path: Optional[Union[str, Path]] = None,
7476
atlas_correction: bool = True,
77+
n4_bias_correction: bool = False,
7578
) -> None:
7679
# Basics
7780
self.modality_name = modality_name
7881
self.input_path = Path(input_path)
7982
self.current = self.input_path
8083
self.normalizer = normalizer
8184
self.atlas_correction = atlas_correction
85+
self.n4_bias_correction = n4_bias_correction
8286
self.transformation_paths: Dict[PreprocessorSteps, Path | None] = {}
8387

8488
# Check that atleast one output is generated
@@ -195,12 +199,12 @@ def normalize(
195199

196200
# Normalize the image
197201
if self.normalizer:
198-
image = read_nifti(str(self.current))
202+
image = read_image(str(self.current))
199203
normalized_image = self.normalizer.normalize(image=image)
200-
write_nifti(
204+
write_image(
201205
input_array=normalized_image,
202-
output_nifti_path=str(self.current),
203-
reference_nifti_path=str(self.current),
206+
output_path=str(self.current),
207+
reference_path=str(self.current),
204208
)
205209
else:
206210
logger.info("No normalizer specified; skipping normalization.")
@@ -505,12 +509,12 @@ def save_current_image(
505509
if normalization:
506510
if self.normalizer is None:
507511
raise ValueError("Normalizer is required for normalization.")
508-
image = read_nifti(str(self.current))
512+
image = read_image(str(self.current))
509513
normalized_image = self.normalizer.normalize(image=image)
510-
write_nifti(
514+
write_image(
511515
input_array=normalized_image,
512-
output_nifti_path=str(output_path),
513-
reference_nifti_path=str(self.current),
516+
output_path=str(output_path),
517+
reference_path=str(self.current),
514518
)
515519
else:
516520
shutil.copyfile(
@@ -534,6 +538,7 @@ class CenterModality(Modality):
534538
normalized_skull_output_path (str or Path, optional): Path to save the normalized modality data with skull. Requires a normalizer.
535539
normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
536540
atlas_correction (bool, optional): Indicates whether atlas correction should be performed.
541+
n4_bias_correction (bool, optional): Indicates whether N4 bias correction should be performed.
537542
bet_mask_output_path (str or Path, optional): Path to save the brain extraction mask.
538543
defacing_mask_output_path (str or Path, optional): Path to save the defacing mask.
539544
@@ -549,6 +554,7 @@ class CenterModality(Modality):
549554
normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
550555
bet (bool): Indicates whether brain extraction is enabled.
551556
atlas_correction (bool): Indicates whether atlas correction should be performed.
557+
n4_bias_correction (bool): Indicates whether N4 bias correction should be performed.
552558
bet_mask_output_path (Path, optional): Path to save the brain extraction mask.
553559
defacing_mask_output_path (Path, optional): Path to save the defacing mask.
554560
@@ -575,6 +581,7 @@ def __init__(
575581
normalized_skull_output_path: Optional[Union[str, Path]] = None,
576582
normalized_defaced_output_path: Optional[Union[str, Path]] = None,
577583
atlas_correction: bool = True,
584+
n4_bias_correction: bool = False,
578585
bet_mask_output_path: Optional[Union[str, Path]] = None,
579586
defacing_mask_output_path: Optional[Union[str, Path]] = None,
580587
) -> None:
@@ -589,6 +596,7 @@ def __init__(
589596
normalized_skull_output_path=normalized_skull_output_path,
590597
normalized_defaced_output_path=normalized_defaced_output_path,
591598
atlas_correction=atlas_correction,
599+
n4_bias_correction=n4_bias_correction,
592600
)
593601
# Only for CenterModality
594602
self.bet_mask_output_path = (
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .n4_bias_corrector import N4BiasCorrector
2+
from .sitk.sitk_n4_bias_corrector import SitkN4BiasCorrector
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
5+
class N4BiasCorrector(ABC):
6+
7+
@abstractmethod
8+
def correct(
9+
self,
10+
input_img_path: Any,
11+
output_img_path: Any,
12+
) -> None:
13+
pass

brainles_preprocessing/n4_bias_correction/sitk/__init__.py

Whitespace-only changes.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from pathlib import Path
2+
from typing import Callable, Optional, Union
3+
4+
import SimpleITK as sitk
5+
from auxiliary.io import write_image
6+
7+
from brainles_preprocessing.n4_bias_correction.n4_bias_corrector import N4BiasCorrector
8+
9+
10+
class SitkN4BiasCorrector(N4BiasCorrector):
11+
12+
def __init__(
13+
self,
14+
mask_func: Optional[Callable[[sitk.Image], sitk.Image]] = None,
15+
n_max_iterations: Optional[int] = None,
16+
n_fitting_levels: int = 3,
17+
) -> None:
18+
"""
19+
N4 Bias Corrector using SimpleITK.
20+
21+
Args:
22+
mask_func (Optional[Callable[[sitk.Image], sitk.Image]], optional):
23+
Function that generates a mask from an image.
24+
Defaults to: `lambda img_itk: sitk.OtsuThreshold(img_itk, 0, 1, 200)`.
25+
n_max_iterations (Optional[int], optional):
26+
Maximum number of iterations for bias field correction.
27+
n_fitting_levels (int, optional):
28+
Number of fitting levels. Default is 3.
29+
"""
30+
31+
if mask_func is None:
32+
mask_func = lambda img_itk: sitk.OtsuThreshold(img_itk, 0, 1, 200)
33+
self.mask_func = mask_func
34+
self.n_max_iterations = n_max_iterations
35+
self.n_fitting_levels = n_fitting_levels
36+
37+
def compute_mask(self, img_itk: sitk.Image) -> sitk.Image:
38+
"""
39+
Compute the mask for the input image using the provided mask function.
40+
41+
Args:
42+
img_itk (SimpleITK.Image): The input image in SimpleITK format.
43+
44+
Returns:
45+
SimpleITK.Image: The computed mask.
46+
"""
47+
48+
return self.mask_func(img_itk)
49+
50+
def correct(
51+
self,
52+
input_img_path: Union[str, Path],
53+
output_img_path: Union[str, Path],
54+
) -> None:
55+
"""
56+
Correct the bias field of the input image using SimpleITK.
57+
58+
Args:
59+
input_img_path (Union[str, Path]): Path to the input image.
60+
output_img_path (Union[str, Path]): Path where the corrected image will be saved.
61+
62+
Returns:
63+
None
64+
"""
65+
img_itk = sitk.ReadImage(str(input_img_path))
66+
67+
mask_itk = self.compute_mask(img_itk)
68+
69+
corrector = sitk.N4BiasFieldCorrectionImageFilter()
70+
if self.n_max_iterations is not None:
71+
corrector.SetMaximumNumberOfIterations(
72+
[self.n_max_iterations] * self.n_fitting_levels
73+
)
74+
75+
corrected_img = corrector.Execute(img_itk, mask_itk)
76+
corrected_img = sitk.GetArrayFromImage(corrected_img)
77+
78+
write_image(
79+
input_array=corrected_img,
80+
output_path=str(output_img_path),
81+
reference_path=str(input_img_path),
82+
)

0 commit comments

Comments
 (0)