From 3a817f73b1e53987ab910ca16c2a188b3fd43d69 Mon Sep 17 00:00:00 2001 From: the-lay Date: Tue, 19 Jan 2021 23:03:39 +0200 Subject: [PATCH] Overhauled tiling to move away from as_strided; added 3D overlap-tile example; fixed sampling shapes; changed 2D overlap-tile example to add a sanity check --- README.md | 17 +- examples/2d_overlap_tile.py | 21 +- examples/3d_overlap_tile.py | 65 ++++++ tests/test_merger.py | 2 +- tests/test_tiler.py | 52 ++--- tiler/__init__.py | 2 +- tiler/merger.py | 2 +- tiler/tiler.py | 394 +++++++----------------------------- 8 files changed, 190 insertions(+), 365 deletions(-) create mode 100644 examples/3d_overlap_tile.py diff --git a/README.md b/README.md index c406f5a..4c66e9f 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,9 @@ images do not fit into GPU memory (2D hyperspectral satellite images, Implemented features ------------- - Data reader agnostic: works with numpy arrays - - Optimized to avoid unnecessary memory copies: numpy views for all tiles, - except border tiles that require padding - N-dimensional array tiling - (but for now tiles must have the same number of dimensions as the array) + (note: currently tile shape must have the same number of dimensions as the array) + - Optional in-place tiling (without creating copies) - Supports channel dimension: dimension that will not be tiled - Overlapping support: you can specify tile percentage or directly overlap size - Window functions: Merger accepts weights for the tile as an array or a scipy window @@ -116,10 +115,9 @@ However, other libraries might fit you better than `tiler`: - Do you know any other similar packages? - [Please make a PR](https://github.com/the-lay/tiler/pulls) or [open a new issue](https://github.com/the-lay/tiler/issues). - -Academic references -------------- -[Introducing Hann windows for reducing edge-effects in patch-based image segmentation](https://doi.org/10.1371/journal.pone.0229839 + +Moreover, some approaches have been described in the literature: + - [Introducing Hann windows for reducing edge-effects in patch-based image segmentation](https://doi.org/10.1371/journal.pone.0229839 ), Pielawski and Wählby, March 2020 @@ -158,8 +156,3 @@ https://gist.github.com/npielawski/7e77d23209a5c415f55b95d4aba914f6 https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0229839#pone.0229839.ref005 https://arxiv.org/pdf/1803.02786.pdf --> - - - - - diff --git a/examples/2d_overlap_tile.py b/examples/2d_overlap_tile.py index 9fbc445..882ac83 100644 --- a/examples/2d_overlap_tile.py +++ b/examples/2d_overlap_tile.py @@ -1,4 +1,4 @@ -# 2D RGB Overlap-tile strategy tiling/merging example +# 2D RGB overlap-tile strategy tiling/merging example # # "This strategy allows the seamless segmentation of arbitrarily large images by an overlap-tile strategy. # To predict the pixels in the border region of the image, the missing context is extrapolated by mirroring @@ -26,7 +26,6 @@ overlap=(64, 64, 0), channel_dimension=2) # Window function for merging -# We also need to generate a window for function window = np.zeros((128, 128, 3)) window[32:-32, 32:-32, :] = 1 @@ -34,9 +33,20 @@ merger = Merger(tiler=tiler, window=window) # Let's define a function that will be applied to each tile -# For this example, we use PIL to adjust color balance -# In practice, this can be a neural network or any kind of processing -def process(patch: np.ndarray) -> np.ndarray: +def process(patch: np.ndarray, sanity_check: bool = True) -> np.ndarray: + + # One example can be a sanity check + # Make the parts that should be remove black + # There should not appear any black spots in the final merged image + if sanity_check: + patch[:32, :, :] = 0 + patch[-32:, :, :] = 0 + patch[:, :32, :] = 0 + patch[:, -32:, :] = 0 + return patch + + # Another example can be to just modify the whole patch + # Using PIL, we adjust the color balance enhancer = ImageEnhance.Color(Image.fromarray(patch)) return np.array(enhancer.enhance(5.0)) @@ -45,6 +55,7 @@ def process(patch: np.ndarray) -> np.ndarray: for tile_id, tile in tiler(padded_image): processed_tile = process(tile) merger.add(tile_id, processed_tile) + final_image = merger.merge().astype(np.uint8) final_unpadded_image = final_image[32:-32, 32:-32, :] diff --git a/examples/3d_overlap_tile.py b/examples/3d_overlap_tile.py new file mode 100644 index 0000000..672b527 --- /dev/null +++ b/examples/3d_overlap_tile.py @@ -0,0 +1,65 @@ +# 3D grayscale overlap-tile stratefy tiling/merging example +# +# "This strategy allows the seamless segmentation of arbitrarily large images by an overlap-tile strategy. +# To predict the pixels in the border region of the image, the missing context is extrapolated by mirroring +# the input image. This tiling strategy is important to apply the network to large images, +# since otherwise the resolution would be limited by the GPU memory." - Ronneberger et al 2015, U-Net paper + +# We will use napari for showing 3D volumes interactively +import numpy as np +from tiler import Tiler, Merger +import napari + +# Example "checkerboard"-like volume with some variation for visualization +# https://stackoverflow.com/a/51715491 +volume = (np.indices((150, 462, 462)).sum(axis=0) % 50).astype(np.float32) +volume[:75] *= np.linspace(3, 10, 75)[:, None, None] +volume[75:] *= np.linspace(10, 3, 75)[:, None, None] + +# Let's assume we want to use tiles of size 48x48x48 and only the middle 20x20x20 for the final image +# That means we need to pad the image by 14 from each side +# To extrapolate missing context let's use reflect mode +padded_volume = np.pad(volume, 14, mode='reflect') + +# Specifying tiling +# The overlap should be 28 voxels +tiler = Tiler(image_shape=padded_volume.shape, + tile_shape=(48, 48, 48), + overlap=(28, 28, 28)) + +# Window function for merging +window = np.zeros((48, 48, 48)) +window[14:-14, 14:-14, 14:-14] = 1 + +# Specifying merging +merger = Merger(tiler=tiler, window=window) + +# Let's define a function that will be applied to each tile +# For this example, let's multiple the sides that should be "cropped" by window function +# by some huge number, as a way to confirm that only the middle parts are being merged +def process(patch: np.ndarray) -> np.ndarray: + patch[:14, :, :] = 0 + patch[-14:, :, :] = 0 + patch[:, :14, :] = 0 + patch[:, -14:, :] = 0 + patch[:, :, :14] = 0 + patch[:, :, -14:] = 0 + return patch + +# Iterate through all the tiles and apply the processing function and merge everything back +for tile_id, tile in tiler(padded_volume, progress_bar=True): + processed_tile = process(tile) + merger.add(tile_id, processed_tile) + +final_volume = merger.merge() +final_unpadded_volume = final_volume[14:-14, 14:-14, 14:-14] + +# Show all the +with napari.gui_qt(): + v = napari.Viewer() + v.add_image(volume, name='Original volume') + v.add_image(padded_volume, name='Padded volume') + v.add_image(final_volume, name='Final volume') + v.add_image(final_unpadded_volume, name='Final unpadded volume') + v.add_image(merger.weights_sum, name='Merger weights sum') + v.add_image(merger.data_visits, name='Merger data visits') diff --git a/tests/test_merger.py b/tests/test_merger.py index a153c2b..2f1dd2b 100644 --- a/tests/test_merger.py +++ b/tests/test_merger.py @@ -103,4 +103,4 @@ def test_generate_window(self): for t_id, t in tiler(self.data): merger.add(t_id, t) np.testing.assert_equal(merger.merge(), - [i if i % 10 else 0 for i in range(100)]) \ No newline at end of file + [i if i % 10 else 0 for i in range(100)]) diff --git a/tests/test_tiler.py b/tests/test_tiler.py index 92dda6d..ab8b56a 100644 --- a/tests/test_tiler.py +++ b/tests/test_tiler.py @@ -44,10 +44,11 @@ def test_repr(self): channel_dimension=0, mode='irregular') - expected_repr = '[3, 15, 300] tiler for data of shape [3, 300, 300]:' \ - '\n\tNew shape: [3, 300, 300]' \ - '\n\tOverlap: 0.0' \ - '\n\tStep: [0, 15, 300]' \ + expected_repr = 'Tiler split [3, 300, 300] data into 20 tiles of [3, 15, 300].' \ + '\n\tMosaic shape: [1, 20, 1]' \ + '\n\tTileable shape: [3, 300, 300]' \ + '\n\tTile overlap: 0' \ + '\n\tElement step: [0, 15, 300]' \ '\n\tMode: irregular' \ '\n\tChannel dimension: 0' @@ -179,16 +180,30 @@ def test_get_tile(self): with self.assertRaises(IndexError): tiler.get_tile(self.data, -1) - tiles = tiler.view_in_tiles(self.data) - np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(None, 0, tiles)) + # copy test + t = tiler.get_tile(self.data, 0, copy=True) + t[9] = 0 + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0)) + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t) + + t = tiler.get_tile(self.data, 0, copy=False) + t[9] = 0 + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0)) + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t) + t[9] = 9 + # copy test with iterator + t = list(tiler(self.data, copy_data=True)) + t[0][1][9] = 0 np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0)) + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0][1]) + self.assertNotEqual(t[0][1][9], self.data[9]) - def test_view_in_tiles(self): - tiler = Tiler(image_shape=self.data.shape, - tile_shape=(10, )) - with self.assertRaises(ValueError): - tiler.view_in_tiles(self.data.reshape((10, 10))) + t = [tile for _, tile in tiler(self.data, copy_data=False)] + t[0][9] = 0 + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0)) + np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0]) + self.assertEqual(t[0][9], self.data[9]) def test_overlap(self): # Case 1 @@ -241,21 +256,6 @@ def test_overlap(self): self.assertEqual(len(tiler), len(expected_split)) np.testing.assert_equal(expected_split, calculated_split) - def test_tile_sample_shape(self): - tile_size = 10 - tiler = Tiler(image_shape=self.data.shape, - tile_shape=(tile_size, ), - channel_dimension=None) - tiler2 = Tiler(image_shape=(3, ) + self.data.shape, - tile_shape=(3, tile_size), - channel_dimension=0) - - with self.assertRaises(IndexError): - tiler.get_tile_sample_shape(len(tiler)) - - np.testing.assert_equal([tile_size], tiler.get_tile_sample_shape(len(tiler) - 1)) - np.testing.assert_equal([tile_size], tiler2.get_tile_sample_shape(len(tiler) - 1)) - def test_tile_mosaic_position(self): tile_size = 10 tiler = Tiler(image_shape=self.data.shape, tile_shape=(tile_size, )) diff --git a/tiler/__init__.py b/tiler/__init__.py index ce1fcb0..481adda 100644 --- a/tiler/__init__.py +++ b/tiler/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.1' +__version__ = '0.1.2' from tiler.tiler import Tiler from tiler.merger import Merger diff --git a/tiler/merger.py b/tiler/merger.py index bdd532e..0cbe3cc 100644 --- a/tiler/merger.py +++ b/tiler/merger.py @@ -227,7 +227,7 @@ def reset(self) -> None: # Image holds sum of all processed tiles multiplied by the window if self.logits: - self.data = np.zeros([self.logits] + padded_data_shape) + self.data = np.zeros(np.hstack((self.logits, padded_data_shape))) else: self.data = np.zeros(padded_data_shape) diff --git a/tiler/tiler.py b/tiler/tiler.py index fabbbdd..ba26192 100644 --- a/tiler/tiler.py +++ b/tiler/tiler.py @@ -1,5 +1,4 @@ import numpy as np -from numpy.lib.stride_tricks import as_strided as ast from tqdm.auto import tqdm from typing import Tuple, List, Union, Callable, Generator # try: @@ -12,43 +11,13 @@ class Tiler: TILING_MODES = ['constant', 'drop', 'irregular', 'reflect', 'edge', 'wrap'] - # @classmethod - # def overlap_tile(cls): - # pass - # - # @classmethod - # def auto_overlap(cls, - # image_shape: Union[Tuple, List], - # tile_shape: Union[Tuple, List], - # window: str, - # mode: Union[str] = 'constant', - # channel_dimension: Union[int, None] = None, - # offset: Union[int, tuple, List, None] = None, - # constant_value: float = 0.0 - # ): - # """ - # Alternative way to create a Tiler object. - # Automatically calculates optimal overlap and padding depending on the window function. - # - # :param image_shape: - # :param tile_shape: - # :param window: - # :param mode: - # :param channel_dimension: - # :param offset: - # :param constant_value: - # :return: - # """ - # - # pass - def __init__(self, image_shape: Union[Tuple, List], tile_shape: Union[Tuple, List], - mode: Union[str] = 'constant', + overlap: Union[int, float, Tuple, List] = 0, channel_dimension: Union[int, None] = None, - constant_value: float = 0.0, - overlap: Union[float, Tuple, List] = 0.0 + mode: str = 'constant', + constant_value: float = 0.0 ): """ Tiler precomputes everything for tiling with specified parameters, without requiring actual data. @@ -62,18 +31,29 @@ def __init__(self, Tile must have same the number of dimensions as data. # TODO: it should be possible to create tiles with less dimensions than data + :param overlap: int, float, tuple, list + If int, the same overlap in each dimension. + If float, percentage of a tile_size to use for overlap (from 0.0 to 1.0). + If tuple or list, size of the overlap in. Must be smaller than tile_shape. + Default is 0.0. + + :param channel_dimension: int, None + Used to specify the channel dimension, the dimension that will not be tiled. + Usually it is the last or the first dimension of the array. + Default is None, no channel dimension in the data. + :param mode: str Mode defines how the data will be tiled. # TODO: allow a user supplied function, Callable One of the following string values: `constant` (default) - Pads tile with constant value to match tile_shape. - Set the value with keyword 'constant_value'. + If a tile is smaller than `tile_shape`, pad it with a constant value to match `tile_shape`. + Set the value with the keyword 'constant_value'. 'drop' - Do not return tiles that are smaller than tile_shape. + If a tile is smaller than `tile_shape`, ignore it. 'irregular' - Tiles can be smaller than tile_shape are. + Allow tiles to be smaller than `tile_shape`. `reflect` Pads tile with the reflection of values along each axis. `edge` @@ -96,45 +76,36 @@ def __init__(self, # # The function accepts the tile and returns the padded tile. - :param channel_dimension: int, None - Used to specify the channel dimension, the dimension that will not be tiled. - Usually it is the last or the first dimension of the array. - Default is None, no channel dimension in the data. - :param constant_value: float Used in 'constant' mode. The value to set the padded values for each axis. - :param overlap: int, float, tuple, list - If int, the same overlap in each dimension. - If float, percentage of a tile_size to use for overlap (from 0.0 to 1.0). - If tuple or list, size of the overlap in. Must be smaller than tile_shape. - Default is 0.0. """ # Image and tile shapes self.image_shape = np.asarray(image_shape).astype(int) self.tile_shape = np.asarray(tile_shape).astype(int) + self._n_dim = len(image_shape) if (self.tile_shape <= 0).any() or (self.image_shape <= 0).any(): - raise ValueError('Shapes must be tuple or lists of positive numbers.') + raise ValueError('Tile and data shapes must be tuple or lists of positive numbers.') if self.tile_shape.size != self.image_shape.size: raise ValueError('Tile and data shapes must have the same length.') # Tiling mode self.mode = mode if self.mode not in self.TILING_MODES: - raise ValueError('Unsupported tiling mode, please check docs.') + raise ValueError(f'{self.mode} is an unsupported tiling mode, please check the documentation.') + + # Constant value used for constant tiling mode + self.constant_value = constant_value # Channel dimension self.channel_dimension = channel_dimension - if self.channel_dimension and ((self.channel_dimension < 0) - or (self.channel_dimension > len(self.image_shape))): + if self.channel_dimension and \ + ((self.channel_dimension < 0) or (self.channel_dimension >= len(self.image_shape))): raise ValueError(f'Specified channel dimension is out of bounds ' - f'(should be from 0 to {len(self.image_shape)}).') - - # Constant value used for `constant` tiling mode - self.constant_value = constant_value + f'(should be None or an integer from 0 to {len(self.image_shape) - 1}).') - # Overlap + # Overlap and step self.overlap = overlap if isinstance(self.overlap, float) and (self.overlap < 0 or self.overlap > 1.0): raise ValueError('Overlap, if float, must be in range of 0.0 (0%) to 1.0 (100%).') @@ -142,10 +113,8 @@ def __init__(self, and (np.any((self.tile_shape - np.array(self.overlap)) <= 0)): raise ValueError('Overlap size much be smaller than tile_shape.') - # Tiling points and sizes calculations - self._n_dim = len(image_shape) if isinstance(self.overlap, list) or isinstance(self.overlap, tuple): - # overlap is given + # overlap is given directly self._tile_overlap: np.ndarray = np.array(self.overlap).astype(int) elif isinstance(self.overlap, int): # int overlap applies the same overlap to each dimension @@ -157,77 +126,50 @@ def __init__(self, raise ValueError('Unsupported overlap mode (not float, int, list or tuple).') self._tile_step: np.ndarray = (self.tile_shape - self._tile_overlap).astype(int) # tile step - self._tile_slices = [slice(None, None, step) for i, step in enumerate(self._tile_step) if step != 0] - # if channel dimension is given, set tile_step of that dimension to 0 + # Calculate mosaic (collection of tiles) shape + div, mod = np.divmod([image_shape[d] - self._tile_overlap[d] for d in range(self._n_dim)], self._tile_step) + if self.mode == 'drop': + self._indexing_shape = div + else: + self._indexing_shape = div + (mod != 0) if self.channel_dimension is not None: - self._tile_step[self.channel_dimension] = 0 - self._tile_points = [ - list(range(0, image_shape[d] - self._tile_overlap[d], self._tile_step[d])) - if self._tile_step[d] != 0 else [0] - for d in range(self._n_dim) - ] - self._new_shape = [x[-1] + self.tile_shape[i] for i, x in enumerate(self._tile_points)] - self._shape_diff = self._new_shape - self.image_shape + self._indexing_shape[self.channel_dimension] = 1 - # Drop mode: delete points that would create patches that are smaller than tile_size - if self.mode == 'drop': - # delete points that would create patches smaller than tile_size - for d, x in enumerate(self._shape_diff): - if 0 < x < self.tile_shape[d]: - del self._tile_points[d][-1] + # Calculate new shape assuming tiles are padded + if self.mode == 'irregular': + self._new_shape = self.image_shape + else: + self._new_shape = (self._indexing_shape * self._tile_step) + self._tile_overlap + self._shape_diff = self._new_shape - self.image_shape + if self.channel_dimension is not None: + self._shape_diff[self.channel_dimension] = 0 - # recalculate new shape and shape diff - self._new_shape = [x[-1] + self.tile_shape[i] for i, x in enumerate(self._tile_points)] - self._shape_diff = self._new_shape - self.image_shape + # If channel dimension is given, set tile_step of that dimension to 0 + if self.channel_dimension is not None: + self._tile_step[self.channel_dimension] = 0 # Tile indexing - # Underneath, the actual tiling is done with numpy's as_strided (returns view = O(1)) - # Returned strided array will be 2n-dimensional with first n being indexing dimensions - # and last n dimensions contain actual data. Reshaping view to a list of patches - # would mean copying data (and losing all benefits of view). To avoid that, we have a proxy array - # that is basically a mapping from 1D (0 to N tiles) to ND tiles. - - # context to remove division by zero warnings - with np.errstate(divide='ignore', invalid='ignore'): - self._indexing_shape = ((self._new_shape - self.tile_shape) // self._tile_step) + 1 - - self._tile_index = np.array(np.meshgrid( *[ np.arange(0, x) for x in self._indexing_shape] )) - self._tile_index = self._tile_index.T.reshape(-1, self._n_dim) # reshape into (tile_id, coordinates) + self._tile_index = np.vstack(np.meshgrid(*[np.arange(0, x) for x in self._indexing_shape], indexing='ij')) + self._tile_index = self._tile_index.reshape(self._n_dim, -1).T self.n_tiles = len(self._tile_index) - # Tile sampling - self._tile_sample_shapes = np.tile(self.tile_shape, (*self._indexing_shape, 1)) - # Most of the tiles should be full self.tile_shape, but the ones on the edges will probably be out-of-bounds. - # The problem with view is that there is no OOB checks. We have to keep in mind how many voxels to sample. - # Border tiles can have shape int: """ Returns number of tiles produced by tiling. """ return self.n_tiles def __repr__(self) -> str: - return f'{list(self.tile_shape)} tiler for data of shape {list(self.image_shape)}:' \ - f'\n\tNew shape: {self._new_shape}' \ - f'\n\tOverlap: {self.overlap}' \ - f'\n\tStep: {list(self._tile_step)}' \ + return f'Tiler split {list(self.image_shape)} data into {len(self)} tiles of {list(self.tile_shape)}.' \ + f'\n\tMosaic shape: {list(self._indexing_shape)}' \ + f'\n\tTileable shape: {list(self._new_shape)}' \ + f'\n\tTile overlap: {self.overlap}' \ + f'\n\tElement step: {list(self._tile_step)}' \ f'\n\tMode: {self.mode}' \ f'\n\tChannel dimension: {self.channel_dimension}' def __call__(self, data: np.ndarray, progress_bar: bool = False, - batch_size: int = 1, drop_last: bool = False) -> \ + batch_size: int = 1, drop_last: bool = False, + copy_data: bool = True) -> \ Generator[Tuple[int, np.ndarray], None, None]: """ Iterate through tiles of the given data array. @@ -249,6 +191,10 @@ def __call__(self, data: np.ndarray, progress_bar: bool = False, # if n_tiles % batch_size != 0 and drop_last == True, drop the last (incomplete) batch # else, returns incomplete batch + :param copy_data: bool + If true, returned tile is a copy. Otherwise, it is a view. + Default is True. + :return: yields (int, np.ndarray) Returns tuple with int that is the tile_id and np.ndarray tile data. """ @@ -261,45 +207,11 @@ def __call__(self, data: np.ndarray, progress_bar: bool = False, # # actual_batch_size = batch_i # collated_tiles = - tiles = self.view_in_tiles(data) - for tile_i in tqdm(range(self.n_tiles), desc='Tiling', disable=not progress_bar, unit='tile'): - yield tile_i, self.get_tile(None, tile_i, tiles) - - def view_in_tiles(self, data: np.ndarray) -> np.ndarray: - """ - Fast (O(1)) tiling of the data with numpy views. - Slices data into mosaic of tiles. - :param data: np.ndarray - Array to be sliced into tiles. + for tile_i in tqdm(range(self.n_tiles), desc='Tiling', disable=not progress_bar, unit='tile'): + yield tile_i, self.get_tile(data, tile_i, copy=copy_data) - :return: np.ndarray - 2 * data.ndim -dimensional array. - First n dimensions are mosaic coordinates, rest n dimensions are the actual data. - """ - if np.any(np.array(data.shape) != self.image_shape): - raise ValueError(f'Data must have the same shape as image_shape ' - f'({data.shape} != {self.image_shape}).') - - # if isinstance(data, np.ndarray): - tile_strides = data.strides - indexing_strides = data[tuple(self._tile_slices)].strides - # elif isinstance(data, torch.Tensor): - # tile_strides = np.multiply(data.stride(), data.element_size()) - # indexing_strides = np.multiply(data[tuple(self._tile_slices)].stride(), data.element_size()) - # else: - # raise ValueError(f'Not np.ndarray, but {type(data)}') - - shape = tuple(list(self._indexing_shape) + list(self.tile_shape)) - strides = tuple(list(indexing_strides) + list(tile_strides)) - - # if isinstance(data, np.ndarray): - tiles = ast(data, shape=shape, strides=strides, writeable=False) - # else: - # tiles = torch.as_strided(data, size=shape, stride=strides) - return tiles - - def get_tile(self, data: Union[np.ndarray, None], tile_id: int, tiles: np.ndarray = None) -> np.ndarray: + def get_tile(self, data: Union[np.ndarray, None], tile_id: int, copy: bool = True) -> np.ndarray: """ Returns tile content. @@ -309,8 +221,9 @@ def get_tile(self, data: Union[np.ndarray, None], tile_id: int, tiles: np.ndarra :param tile_id: int Specify which tile to return. Must be smaller than number of tiles. - :param tiles: np.ndarray - # TODO for inner use + :param copy: bool + If true, returned tile is a copy. Otherwise, it is a view. + Default is True. :return: np.ndarray Content of tile number tile_id. Padded if necessary. @@ -320,26 +233,15 @@ def get_tile(self, data: Union[np.ndarray, None], tile_id: int, tiles: np.ndarra raise IndexError(f'Out of bounds, there is no tile {tile_id}.' f'There are {len(self) - 1} tiles, starting from index 0.') - # get tiles view - if tiles is None: - tiles = self.view_in_tiles(data) - - # get the shape that should be sampled from the tile - sample_shape = self.get_tile_sample_shape(tile_id, with_channel_dim=(self.channel_dimension is not None)) - - # get the actual data for the tile - tile_view = tiles[tuple(self._tile_index[tile_id])] - # if isinstance(tile_view, np.ndarray): - tile_data = tile_view[tuple(slice(None, stop) for stop in sample_shape)].copy() - # elif isinstance(tile_view, torch.Tensor): - # tile_data = tile_view[tuple(slice(None, stop) for stop in sample_shape)].clone() - # else: - # raise ValueError(f'Not np.ndarray, but {type(tile_view)}') - - # # if sample_shape is not the same as tile_shape, we need to pad the tile in the given mode - # if self.channel_dimension is not None: - # sample_shape = np.insert(sample_shape, self.channel_dimension, self.tile_shape[self.channel_dimension]) - shape_diff = self.tile_shape - np.array(sample_shape, ndmin=self.tile_shape.ndim) + # get tile data + tile_corner = self._tile_index[tile_id] * self._tile_step + sampling = [slice(tile_corner[d], tile_corner[d] + self.tile_shape[d]) for d in range(self._n_dim)] + tile_data = data[tuple(sampling)] + + if copy: + tile_data = tile_data.copy() + + shape_diff = self.tile_shape - tile_data.shape if (self.mode != 'irregular') and np.any(shape_diff > 0): if self.mode == 'constant': tile_data = np.pad(tile_data, list((0, diff) for diff in shape_diff), mode=self.mode, @@ -349,31 +251,6 @@ def get_tile(self, data: Union[np.ndarray, None], tile_id: int, tiles: np.ndarra return tile_data - def get_tile_sample_shape(self, tile_id: int, with_channel_dim: bool = False) -> np.ndarray: - """ - Returns shape of sample for the tile with number tile_id. - In other words, shape of a sub-hyperrectangle of tile that was sampled from original data. - For example if (64, 64) tile was actually padded to that size from (40, 40), - this method will return (40, 40). - - :param tile_id: int - Tile ID for which to return sample shape. - - :param with_channel_dim: bool - Whether to return shape with channel dimension or without. - - :return: np.ndarray - Shape of sample of the tile. - """ - - if (tile_id < 0) or (tile_id >= self.n_tiles): - raise IndexError(f'Out of bounds, there is no tile {tile_id}. ' - f'There are {len(self)} tiles, starting from index 0.') - - if self.channel_dimension is not None and not with_channel_dim: - return self._tile_sample_shapes[tile_id][~(np.arange(self._n_dim) == self.channel_dimension)] - return self._tile_sample_shapes[tile_id] - def get_tile_bbox_position(self, tile_id: int, with_channel_dim: bool = False) -> Tuple[np.ndarray, np.ndarray]: """ Returns diagonal corner coordinates of bounding hyperrectangle of the tile on padded data. @@ -436,124 +313,3 @@ def get_mosaic_shape(self, with_channel_dim: bool = False) -> np.ndarray: if self.channel_dimension is not None and not with_channel_dim: return self._indexing_shape[~(np.arange(self._n_dim) == self.channel_dimension)] return self._indexing_shape - - - - - # - # # Merge (processed) tile into accumulator array - # # Efficient way to accumulate processed images - # # Supports various windows - # def merge(self, accumulator: np.ndarray, tile: np.ndarray, tile_id: int, window: str = 'norm'): - # if accumulator.shape != self.image_shape: - # raise ValueError(f'Accumulator must have the same shape as image_shape ' - # f'({accumulator.shape} != {self.image_shape})') - # if window not in self.__WINDOWS: - # raise ValueError('Unsupported window function, please check docs') - # - # - # # Return border type of the tile - # def get_tile_border_type(self, tile_id: int): - # tile_pos = self._tile_index[tile_id] - # tile_n_around = self._tile_border_types[tuple(tile_pos)] - # min_max = tile_pos == min(tile_pos) - # return - # - # - # - # # Return - # def is_corner_tile(self, tile_id: int) -> bool: - # pass - # - # def is_edge_tile(self, tile_id: int) -> bool: - # pass - # - # # corners - # # number of corners: 2^n_dim, permutations of all corners - # # corners direction? - # - # # edges - # # - # # import numpy as np - # # - # # def edge_mask(x): - # # mask = np.ones(x.shape, dtype=bool) - # # mask[x.ndim * (slice(1, -1),)] = False - # # return mask - # # - # # x = np.random.rand(4, 5) - # # edge_mask(x) - # # # array([[ True, True, True, True, True], - # # # [ True, False, False, False, True], - # # # [ True, False, False, False, True], - # # # [ True, True, True, True, True]], dtype=bool) - # - # - # # - # # - # # def _precompute_window_type(self): - # # - # # - # # - # # def _get_corners(self) -> np.ndarray: - # # corners = a[tuple(slice(None, None, j - 1) for j in a.shape)] - # # - # # - # # def how_many_edges_touching(self, tile_id: int) -> int: - # # - # # - # # def is_border_tile(self, tile_id: int): - # # # edge tile will have at least one min or max value in any dimension - # # if self.how_many_edges_touching(tile_id) > 0: - # # return True - # # else: - # # return False - # # - # # def is_corner_tile(self, tile_id: int): - # # # corner tile will have ndim min or max values in any dimensions - # # if self.how_many_edges_touching(tile_id) == self._n_dim: - # # return True - # # else: - # # return False - # # - # # - # # def a(self): - # # # Define types of possible tiles - # # - # # # 2D case - # # # *---------* - # # # |1 5 2| - # # # |7 9 8| - # # # |3 6 4| - # # # *---------* - # # - # # - # # # +--------+ - # # # / /| - # # # / / | - # # # +--------+ | - # # # | | | - # # # | | + - # # # | | / - # # # | |/ - # # # +--------+ - # # - # # # Calculate which tiles are border tiles - # # - # # # self._border_tiles = [tile for tile in self._tiles if tile] - # # self._border_tiles = [] - # # for i in range(len(self._tiles)): - # # for dim, x in enumerate(self._tiles[i]): - # # if x == self._tile - # # - # # if np.any([True for x in self._tiles[i] if ]) - # # - # # for tile in self._tiles: - # # # check each dimension and if it is _tile_ends - # # if np.any() - # # [x for x in tile] - # # for dim in tile: - # # if tile[dim] - # - # # def merge(self, images, window: str = None, crop_padding: bool = True): - # # pass