Skip to content

Commit 3430b86

Browse files
authored
Merge pull request #44 from BrainLesion/testing_registrator
Testing registrator
2 parents 65b274a + da440a9 commit 3430b86

File tree

5 files changed

+141
-57
lines changed

5 files changed

+141
-57
lines changed

brainles_preprocessing/modality.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ def register(
9494
registered_log = os.path.join(registration_dir, f"{moving_image_name}.log")
9595

9696
registrator.register(
97-
fixed_image=fixed_image_path,
98-
moving_image=self.current,
97+
fixed_image_path=fixed_image_path,
98+
moving_image_path=self.current,
9999
transformed_image=registered,
100-
matrix=registered_matrix,
101-
log_file=registered_log,
100+
matrix_path=registered_matrix,
101+
log_file_path=registered_log,
102102
)
103103
self.current = registered
104104
return registered_matrix
@@ -133,11 +133,11 @@ def transform(
133133
transformed_log = os.path.join(registration_dir, f"{moving_image_name}.log")
134134

135135
registrator.transform(
136-
fixed_image=fixed_image_path,
137-
moving_image=self.current,
138-
transformed_image=transformed,
139-
matrix=transformation_matrix,
140-
log_file=transformed_log,
136+
fixed_image_path=fixed_image_path,
137+
moving_image_path=self.current,
138+
transformed_image_path=transformed,
139+
matrix_path=transformation_matrix,
140+
log_file_path=transformed_log,
141141
)
142142
self.current = transformed
143143

brainles_preprocessing/registration/niftyreg.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,25 @@ def __init__(
4242

4343
def register(
4444
self,
45-
fixed_image,
46-
moving_image,
47-
transformed_image,
48-
matrix,
49-
log_file,
45+
fixed_image_path: str,
46+
moving_image_path: str,
47+
transformed_image_path: str,
48+
matrix_path: str,
49+
log_file_path: str,
5050
):
5151
"""
5252
Register images using NiftyReg.
5353
5454
Args:
55-
fixed_image (str): Path to the fixed image.
56-
moving_image (str): Path to the moving image.
57-
transformed_image (str): Path to the transformed image (output).
58-
matrix (str): Path to the transformation matrix (output).
59-
log_file (str): Path to the log file.
55+
fixed_image_path (str): Path to the fixed image.
56+
moving_image_path (str): Path to the moving image.
57+
transformed_image_path (str): Path to the transformed image (output).
58+
matrix_path (str): Path to the transformation matrix (output).
59+
log_file_path (str): Path to the log file.
6060
"""
6161
runner = ScriptRunner(
6262
script_path=self.registration_script,
63-
log_path=log_file,
63+
log_path=log_file_path,
6464
)
6565

6666
niftyreg_executable = str(
@@ -69,10 +69,10 @@ def register(
6969

7070
input_params = [
7171
turbopath(niftyreg_executable),
72-
turbopath(fixed_image),
73-
turbopath(moving_image),
74-
turbopath(transformed_image),
75-
turbopath(matrix),
72+
turbopath(fixed_image_path),
73+
turbopath(moving_image_path),
74+
turbopath(transformed_image_path),
75+
turbopath(matrix_path),
7676
]
7777

7878
# Call the run method to execute the script and capture the output in the log file
@@ -85,25 +85,25 @@ def register(
8585

8686
def transform(
8787
self,
88-
fixed_image,
89-
moving_image,
90-
transformed_image,
91-
matrix,
92-
log_file,
88+
fixed_image_path: str,
89+
moving_image_path: str,
90+
transformed_image_path: str,
91+
matrix_path: str,
92+
log_file_path: str,
9393
):
9494
"""
9595
Apply a transformation using NiftyReg.
9696
9797
Args:
98-
fixed_image (str): Path to the fixed image.
99-
moving_image (str): Path to the moving image.
100-
transformed_image (str): Path to the transformed image (output).
101-
matrix (str): Path to the transformation matrix.
102-
log_file (str): Path to the log file.
98+
fixed_image_path (str): Path to the fixed image.
99+
moving_image_path (str): Path to the moving image.
100+
transformed_image_path (str): Path to the transformed image (output).
101+
matrix_path (str): Path to the transformation matrix.
102+
log_file_path (str): Path to the log file.
103103
"""
104104
runner = ScriptRunner(
105105
script_path=self.transformation_script,
106-
log_path=log_file,
106+
log_path=log_file_path,
107107
)
108108

109109
niftyreg_executable = str(
@@ -112,10 +112,10 @@ def transform(
112112

113113
input_params = [
114114
turbopath(niftyreg_executable),
115-
turbopath(fixed_image),
116-
turbopath(moving_image),
117-
turbopath(transformed_image),
118-
turbopath(matrix),
115+
turbopath(fixed_image_path),
116+
turbopath(moving_image_path),
117+
turbopath(transformed_image_path),
118+
turbopath(matrix_path),
119119
]
120120

121121
# Call the run method to execute the script and capture the output in the log file
Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,56 @@
1-
# TODO add typing and docs
21
from abc import ABC, abstractmethod
2+
from typing import Any
33

44

55
class Registrator(ABC):
6-
def __init__(self, backend):
7-
self.backend = backend
6+
# TODO probably the init here should be removed?
7+
# def __init__(self, backend):
8+
# self.backend = backend
89

910
@abstractmethod
1011
def register(
1112
self,
12-
fixed_image,
13-
moving_image,
14-
transformed_image,
15-
matrix,
16-
log_file,
17-
):
13+
fixed_image_path: Any,
14+
moving_image_path: Any,
15+
transformed_image_path: Any,
16+
matrix_path: Any,
17+
log_file_path: str,
18+
) -> None:
19+
"""
20+
Abstract method for registering images.
21+
22+
Args:
23+
fixed_image_path (Any): The fixed image for registration.
24+
moving_image_path (Any): The moving image to be registered.
25+
transformed_image_path (Any): The resulting transformed image after registration.
26+
matrix_path (Any): The transformation matrix applied during registration.
27+
log_file_path (str): The path to the log file for recording registration details.
28+
29+
Returns:
30+
None
31+
"""
1832
pass
1933

2034
@abstractmethod
2135
def transform(
2236
self,
23-
fixed_image,
24-
moving_image,
25-
transformed_image,
26-
matrix,
27-
log_file,
28-
):
37+
fixed_image_path: Any,
38+
moving_image_path: Any,
39+
transformed_image_path: Any,
40+
matrix: Any,
41+
log_file: str,
42+
) -> None:
43+
"""
44+
Abstract method for transforming images.
45+
46+
Args:
47+
fixed_image_path (Any): The fixed image to be transformed.
48+
moving_image_path (Any): The moving image to be transformed.
49+
transformed_image_path (Any): The resulting transformed image.
50+
matrix_path (Any): The transformation matrix applied during transformation.
51+
log_file_path (str): The path to the log file for recording transformation details.
52+
53+
Returns:
54+
None
55+
"""
2956
pass

tests/test_brain_extractor.py renamed to tests/test_hdbet_brain_extractor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ class TestHDBetExtractor(unittest.TestCase):
1111
def setUp(self):
1212
test_data_dir = turbopath(__file__).parent + "/test_data"
1313
input_dir = test_data_dir + "/input"
14-
self.output_dir = test_data_dir + "/temp_output"
14+
self.output_dir = test_data_dir + "/temp_output_hdbet"
1515
os.makedirs(self.output_dir, exist_ok=True)
1616

1717
self.brain_extractor = HDBetExtractor()
18+
1819
self.input_image_path = input_dir + "/tcia_example_t1c.nii.gz"
1920
self.input_brain_mask_path = input_dir + "/bet_tcia_example_t1c_mask.nii.gz"
21+
2022
self.masked_image_path = self.output_dir + "/bet_tcia_example_t1c.nii.gz"
2123
self.brain_mask_path = self.output_dir + "/bet_tcia_example_t1c_mask.nii.gz"
2224
self.masked_again_image_path = (
@@ -30,9 +32,8 @@ def tearDown(self):
3032
# Clean up created files if they exist
3133
shutil.rmtree(self.output_dir)
3234

33-
3435
def test_extract_creates_output_files(self):
35-
# we try to run the fastest possible skullstripping on GPU
36+
# we try to run the fastest possible skullstripping on CPU
3637
self.brain_extractor.extract(
3738
input_image_path=self.input_image_path,
3839
masked_image_path=self.masked_image_path,

tests/test_niftyreg_registrator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import shutil
3+
import unittest
4+
5+
from auxiliary.turbopath import turbopath
6+
7+
from brainles_preprocessing.registration.niftyreg import NiftyRegRegistrator
8+
9+
10+
class TestNiftyRegRegistrator(unittest.TestCase):
11+
def setUp(self):
12+
test_data_dir = turbopath(__file__).parent + "/test_data"
13+
input_dir = test_data_dir + "/input"
14+
self.output_dir = test_data_dir + "/temp_output_niftyreg"
15+
os.makedirs(self.output_dir, exist_ok=True)
16+
17+
self.registrator = NiftyRegRegistrator()
18+
19+
self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz"
20+
self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz"
21+
22+
self.transformed_image = self.output_dir + "/transformed_image.nii.gz"
23+
self.matrix = self.output_dir + "/matrix.txt"
24+
self.log_file = self.output_dir + "/registration.log"
25+
26+
def tearDown(self):
27+
# Clean up created files if they exist
28+
shutil.rmtree(self.output_dir)
29+
30+
def test_register_creates_output_files(self):
31+
# we try to run the fastest possible skullstripping on GPU
32+
self.registrator.register(
33+
fixed_image_path=self.fixed_image,
34+
moving_image_path=self.moving_image,
35+
transformed_image_path=self.transformed_image,
36+
matrix_path=self.matrix,
37+
log_file_path=self.log_file,
38+
)
39+
40+
self.assertTrue(
41+
os.path.exists(self.transformed_image),
42+
"transformed file was not created.",
43+
)
44+
45+
self.assertTrue(
46+
os.path.exists(self.matrix),
47+
"matrix file was not created.",
48+
)
49+
50+
self.assertTrue(
51+
os.path.exists(self.log_file),
52+
"log file was not created.",
53+
)
54+
55+
56+
# TODO also test transform

0 commit comments

Comments
 (0)