Skip to content

Commit

Permalink
make get_kernel a private method
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiquard committed Dec 11, 2024
1 parent 87acc0f commit 28ddaa7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
22 changes: 11 additions & 11 deletions src/furax/preprocessing/gap_filling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,6 @@ def __call__(self, key: PRNGKeyArray, x: Float[Array, ' *shape']) -> Float[Array
y = y.at[p.indices].set(p(x))
return y

@staticmethod
def get_kernel(n_tt: Float[Array, ' _'], size: int) -> Float[Array, ' {size}']:
lagmax = n_tt.size - 1
padding_size = size - (2 * lagmax + 1)
if padding_size < 0:
msg = f'The maximum lag ({lagmax}) is too large for the required kernel size ({size}).'
raise ValueError(msg)
kernel = jnp.concatenate((n_tt, jnp.zeros(padding_size), n_tt[-1:0:-1]))
return kernel

@staticmethod
def folded_psd(
n_tt: Float[ArrayLike, ' _'], fft_size: int
Expand All @@ -68,12 +58,22 @@ def folded_psd(
fft_size: The size of the FFT to use (at least twice the size of ``n_tt``).
"""
n_tt = jnp.asarray(n_tt)
kernel = GapFillingOperator.get_kernel(n_tt, fft_size)
kernel = GapFillingOperator._get_kernel(n_tt, fft_size)
psd = jnp.abs(jnp.fft.rfft(kernel, n=fft_size))
# zero out DC value
psd = psd.at[0].set(0)
return psd

@staticmethod
def _get_kernel(n_tt: Float[Array, ' _'], size: int) -> Float[Array, ' {size}']:
lagmax = n_tt.size - 1
padding_size = size - (2 * lagmax + 1)
if padding_size < 0:
msg = f'The maximum lag ({lagmax}) is too large for the required kernel size ({size}).'
raise ValueError(msg)
kernel = jnp.concatenate((n_tt, jnp.zeros(padding_size), n_tt[-1:0:-1]))
return kernel

def _generate_realization_for(
self, x: Float[Array, ' *shape'], key: PRNGKeyArray
) -> Float[Array, ' *shape']:
Expand Down
4 changes: 2 additions & 2 deletions tests/preprocessing/test_gap_filling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, num: int | tuple[int, ...]) -> None:
def test_get_kernel(n_tt: list[int], fft_size: int, expected_kernel: list[int]):
n_tt = jnp.array(n_tt)
expected_kernel = np.array(expected_kernel)
actual_kernel = GapFillingOperator.get_kernel(n_tt, fft_size)
actual_kernel = GapFillingOperator._get_kernel(n_tt, fft_size)
assert_allclose(actual_kernel, expected_kernel)


Expand All @@ -37,7 +37,7 @@ def test_get_kernel_fail_lagmax(n_tt: list[int], fft_size: int):
# This test should fail because the maximum lag is too large for the required fft_size
n_tt = jnp.array(n_tt)
with pytest.raises(ValueError):
_ = GapFillingOperator.get_kernel(n_tt, fft_size)
_ = GapFillingOperator._get_kernel(n_tt, fft_size)


@pytest.mark.parametrize('do_jit', [False, True])
Expand Down

0 comments on commit 28ddaa7

Please sign in to comment.