Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added first version of RandomSpatialPad #1

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 4 additions & 56 deletions src/nuclai/utils/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from monai import transforms
from torch.utils.data import DataLoader

from nuclai.utils.utils import _get_mask
from nuclai.utils.utils import RandomSpatialPad, _get_mask


class DataSetCls:
Expand Down Expand Up @@ -181,15 +181,13 @@ class DataSet:
path_data: path to CSV file containing image paths and header "image".
trans: Compose of transforms to apply to each image.
shape: shape of the input image.
(remove) bit_depth: bit depth of the input image.
"""

def __init__(
self,
path_data: Union[str, pathlib.PosixPath, pathlib.WindowsPath],
trans: Optional[transforms.Compose] = None,
shape: tuple[int, ...] = (30, 300, 300),
# bit_depth: int = 8,
):
super().__init__()

Expand Down Expand Up @@ -219,30 +217,17 @@ def __init__(
isinstance(i, int) for i in shape
), "values of shape should be of type integer."

# assert isinstance(
# bit_depth, int
# ), f'type of bit_depth should be int instead it is of type: "{type(bit_depth)}".'

self.path_data = path_data
self.data = pd.read_csv(path_data)
self.shape = shape
self.trans = trans
self.padder = transforms.SpatialPad(self.shape, method="symmetric")
self.padder = RandomSpatialPad(self.shape)
# self.padder = transforms.SpatialPad(spatial_size=self.shape, method="symmetric")

assert (
"image" in self.data.columns
), 'The input file requires "image" as header.'

# if bit_depth == 8:
# self.bit_depth = np.uint8
# elif bit_depth == 16:
# self.bit_depth = np.int32
# else:
# self.bit_depth = np.uint8
# raise Warning(
# f'bit_depth must be in {8, 16}, but is "{bit_depth}". It will be handled as 8bit and may create an integer overflow.'
# )

def __len__(self):
return len(self.data)

Expand Down Expand Up @@ -285,16 +270,14 @@ def _preprocess(self, img: np.array) -> tuple[torch.Tensor, torch.Tensor]:
len(img.shape) == 3
), f'images are expected to be grayscale and len(img.shape)==3, here it is: "{len(img.shape)}".'

# img = img.astype(self.bit_depth)

img_t = torch.from_numpy(img).type(torch.FloatTensor)
img_t = torch.unsqueeze(img_t, 0)

# apply transforms
if self.trans is not None:
img_t = self.trans(img_t)

# TODO: for now apply padding after transformations (see comment in DataModule.setup)
# apply padding and get mask
img_t = self.padder(img_t)
mask = _get_mask(img=img_t, padder=self.padder)

Expand Down Expand Up @@ -332,34 +315,12 @@ def setup(self, stage: Optional[str] = None):
"""
Instantiate datasets
"""

# catch image data type
# tmp = pd.read_csv(self.path_data)
# img = tifffile.imread(tmp.loc[0, "image"])

# if img.dtype == np.uint8:
# max_intensity = 255.0
# bit_depth = 8
# elif img.dtype == np.uint16:
# max_intensity = 65535.0
# bit_depth = 16
# else:
# max_intensity = 255.0
# bit_depth = 8
# raise Warning(
# f'Image type "{img.dtype}" is currently not supported and will be converted to "uint8".'
# )

if stage == "fit" or stage is None:
assert self.path_data_val is not None, "path_data_val is missing."

# instantiate transforms and datasetst
# TODO: use translation to not only have centered images (e.g. RandAffine)
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
transforms.RandZoom(keep_size=True),
transforms.RandAxisFlip(),
Expand All @@ -372,9 +333,6 @@ def setup(self, stage: Optional[str] = None):

trans_val = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -383,22 +341,17 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)
self.data_val = DataSet(
self.path_data_val,
trans=trans_val,
shape=self.shape,
# bit_depth=bit_depth,
)

if stage == "test" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -407,16 +360,12 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)

if stage == "predict" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -425,7 +374,6 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)

def train_dataloader(self):
Expand Down
62 changes: 60 additions & 2 deletions src/nuclai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,71 @@
import pathlib
from typing import BinaryIO, Union

import numpy as np
import torch
from monai.data.meta_tensor import MetaTensor
from monai.data import MetaTensor
from monai.data.meta_obj import get_track_meta
from monai.transforms import CropForegroundd, SpatialPad
from monai.utils import TraceKeys
from monai.transforms.croppad.functional import pad_func
from monai.utils import TraceKeys, convert_to_tensor, fall_back_tuple
from skimage import io


class RandomSpatialPad(SpatialPad):
"""
Randomly pad input along the last n axes given a defined n-dimensional shape.

Args:
spatial_size: Union[int, tuple[int, ...], list[int, ...]]
Expected shape of spatial dimensions after padding.
**kwargs: Any
Additional parameters for parent class.

TODO:
* adapt DataModule to use RandomSpatialPad
* add to compose of each data set

"""

def __init__(
self,
spatial_size: Union[int, tuple[int, ...], list[int, ...]],
**kwargs,
):
super().__init__(spatial_size=spatial_size, **kwargs)

def __call__(self, img: MetaTensor) -> MetaTensor:
# get img shape
input_spatial_shape = img.shape[1:] # assume channel first format

# compute padding
pads = self._compute_pad_width(input_spatial_shape)

# Convert img to metatensor if necessary
img = convert_to_tensor(data=img, track_meta=get_track_meta())

return pad_func(img, pads, self.get_transform_info())

def _compute_pad_width(self, input_spatial_shape):
# validate spatial_size
spatial_size = fall_back_tuple(self.spatial_size, input_spatial_shape)

# compute padding
pads = []
for dim_size, target_size in zip(
input_spatial_shape, spatial_size, strict=False
):
total_pad = max(target_size - dim_size, 0)
pad_before = np.random.randint(0, total_pad + 1)
pad_after = total_pad - pad_before
pads.append((pad_before, pad_after))

# add channel dimension
pads = tuple([(0, 0)] + pads)

return pads


def _threshold_at_zero(x: torch.Tensor) -> torch.Tensor:
"""
Helper function for crop_original.
Expand Down
Loading