Skip to content

Commit

Permalink
Merge pull request #1 from dsethz/random_spatial_pad
Browse files Browse the repository at this point in the history
added first version of RandomSpatialPad
  • Loading branch information
dsethz authored Dec 20, 2024
2 parents 6dcfa5b + ebda81c commit 412cae7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 58 deletions.
60 changes: 4 additions & 56 deletions src/nuclai/utils/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from monai import transforms
from torch.utils.data import DataLoader

from nuclai.utils.utils import _get_mask
from nuclai.utils.utils import RandomSpatialPad, _get_mask


class DataSetCls:
Expand Down Expand Up @@ -181,15 +181,13 @@ class DataSet:
path_data: path to CSV file containing image paths and header "image".
trans: Compose of transforms to apply to each image.
shape: shape of the input image.
(remove) bit_depth: bit depth of the input image.
"""

def __init__(
self,
path_data: Union[str, pathlib.PosixPath, pathlib.WindowsPath],
trans: Optional[transforms.Compose] = None,
shape: tuple[int, ...] = (30, 300, 300),
# bit_depth: int = 8,
):
super().__init__()

Expand Down Expand Up @@ -219,30 +217,17 @@ def __init__(
isinstance(i, int) for i in shape
), "values of shape should be of type integer."

# assert isinstance(
# bit_depth, int
# ), f'type of bit_depth should be int instead it is of type: "{type(bit_depth)}".'

self.path_data = path_data
self.data = pd.read_csv(path_data)
self.shape = shape
self.trans = trans
self.padder = transforms.SpatialPad(self.shape, method="symmetric")
self.padder = RandomSpatialPad(self.shape)
# self.padder = transforms.SpatialPad(spatial_size=self.shape, method="symmetric")

assert (
"image" in self.data.columns
), 'The input file requires "image" as header.'

# if bit_depth == 8:
# self.bit_depth = np.uint8
# elif bit_depth == 16:
# self.bit_depth = np.int32
# else:
# self.bit_depth = np.uint8
# raise Warning(
# f'bit_depth must be in {8, 16}, but is "{bit_depth}". It will be handled as 8bit and may create an integer overflow.'
# )

def __len__(self):
return len(self.data)

Expand Down Expand Up @@ -285,16 +270,14 @@ def _preprocess(self, img: np.array) -> tuple[torch.Tensor, torch.Tensor]:
len(img.shape) == 3
), f'images are expected to be grayscale and len(img.shape)==3, here it is: "{len(img.shape)}".'

# img = img.astype(self.bit_depth)

img_t = torch.from_numpy(img).type(torch.FloatTensor)
img_t = torch.unsqueeze(img_t, 0)

# apply transforms
if self.trans is not None:
img_t = self.trans(img_t)

# TODO: for now apply padding after transformations (see comment in DataModule.setup)
# apply padding and get mask
img_t = self.padder(img_t)
mask = _get_mask(img=img_t, padder=self.padder)

Expand Down Expand Up @@ -332,34 +315,12 @@ def setup(self, stage: Optional[str] = None):
"""
Instantiate datasets
"""

# catch image data type
# tmp = pd.read_csv(self.path_data)
# img = tifffile.imread(tmp.loc[0, "image"])

# if img.dtype == np.uint8:
# max_intensity = 255.0
# bit_depth = 8
# elif img.dtype == np.uint16:
# max_intensity = 65535.0
# bit_depth = 16
# else:
# max_intensity = 255.0
# bit_depth = 8
# raise Warning(
# f'Image type "{img.dtype}" is currently not supported and will be converted to "uint8".'
# )

if stage == "fit" or stage is None:
assert self.path_data_val is not None, "path_data_val is missing."

# instantiate transforms and datasetst
# TODO: use translation to not only have centered images (e.g. RandAffine)
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
transforms.RandZoom(keep_size=True),
transforms.RandAxisFlip(),
Expand All @@ -372,9 +333,6 @@ def setup(self, stage: Optional[str] = None):

trans_val = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -383,22 +341,17 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)
self.data_val = DataSet(
self.path_data_val,
trans=trans_val,
shape=self.shape,
# bit_depth=bit_depth,
)

if stage == "test" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -407,16 +360,12 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)

if stage == "predict" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)
Expand All @@ -425,7 +374,6 @@ def setup(self, stage: Optional[str] = None):
self.path_data,
trans=trans,
shape=self.shape,
# bit_depth=bit_depth,
)

def train_dataloader(self):
Expand Down
62 changes: 60 additions & 2 deletions src/nuclai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,71 @@
import pathlib
from typing import BinaryIO, Union

import numpy as np
import torch
from monai.data.meta_tensor import MetaTensor
from monai.data import MetaTensor
from monai.data.meta_obj import get_track_meta
from monai.transforms import CropForegroundd, SpatialPad
from monai.utils import TraceKeys
from monai.transforms.croppad.functional import pad_func
from monai.utils import TraceKeys, convert_to_tensor, fall_back_tuple
from skimage import io


class RandomSpatialPad(SpatialPad):
"""
Randomly pad input along the last n axes given a defined n-dimensional shape.
Args:
spatial_size: Union[int, tuple[int, ...], list[int, ...]]
Expected shape of spatial dimensions after padding.
**kwargs: Any
Additional parameters for parent class.
TODO:
* adapt DataModule to use RandomSpatialPad
* add to compose of each data set
"""

def __init__(
self,
spatial_size: Union[int, tuple[int, ...], list[int, ...]],
**kwargs,
):
super().__init__(spatial_size=spatial_size, **kwargs)

def __call__(self, img: MetaTensor) -> MetaTensor:
# get img shape
input_spatial_shape = img.shape[1:] # assume channel first format

# compute padding
pads = self._compute_pad_width(input_spatial_shape)

# Convert img to metatensor if necessary
img = convert_to_tensor(data=img, track_meta=get_track_meta())

return pad_func(img, pads, self.get_transform_info())

def _compute_pad_width(self, input_spatial_shape):
# validate spatial_size
spatial_size = fall_back_tuple(self.spatial_size, input_spatial_shape)

# compute padding
pads = []
for dim_size, target_size in zip(
input_spatial_shape, spatial_size, strict=False
):
total_pad = max(target_size - dim_size, 0)
pad_before = np.random.randint(0, total_pad + 1)
pad_after = total_pad - pad_before
pads.append((pad_before, pad_after))

# add channel dimension
pads = tuple([(0, 0)] + pads)

return pads


def _threshold_at_zero(x: torch.Tensor) -> torch.Tensor:
"""
Helper function for crop_original.
Expand Down

0 comments on commit 412cae7

Please sign in to comment.