diff --git a/docs/index.html b/docs/index.html
index ff17ab2..da759e2 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -2000,7 +2000,7 @@
Return
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.
Default is `np.float32`.
- weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.
+ weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use `np.float16`.
Likewise, on the opposite, you can use `np.float64`.
Default is `np.float32`.
@@ -2039,7 +2039,7 @@ Return
np.ndarray: n-dimensional window of the given shape and function
"""
- w = np.ones(shape)
+ w = np.ones(shape, dtype=self.weights_dtype)
overlap = self.tiler._tile_overlap
for axis, length in enumerate(shape):
if axis == self.tiler.channel_dimension:
@@ -2095,7 +2095,7 @@ Return
raise ValueError(
f"Window function must have the same shape as tile shape."
)
- self.window = window
+ self.window = window.astype(self.weights_dtype)
else:
raise ValueError(
f"Unsupported type for window function ({type(window)}), expected str or np.ndarray."
@@ -2363,7 +2363,7 @@ Return
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.
Default is `np.float32`.
- weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.
+ weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use `np.float16`.
Likewise, on the opposite, you can use `np.float64`.
Default is `np.float32`.
@@ -2424,7 +2424,7 @@ Args
self.data_visits
. Can be disabled to save some memory. Default is True
.
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.
Default is np.float32
.
-weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.
+weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use np.float16
.
Likewise, on the opposite, you can use np.float64
.
Default is np.float32
.
@@ -2516,7 +2516,7 @@ Args
raise ValueError(
f"Window function must have the same shape as tile shape."
)
- self.window = window
+ self.window = window.astype(self.weights_dtype)
else:
raise ValueError(
f"Unsupported type for window function ({type(window)}), expected str or np.ndarray."
diff --git a/tests/test_merger.py b/tests/test_merger.py
index 6f2bcbc..c602bc7 100644
--- a/tests/test_merger.py
+++ b/tests/test_merger.py
@@ -29,11 +29,17 @@ def test_init(self):
self.assert_(merger3.data_visits is not None)
# Check data and weights dtypes
- merger4 = Merger(tiler=tiler, data_dtype=np.float32, weights_dtype=np.float32)
+ merger4 = Merger(
+ tiler=tiler,
+ data_dtype=np.float32,
+ weights_dtype=np.float32,
+ window="boxcar",
+ )
self.assertEqual(merger4.data.dtype, np.float32)
self.assertEqual(merger4.data_dtype, np.float32)
self.assertEqual(merger4.weights_sum.dtype, np.float32)
self.assertEqual(merger4.weights_dtype, np.float32)
+ self.assertEqual(merger4.window.dtype, np.float32)
def test_add(self):
tiler = Tiler(data_shape=self.data.shape, tile_shape=(10,))
diff --git a/tiler/merger.py b/tiler/merger.py
index 2f28e56..5672142 100644
--- a/tiler/merger.py
+++ b/tiler/merger.py
@@ -94,7 +94,7 @@ def __init__(
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.
Default is `np.float32`.
- weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.
+ weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use `np.float16`.
Likewise, on the opposite, you can use `np.float64`.
Default is `np.float32`.
@@ -133,7 +133,7 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray
np.ndarray: n-dimensional window of the given shape and function
"""
- w = np.ones(shape)
+ w = np.ones(shape, dtype=self.weights_dtype)
overlap = self.tiler._tile_overlap
for axis, length in enumerate(shape):
if axis == self.tiler.channel_dimension:
@@ -189,7 +189,7 @@ def set_window(self, window: Union[None, str, np.ndarray] = None) -> None:
raise ValueError(
f"Window function must have the same shape as tile shape."
)
- self.window = window
+ self.window = window.astype(self.weights_dtype)
else:
raise ValueError(
f"Unsupported type for window function ({type(window)}), expected str or np.ndarray."