Skip to content

Commit

Permalink
Overhauled tiling to move away from as_strided; added 3D overlap-tile…
Browse files Browse the repository at this point in the history
… example; fixed sampling shapes; changed 2D overlap-tile example to add a sanity check
  • Loading branch information
the-lay committed Jan 19, 2021
1 parent d929d66 commit 3a817f7
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 365 deletions.
17 changes: 5 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ images do not fit into GPU memory (2D hyperspectral satellite images,
Implemented features
-------------
- Data reader agnostic: works with numpy arrays
- Optimized to avoid unnecessary memory copies: numpy views for all tiles,
except border tiles that require padding
- N-dimensional array tiling
(but for now tiles must have the same number of dimensions as the array)
(note: currently tile shape must have the same number of dimensions as the array)
- Optional in-place tiling (without creating copies)
- Supports channel dimension: dimension that will not be tiled
- Overlapping support: you can specify tile percentage or directly overlap size
- Window functions: Merger accepts weights for the tile as an array or a scipy window
Expand Down Expand Up @@ -116,10 +115,9 @@ However, other libraries might fit you better than `tiler`:

- Do you know any other similar packages?
- [Please make a PR](https://github.com/the-lay/tiler/pulls) or [open a new issue](https://github.com/the-lay/tiler/issues).

Academic references
-------------
[Introducing Hann windows for reducing edge-effects in patch-based image segmentation](https://doi.org/10.1371/journal.pone.0229839

Moreover, some approaches have been described in the literature:
- [Introducing Hann windows for reducing edge-effects in patch-based image segmentation](https://doi.org/10.1371/journal.pone.0229839
), Pielawski and Wählby, March 2020


Expand Down Expand Up @@ -158,8 +156,3 @@ https://gist.github.com/npielawski/7e77d23209a5c415f55b95d4aba914f6
https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0229839#pone.0229839.ref005
https://arxiv.org/pdf/1803.02786.pdf
-->





21 changes: 16 additions & 5 deletions examples/2d_overlap_tile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 2D RGB Overlap-tile strategy tiling/merging example
# 2D RGB overlap-tile strategy tiling/merging example
#
# "This strategy allows the seamless segmentation of arbitrarily large images by an overlap-tile strategy.
# To predict the pixels in the border region of the image, the missing context is extrapolated by mirroring
Expand Down Expand Up @@ -26,17 +26,27 @@
overlap=(64, 64, 0), channel_dimension=2)

# Window function for merging
# We also need to generate a window for function
window = np.zeros((128, 128, 3))
window[32:-32, 32:-32, :] = 1

# Specifying merging
merger = Merger(tiler=tiler, window=window)

# Let's define a function that will be applied to each tile
# For this example, we use PIL to adjust color balance
# In practice, this can be a neural network or any kind of processing
def process(patch: np.ndarray) -> np.ndarray:
def process(patch: np.ndarray, sanity_check: bool = True) -> np.ndarray:

# One example can be a sanity check
# Make the parts that should be remove black
# There should not appear any black spots in the final merged image
if sanity_check:
patch[:32, :, :] = 0
patch[-32:, :, :] = 0
patch[:, :32, :] = 0
patch[:, -32:, :] = 0
return patch

# Another example can be to just modify the whole patch
# Using PIL, we adjust the color balance
enhancer = ImageEnhance.Color(Image.fromarray(patch))
return np.array(enhancer.enhance(5.0))

Expand All @@ -45,6 +55,7 @@ def process(patch: np.ndarray) -> np.ndarray:
for tile_id, tile in tiler(padded_image):
processed_tile = process(tile)
merger.add(tile_id, processed_tile)

final_image = merger.merge().astype(np.uint8)
final_unpadded_image = final_image[32:-32, 32:-32, :]

Expand Down
65 changes: 65 additions & 0 deletions examples/3d_overlap_tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# 3D grayscale overlap-tile stratefy tiling/merging example
#
# "This strategy allows the seamless segmentation of arbitrarily large images by an overlap-tile strategy.
# To predict the pixels in the border region of the image, the missing context is extrapolated by mirroring
# the input image. This tiling strategy is important to apply the network to large images,
# since otherwise the resolution would be limited by the GPU memory." - Ronneberger et al 2015, U-Net paper

# We will use napari for showing 3D volumes interactively
import numpy as np
from tiler import Tiler, Merger
import napari

# Example "checkerboard"-like volume with some variation for visualization
# https://stackoverflow.com/a/51715491
volume = (np.indices((150, 462, 462)).sum(axis=0) % 50).astype(np.float32)
volume[:75] *= np.linspace(3, 10, 75)[:, None, None]
volume[75:] *= np.linspace(10, 3, 75)[:, None, None]

# Let's assume we want to use tiles of size 48x48x48 and only the middle 20x20x20 for the final image
# That means we need to pad the image by 14 from each side
# To extrapolate missing context let's use reflect mode
padded_volume = np.pad(volume, 14, mode='reflect')

# Specifying tiling
# The overlap should be 28 voxels
tiler = Tiler(image_shape=padded_volume.shape,
tile_shape=(48, 48, 48),
overlap=(28, 28, 28))

# Window function for merging
window = np.zeros((48, 48, 48))
window[14:-14, 14:-14, 14:-14] = 1

# Specifying merging
merger = Merger(tiler=tiler, window=window)

# Let's define a function that will be applied to each tile
# For this example, let's multiple the sides that should be "cropped" by window function
# by some huge number, as a way to confirm that only the middle parts are being merged
def process(patch: np.ndarray) -> np.ndarray:
patch[:14, :, :] = 0
patch[-14:, :, :] = 0
patch[:, :14, :] = 0
patch[:, -14:, :] = 0
patch[:, :, :14] = 0
patch[:, :, -14:] = 0
return patch

# Iterate through all the tiles and apply the processing function and merge everything back
for tile_id, tile in tiler(padded_volume, progress_bar=True):
processed_tile = process(tile)
merger.add(tile_id, processed_tile)

final_volume = merger.merge()
final_unpadded_volume = final_volume[14:-14, 14:-14, 14:-14]

# Show all the
with napari.gui_qt():
v = napari.Viewer()
v.add_image(volume, name='Original volume')
v.add_image(padded_volume, name='Padded volume')
v.add_image(final_volume, name='Final volume')
v.add_image(final_unpadded_volume, name='Final unpadded volume')
v.add_image(merger.weights_sum, name='Merger weights sum')
v.add_image(merger.data_visits, name='Merger data visits')
2 changes: 1 addition & 1 deletion tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ def test_generate_window(self):
for t_id, t in tiler(self.data):
merger.add(t_id, t)
np.testing.assert_equal(merger.merge(),
[i if i % 10 else 0 for i in range(100)])
[i if i % 10 else 0 for i in range(100)])
52 changes: 26 additions & 26 deletions tests/test_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def test_repr(self):
channel_dimension=0,
mode='irregular')

expected_repr = '[3, 15, 300] tiler for data of shape [3, 300, 300]:' \
'\n\tNew shape: [3, 300, 300]' \
'\n\tOverlap: 0.0' \
'\n\tStep: [0, 15, 300]' \
expected_repr = 'Tiler split [3, 300, 300] data into 20 tiles of [3, 15, 300].' \
'\n\tMosaic shape: [1, 20, 1]' \
'\n\tTileable shape: [3, 300, 300]' \
'\n\tTile overlap: 0' \
'\n\tElement step: [0, 15, 300]' \
'\n\tMode: irregular' \
'\n\tChannel dimension: 0'

Expand Down Expand Up @@ -179,16 +180,30 @@ def test_get_tile(self):
with self.assertRaises(IndexError):
tiler.get_tile(self.data, -1)

tiles = tiler.view_in_tiles(self.data)
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(None, 0, tiles))
# copy test
t = tiler.get_tile(self.data, 0, copy=True)
t[9] = 0
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0))
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t)

t = tiler.get_tile(self.data, 0, copy=False)
t[9] = 0
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0))
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t)
t[9] = 9

# copy test with iterator
t = list(tiler(self.data, copy_data=True))
t[0][1][9] = 0
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], tiler.get_tile(self.data, 0))
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0][1])
self.assertNotEqual(t[0][1][9], self.data[9])

def test_view_in_tiles(self):
tiler = Tiler(image_shape=self.data.shape,
tile_shape=(10, ))
with self.assertRaises(ValueError):
tiler.view_in_tiles(self.data.reshape((10, 10)))
t = [tile for _, tile in tiler(self.data, copy_data=False)]
t[0][9] = 0
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], tiler.get_tile(self.data, 0))
np.testing.assert_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 0], t[0])
self.assertEqual(t[0][9], self.data[9])

def test_overlap(self):
# Case 1
Expand Down Expand Up @@ -241,21 +256,6 @@ def test_overlap(self):
self.assertEqual(len(tiler), len(expected_split))
np.testing.assert_equal(expected_split, calculated_split)

def test_tile_sample_shape(self):
tile_size = 10
tiler = Tiler(image_shape=self.data.shape,
tile_shape=(tile_size, ),
channel_dimension=None)
tiler2 = Tiler(image_shape=(3, ) + self.data.shape,
tile_shape=(3, tile_size),
channel_dimension=0)

with self.assertRaises(IndexError):
tiler.get_tile_sample_shape(len(tiler))

np.testing.assert_equal([tile_size], tiler.get_tile_sample_shape(len(tiler) - 1))
np.testing.assert_equal([tile_size], tiler2.get_tile_sample_shape(len(tiler) - 1))

def test_tile_mosaic_position(self):
tile_size = 10
tiler = Tiler(image_shape=self.data.shape, tile_shape=(tile_size, ))
Expand Down
2 changes: 1 addition & 1 deletion tiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.1'
__version__ = '0.1.2'

from tiler.tiler import Tiler
from tiler.merger import Merger
2 changes: 1 addition & 1 deletion tiler/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def reset(self) -> None:

# Image holds sum of all processed tiles multiplied by the window
if self.logits:
self.data = np.zeros([self.logits] + padded_data_shape)
self.data = np.zeros(np.hstack((self.logits, padded_data_shape)))
else:
self.data = np.zeros(padded_data_shape)

Expand Down
Loading

0 comments on commit 3a817f7

Please sign in to comment.