Skip to content

Commit 225af2b

Browse files
authored
Merge pull request #29 from BrainLesion/feat/refactor_brain_extraction
Refactor brain extraction
2 parents a60a1cf + a180348 commit 225af2b

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed
Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,79 @@
11
# TODO add typing and docs
22
from abc import abstractmethod
3+
import os
34

45
import nibabel as nib
56
import numpy as np
67
from brainles_hd_bet import run_hd_bet
78

9+
from auxiliary.nifti.io import read_nifti, write_nifti
10+
from auxiliary.turbopath import name_extractor
11+
12+
13+
from shutil import copyfile
14+
815

916
class BrainExtractor:
1017
@abstractmethod
1118
def extract(
1219
self,
13-
input_image,
14-
output_image,
15-
log_file,
16-
mode,
17-
):
20+
input_image_path: str,
21+
masked_image_path: str,
22+
brain_mask_path: str,
23+
log_file_path: str,
24+
# TODO convert mode to enum
25+
mode: str,
26+
) -> None:
1827
pass
1928

2029
def apply_mask(
2130
self,
22-
input_image,
23-
mask_image,
24-
output_image,
25-
):
26-
"""masks images with brain masks"""
27-
inputnifti = nib.load(input_image)
28-
mask = nib.load(mask_image)
31+
input_image_path: str,
32+
mask_image_path: str,
33+
masked_image_path: str,
34+
) -> None:
35+
"""
36+
Apply a brain mask to an input image.
2937
30-
# mask it
31-
masked_file = np.multiply(inputnifti.get_fdata(), mask.get_fdata())
32-
masked_file = nib.Nifti1Image(masked_file, inputnifti.affine, inputnifti.header)
38+
Parameters:
39+
- input_image_path (str): Path to the input image (NIfTI format).
40+
- mask_image_path (str): Path to the brain mask image (NIfTI format).
41+
- masked_image_path (str): Path to save the resulting masked image (NIfTI format).
3342
34-
# save it
35-
nib.save(masked_file, output_image)
43+
Returns:
44+
- str: Path to the saved masked image.
45+
"""
46+
47+
# read data
48+
input_data = read_nifti(input_image_path)
49+
mask_data = read_nifti(mask_image_path)
50+
51+
# mask and save it
52+
masked_data = input_data * mask_data
53+
54+
write_nifti(
55+
input_array=masked_data,
56+
output_nifti_path=masked_image_path,
57+
reference_nifti_path=input_image_path,
58+
create_parent_directory=True,
59+
)
3660

3761

3862
class HDBetExtractor(BrainExtractor):
3963
def extract(
4064
self,
41-
input_image,
42-
masked_image,
43-
# TODO implement logging!
44-
log_file,
45-
mode="accurate",
46-
):
65+
input_image_path: str,
66+
masked_image_path: str,
67+
brain_mask_path: str,
68+
log_file_path: str = None,
69+
# TODO convert mode to enum
70+
mode: str = "accurate",
71+
) -> None:
4772
# GPU + accurate + TTA
4873
"""skullstrips images with HD-BET generates a skullstripped file and mask"""
4974
run_hd_bet(
50-
mri_fnames=[input_image],
51-
output_fnames=[masked_image],
75+
mri_fnames=[input_image_path],
76+
output_fnames=[masked_image_path],
5277
# device=0,
5378
# TODO consider postprocessing
5479
# postprocess=False,
@@ -59,3 +84,15 @@ def extract(
5984
keep_mask=True,
6085
overwrite=True,
6186
)
87+
88+
hdbet_mask_path = (
89+
masked_image_path.parent
90+
+ "/"
91+
+ name_extractor(masked_image_path)
92+
+ "_masked.nii.gz"
93+
)
94+
95+
copyfile(
96+
src=hdbet_mask_path,
97+
dst=brain_mask_path,
98+
)

brainles_preprocessing/modality.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def apply_mask(
115115
f"brain_masked__{self.modality_name}.nii.gz",
116116
)
117117
brain_extractor.apply_mask(
118-
input_image=self.current,
119-
mask_image=atlas_mask,
120-
output_image=brain_masked,
118+
input_image_path=self.current,
119+
mask_image_path=atlas_mask,
120+
masked_image_path=brain_masked,
121121
)
122122
self.current = brain_masked
123123

@@ -153,9 +153,10 @@ def extract_brain_region(
153153
)
154154

155155
brain_extractor.extract(
156-
input_image=self.current,
157-
masked_image=atlas_bet_cm,
158-
log_file=bet_log,
156+
input_image_path=self.current,
157+
masked_image_path=atlas_bet_cm,
158+
brain_mask_path=atlas_mask,
159+
log_file_path=bet_log,
159160
)
160161
self.current = atlas_bet_cm
161162
return atlas_mask

0 commit comments

Comments
 (0)