Skip to content

Commit

Permalink
Fix dtype in toeplitz operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuhyun Sohn authored and pchanial committed Jan 22, 2025
1 parent 26723e4 commit 693df6f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/furax/core/_toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _apply_overlap_save(self, x: Array, band_values: Array) -> Array:
x_padding_start = overlap
x_padding_end = total_length - overlap - l
x_padded = jnp.pad(x, (x_padding_start, x_padding_end), mode='constant')
y = jnp.zeros(l + x_padding_end)
y = jnp.zeros(l + x_padding_end, dtype=x.dtype)

def func(iblock, y): # type: ignore[no-untyped-def]
position = iblock * step_size
Expand Down Expand Up @@ -235,7 +235,7 @@ def _overlap_add_jax(x, H, fft_size, b): # type: ignore[no-untyped-def]
# pad x so that its size is a multiple of m
x_padding = 0 if l % m == 0 else m - (l % m)
x_padded = jnp.pad(x, (x_padding,), mode='constant')
y = jnp.zeros(l + b)
y = jnp.zeros(l + b, dtype=x.dtype)

def func(j, y): # type: ignore[no-untyped-def]
i = j * m
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_toeplitz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from math import prod

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -88,7 +89,7 @@ def test(method: str, do_jit: bool) -> None:
def test_multidimensional(
in_shape: tuple[int, ...], band_shape: tuple[int, ...], method: str
) -> None:
band_values = jnp.arange(np.prod(band_shape)).reshape(band_shape)
band_values = jnp.arange(prod(band_shape), dtype=jnp.float64).reshape(band_shape)
in_structure = jax.ShapeDtypeStruct(in_shape, jnp.float64)
op = SymmetricBandToeplitzOperator(band_values, in_structure, method=method)
broadcast_band_values = jnp.broadcast_to(band_values, in_shape[:-1] + (4,))
Expand Down

0 comments on commit 693df6f

Please sign in to comment.