From 7d7510297e5a2c895801841f94f338d76f97cd91 Mon Sep 17 00:00:00 2001 From: the-lay Date: Fri, 7 Jan 2022 00:06:06 +0200 Subject: [PATCH] Weights data type now also specifies window data type --- docs/index.html | 12 ++++++------ tests/test_merger.py | 8 +++++++- tiler/merger.py | 6 +++--- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/docs/index.html b/docs/index.html index ff17ab2..da759e2 100644 --- a/docs/index.html +++ b/docs/index.html @@ -2000,7 +2000,7 @@
Return
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result. Default is `np.float32`. - weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights. + weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array. If you don't need precision but would rather save memory you can use `np.float16`. Likewise, on the opposite, you can use `np.float64`. Default is `np.float32`. @@ -2039,7 +2039,7 @@
Return
np.ndarray: n-dimensional window of the given shape and function """ - w = np.ones(shape) + w = np.ones(shape, dtype=self.weights_dtype) overlap = self.tiler._tile_overlap for axis, length in enumerate(shape): if axis == self.tiler.channel_dimension: @@ -2095,7 +2095,7 @@
Return
raise ValueError( f"Window function must have the same shape as tile shape." ) - self.window = window + self.window = window.astype(self.weights_dtype) else: raise ValueError( f"Unsupported type for window function ({type(window)}), expected str or np.ndarray." @@ -2363,7 +2363,7 @@
Return
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result. Default is `np.float32`. - weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights. + weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array. If you don't need precision but would rather save memory you can use `np.float16`. Likewise, on the opposite, you can use `np.float64`. Default is `np.float32`. @@ -2424,7 +2424,7 @@
Args
self.data_visits. Can be disabled to save some memory. Default is True.
  • data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result. Default is np.float32.
  • -
  • weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights. +
  • weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array. If you don't need precision but would rather save memory you can use np.float16. Likewise, on the opposite, you can use np.float64. Default is np.float32.
  • @@ -2516,7 +2516,7 @@
    Args
    raise ValueError( f"Window function must have the same shape as tile shape." ) - self.window = window + self.window = window.astype(self.weights_dtype) else: raise ValueError( f"Unsupported type for window function ({type(window)}), expected str or np.ndarray." diff --git a/tests/test_merger.py b/tests/test_merger.py index 6f2bcbc..c602bc7 100644 --- a/tests/test_merger.py +++ b/tests/test_merger.py @@ -29,11 +29,17 @@ def test_init(self): self.assert_(merger3.data_visits is not None) # Check data and weights dtypes - merger4 = Merger(tiler=tiler, data_dtype=np.float32, weights_dtype=np.float32) + merger4 = Merger( + tiler=tiler, + data_dtype=np.float32, + weights_dtype=np.float32, + window="boxcar", + ) self.assertEqual(merger4.data.dtype, np.float32) self.assertEqual(merger4.data_dtype, np.float32) self.assertEqual(merger4.weights_sum.dtype, np.float32) self.assertEqual(merger4.weights_dtype, np.float32) + self.assertEqual(merger4.window.dtype, np.float32) def test_add(self): tiler = Tiler(data_shape=self.data.shape, tile_shape=(10,)) diff --git a/tiler/merger.py b/tiler/merger.py index 2f28e56..5672142 100644 --- a/tiler/merger.py +++ b/tiler/merger.py @@ -94,7 +94,7 @@ def __init__( data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result. Default is `np.float32`. - weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights. + weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array. If you don't need precision but would rather save memory you can use `np.float16`. Likewise, on the opposite, you can use `np.float64`. Default is `np.float32`. @@ -133,7 +133,7 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray np.ndarray: n-dimensional window of the given shape and function """ - w = np.ones(shape) + w = np.ones(shape, dtype=self.weights_dtype) overlap = self.tiler._tile_overlap for axis, length in enumerate(shape): if axis == self.tiler.channel_dimension: @@ -189,7 +189,7 @@ def set_window(self, window: Union[None, str, np.ndarray] = None) -> None: raise ValueError( f"Window function must have the same shape as tile shape." ) - self.window = window + self.window = window.astype(self.weights_dtype) else: raise ValueError( f"Unsupported type for window function ({type(window)}), expected str or np.ndarray."