diff --git a/src/furax/core/_toeplitz.py b/src/furax/core/_toeplitz.py index 41fb7f9..d5b724c 100644 --- a/src/furax/core/_toeplitz.py +++ b/src/furax/core/_toeplitz.py @@ -170,7 +170,7 @@ def _apply_overlap_save(self, x: Array, band_values: Array) -> Array: x_padding_start = overlap x_padding_end = total_length - overlap - l x_padded = jnp.pad(x, (x_padding_start, x_padding_end), mode='constant') - y = jnp.zeros(l + x_padding_end) + y = jnp.zeros(l + x_padding_end, dtype=x.dtype) def func(iblock, y): # type: ignore[no-untyped-def] position = iblock * step_size @@ -235,7 +235,7 @@ def _overlap_add_jax(x, H, fft_size, b): # type: ignore[no-untyped-def] # pad x so that its size is a multiple of m x_padding = 0 if l % m == 0 else m - (l % m) x_padded = jnp.pad(x, (x_padding,), mode='constant') - y = jnp.zeros(l + b) + y = jnp.zeros(l + b, dtype=x.dtype) def func(j, y): # type: ignore[no-untyped-def] i = j * m diff --git a/tests/core/test_toeplitz.py b/tests/core/test_toeplitz.py index c6989da..bce20a0 100644 --- a/tests/core/test_toeplitz.py +++ b/tests/core/test_toeplitz.py @@ -1,4 +1,5 @@ import itertools +from math import prod import jax import jax.numpy as jnp @@ -88,7 +89,7 @@ def test(method: str, do_jit: bool) -> None: def test_multidimensional( in_shape: tuple[int, ...], band_shape: tuple[int, ...], method: str ) -> None: - band_values = jnp.arange(np.prod(band_shape)).reshape(band_shape) + band_values = jnp.arange(prod(band_shape), dtype=jnp.float64).reshape(band_shape) in_structure = jax.ShapeDtypeStruct(in_shape, jnp.float64) op = SymmetricBandToeplitzOperator(band_values, in_structure, method=method) broadcast_band_values = jnp.broadcast_to(band_values, in_shape[:-1] + (4,))