-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Standardize Axes in Random Transforms. Add Random Axis to RandomMotion #1185
base: main
Are you sure you want to change the base?
Changes from all commits
57797c6
ce86ccb
a6d641c
b5131f2
4bd32a5
b938154
111d1e2
c1cc856
08357f0
1d8761c
a7eb171
aaa7be9
671a373
88f0e94
6624576
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,9 @@ | |
from ..typing import TypeData | ||
from ..typing import TypeDataAffine | ||
from ..typing import TypeDirection | ||
from ..typing import TypeDoubletInt | ||
from ..typing import TypePath | ||
from ..typing import TypeQuartetInt | ||
from ..typing import TypeTripletFloat | ||
from ..typing import TypeTripletInt | ||
|
||
|
||
# Matrices used to switch between LPS and RAS | ||
|
@@ -87,26 +85,43 @@ def _read_dicom(directory: TypePath): | |
|
||
|
||
def read_shape(path: TypePath) -> TypeQuartetInt: | ||
reader = sitk.ImageFileReader() | ||
reader.SetFileName(str(path)) | ||
reader.ReadImageInformation() | ||
num_channels = reader.GetNumberOfComponents() | ||
num_dimensions = reader.GetDimension() | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the changes in this method related to the goal of this PR? |
||
reader = sitk.ImageFileReader() | ||
reader.SetFileName(str(path)) | ||
reader.ReadImageInformation() | ||
num_channels = reader.GetNumberOfComponents() | ||
num_dimensions = reader.GetDimension() | ||
shape = reader.GetSize() | ||
except RuntimeError as e: # try with NiBabel | ||
message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...' | ||
warnings.warn(message, stacklevel=2) | ||
try: | ||
obj: SpatialImage = nib.load(str(path)) # type: ignore[assignment] | ||
except nib.loadsave.ImageFileError as e: | ||
message = ( | ||
f'File "{path}" not understood.' | ||
' Check supported formats by at' | ||
' https://simpleitk.readthedocs.io/en/master/IO.html#images' | ||
' and https://nipy.org/nibabel/api.html#file-formats' | ||
) | ||
raise RuntimeError(message) from e | ||
num_dimensions = obj.ndim | ||
shape = obj.shape | ||
num_channels = 1 if num_dimensions < 4 else shape[-1] | ||
assert 2 <= num_dimensions <= 4 | ||
if num_dimensions == 2: | ||
spatial_shape_2d: TypeDoubletInt = reader.GetSize() | ||
assert len(spatial_shape_2d) == 2 | ||
si, sj = spatial_shape_2d | ||
assert len(shape) == 2 | ||
si, sj = shape | ||
sk = 1 | ||
elif num_dimensions == 4: | ||
# We assume bad NIfTI file (channels encoded as spatial dimension) | ||
spatial_shape_4d: TypeQuartetInt = reader.GetSize() | ||
assert len(spatial_shape_4d) == 4 | ||
si, sj, sk, num_channels = spatial_shape_4d | ||
assert len(shape) == 4 | ||
si, sj, sk, num_channels = shape | ||
elif num_dimensions == 3: | ||
spatial_shape_3d: TypeTripletInt = reader.GetSize() | ||
assert len(spatial_shape_3d) == 3 | ||
si, sj, sk = spatial_shape_3d | ||
assert len(shape) == 3 | ||
si, sj, sk = shape | ||
else: | ||
raise ValueError(f'Unsupported number of dimensions: {num_dimensions}') | ||
shape = num_channels, si, sj, sk | ||
return shape | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,5 @@ | ||||||
from collections import defaultdict | ||||||
from typing import Dict | ||||||
from typing import Iterable | ||||||
from typing import Tuple | ||||||
from typing import Union | ||||||
|
||||||
|
@@ -60,16 +59,7 @@ def __init__( | |||||
**kwargs, | ||||||
): | ||||||
super().__init__(**kwargs) | ||||||
if not isinstance(axes, tuple): | ||||||
try: | ||||||
axes = tuple(axes) # type: ignore[arg-type] | ||||||
except TypeError: | ||||||
axes = (axes,) # type: ignore[assignment] | ||||||
assert isinstance(axes, Iterable) | ||||||
for axis in axes: | ||||||
if not isinstance(axis, str) and axis not in (0, 1, 2): | ||||||
raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"') | ||||||
self.axes = axes | ||||||
self.axes = self.parse_axes(axes) | ||||||
self.num_ghosts_range = self._parse_range( | ||||||
num_ghosts, | ||||||
'num_ghosts', | ||||||
|
@@ -84,16 +74,13 @@ def __init__( | |||||
self.restore = _parse_restore(restore) | ||||||
|
||||||
def apply_transform(self, subject: Subject) -> Subject: | ||||||
axes = self.ensure_axes_indices(subject, self.axes) | ||||||
arguments: Dict[str, dict] = defaultdict(dict) | ||||||
if any(isinstance(n, str) for n in self.axes): | ||||||
subject.check_consistent_orientation() | ||||||
for name, image in self.get_images_dict(subject).items(): | ||||||
is_2d = image.is_2d() | ||||||
axes = [a for a in self.axes if a != 2] if is_2d else self.axes | ||||||
for name, _ in self.get_images_dict(subject).items(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
min_ghosts, max_ghosts = self.num_ghosts_range | ||||||
params = self.get_params( | ||||||
axes, | ||||||
(int(min_ghosts), int(max_ghosts)), | ||||||
axes, # type: ignore[arg-type] | ||||||
self.intensity_range, | ||||||
) | ||||||
num_ghosts_param, axis_param, intensity_param = params | ||||||
|
@@ -108,8 +95,8 @@ def apply_transform(self, subject: Subject) -> Subject: | |||||
|
||||||
def get_params( | ||||||
self, | ||||||
num_ghosts_range: Tuple[int, int], | ||||||
axes: Tuple[int, ...], | ||||||
num_ghosts_range: Tuple[int, int], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the point of this change? |
||||||
intensity_range: Tuple[float, float], | ||||||
) -> Tuple: | ||||||
ng_min, ng_max = num_ghosts_range | ||||||
|
@@ -118,6 +105,17 @@ def get_params( | |||||
intensity = self.sample_uniform(*intensity_range) | ||||||
return num_ghosts, axis, intensity | ||||||
|
||||||
@staticmethod | ||||||
def parse_restore(restore): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unrelated to this PR. |
||||||
try: | ||||||
restore = float(restore) | ||||||
except ValueError as e: | ||||||
raise TypeError(f'Restore must be a float, not "{restore}"') from e | ||||||
if not 0 <= restore <= 1: | ||||||
message = f'Restore must be a number between 0 and 1, not {restore}' | ||||||
raise ValueError(message) | ||||||
return restore | ||||||
|
||||||
|
||||||
class Ghosting(IntensityTransform, FourierTransform): | ||||||
r"""Add MRI ghosting artifact. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.