From ebda81c04785d70d16d7d970362407830134bad9 Mon Sep 17 00:00:00 2001 From: dsethz Date: Wed, 27 Nov 2024 20:18:01 +0100 Subject: [PATCH] added first version of RandomSpatialPad --- src/nuclai/utils/datamodule.py | 60 +++----------------------------- src/nuclai/utils/utils.py | 62 ++++++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/nuclai/utils/datamodule.py b/src/nuclai/utils/datamodule.py index 94b541c..3bc3ad5 100644 --- a/src/nuclai/utils/datamodule.py +++ b/src/nuclai/utils/datamodule.py @@ -19,7 +19,7 @@ from monai import transforms from torch.utils.data import DataLoader -from nuclai.utils.utils import _get_mask +from nuclai.utils.utils import RandomSpatialPad, _get_mask class DataSetCls: @@ -181,7 +181,6 @@ class DataSet: path_data: path to CSV file containing image paths and header "image". trans: Compose of transforms to apply to each image. shape: shape of the input image. - (remove) bit_depth: bit depth of the input image. """ def __init__( @@ -189,7 +188,6 @@ def __init__( path_data: Union[str, pathlib.PosixPath, pathlib.WindowsPath], trans: Optional[transforms.Compose] = None, shape: tuple[int, ...] = (30, 300, 300), - # bit_depth: int = 8, ): super().__init__() @@ -219,30 +217,17 @@ def __init__( isinstance(i, int) for i in shape ), "values of shape should be of type integer." - # assert isinstance( - # bit_depth, int - # ), f'type of bit_depth should be int instead it is of type: "{type(bit_depth)}".' - self.path_data = path_data self.data = pd.read_csv(path_data) self.shape = shape self.trans = trans - self.padder = transforms.SpatialPad(self.shape, method="symmetric") + self.padder = RandomSpatialPad(self.shape) + # self.padder = transforms.SpatialPad(spatial_size=self.shape, method="symmetric") assert ( "image" in self.data.columns ), 'The input file requires "image" as header.' - # if bit_depth == 8: - # self.bit_depth = np.uint8 - # elif bit_depth == 16: - # self.bit_depth = np.int32 - # else: - # self.bit_depth = np.uint8 - # raise Warning( - # f'bit_depth must be in {8, 16}, but is "{bit_depth}". It will be handled as 8bit and may create an integer overflow.' - # ) - def __len__(self): return len(self.data) @@ -285,8 +270,6 @@ def _preprocess(self, img: np.array) -> tuple[torch.Tensor, torch.Tensor]: len(img.shape) == 3 ), f'images are expected to be grayscale and len(img.shape)==3, here it is: "{len(img.shape)}".' - # img = img.astype(self.bit_depth) - img_t = torch.from_numpy(img).type(torch.FloatTensor) img_t = torch.unsqueeze(img_t, 0) @@ -294,7 +277,7 @@ def _preprocess(self, img: np.array) -> tuple[torch.Tensor, torch.Tensor]: if self.trans is not None: img_t = self.trans(img_t) - # TODO: for now apply padding after transformations (see comment in DataModule.setup) + # apply padding and get mask img_t = self.padder(img_t) mask = _get_mask(img=img_t, padder=self.padder) @@ -332,34 +315,12 @@ def setup(self, stage: Optional[str] = None): """ Instantiate datasets """ - - # catch image data type - # tmp = pd.read_csv(self.path_data) - # img = tifffile.imread(tmp.loc[0, "image"]) - - # if img.dtype == np.uint8: - # max_intensity = 255.0 - # bit_depth = 8 - # elif img.dtype == np.uint16: - # max_intensity = 65535.0 - # bit_depth = 16 - # else: - # max_intensity = 255.0 - # bit_depth = 8 - # raise Warning( - # f'Image type "{img.dtype}" is currently not supported and will be converted to "uint8".' - # ) - if stage == "fit" or stage is None: assert self.path_data_val is not None, "path_data_val is missing." # instantiate transforms and datasetst - # TODO: use translation to not only have centered images (e.g. RandAffine) trans = transforms.Compose( [ - # transforms.NormalizeIntensity( - # subtrahend=0, divisor=max_intensity - # ), transforms.NormalizeIntensity(), transforms.RandZoom(keep_size=True), transforms.RandAxisFlip(), @@ -372,9 +333,6 @@ def setup(self, stage: Optional[str] = None): trans_val = transforms.Compose( [ - # transforms.NormalizeIntensity( - # subtrahend=0, divisor=max_intensity - # ), transforms.NormalizeIntensity(), ] ) @@ -383,22 +341,17 @@ def setup(self, stage: Optional[str] = None): self.path_data, trans=trans, shape=self.shape, - # bit_depth=bit_depth, ) self.data_val = DataSet( self.path_data_val, trans=trans_val, shape=self.shape, - # bit_depth=bit_depth, ) if stage == "test" or stage is None: # instantiate transforms and datasets trans = transforms.Compose( [ - # transforms.NormalizeIntensity( - # subtrahend=0, divisor=max_intensity - # ), transforms.NormalizeIntensity(), ] ) @@ -407,16 +360,12 @@ def setup(self, stage: Optional[str] = None): self.path_data, trans=trans, shape=self.shape, - # bit_depth=bit_depth, ) if stage == "predict" or stage is None: # instantiate transforms and datasets trans = transforms.Compose( [ - # transforms.NormalizeIntensity( - # subtrahend=0, divisor=max_intensity - # ), transforms.NormalizeIntensity(), ] ) @@ -425,7 +374,6 @@ def setup(self, stage: Optional[str] = None): self.path_data, trans=trans, shape=self.shape, - # bit_depth=bit_depth, ) def train_dataloader(self): diff --git a/src/nuclai/utils/utils.py b/src/nuclai/utils/utils.py index 19b09ea..dedc5c8 100644 --- a/src/nuclai/utils/utils.py +++ b/src/nuclai/utils/utils.py @@ -10,13 +10,71 @@ import pathlib from typing import BinaryIO, Union +import numpy as np import torch -from monai.data.meta_tensor import MetaTensor +from monai.data import MetaTensor +from monai.data.meta_obj import get_track_meta from monai.transforms import CropForegroundd, SpatialPad -from monai.utils import TraceKeys +from monai.transforms.croppad.functional import pad_func +from monai.utils import TraceKeys, convert_to_tensor, fall_back_tuple from skimage import io +class RandomSpatialPad(SpatialPad): + """ + Randomly pad input along the last n axes given a defined n-dimensional shape. + + Args: + spatial_size: Union[int, tuple[int, ...], list[int, ...]] + Expected shape of spatial dimensions after padding. + **kwargs: Any + Additional parameters for parent class. + + TODO: + * adapt DataModule to use RandomSpatialPad + * add to compose of each data set + + """ + + def __init__( + self, + spatial_size: Union[int, tuple[int, ...], list[int, ...]], + **kwargs, + ): + super().__init__(spatial_size=spatial_size, **kwargs) + + def __call__(self, img: MetaTensor) -> MetaTensor: + # get img shape + input_spatial_shape = img.shape[1:] # assume channel first format + + # compute padding + pads = self._compute_pad_width(input_spatial_shape) + + # Convert img to metatensor if necessary + img = convert_to_tensor(data=img, track_meta=get_track_meta()) + + return pad_func(img, pads, self.get_transform_info()) + + def _compute_pad_width(self, input_spatial_shape): + # validate spatial_size + spatial_size = fall_back_tuple(self.spatial_size, input_spatial_shape) + + # compute padding + pads = [] + for dim_size, target_size in zip( + input_spatial_shape, spatial_size, strict=False + ): + total_pad = max(target_size - dim_size, 0) + pad_before = np.random.randint(0, total_pad + 1) + pad_after = total_pad - pad_before + pads.append((pad_before, pad_after)) + + # add channel dimension + pads = tuple([(0, 0)] + pads) + + return pads + + def _threshold_at_zero(x: torch.Tensor) -> torch.Tensor: """ Helper function for crop_original.