From f7284543652dc825bd84913699526868292c94df Mon Sep 17 00:00:00 2001 From: Pibe97 <55357900+Pibe97@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:10:31 +0100 Subject: [PATCH] Fix 2/3 rule filter Currently, the 2/3 filter retains more negative than positive frequencies in the kx direction. For example, working on a 128 by 128 grid (i.e. 128 by 65 in rfft2 space), in the kx direction (i.e. rows) the current filter retains the kx = 0 and 41 positive frequencies, but 43 negative frequencies. This error can be spotted by checking that applying jnp.fft.irfft2 followed by jnp.fft.rfft2 to a trajectory output, which are inverse operations, changes the output. Alternatively, the 0th column (ky=0) of a trajectory snapshot (in rfft2 space), should have complex conjugate symmetry between the positive and negative kx frequencies. The code has been adjusted to retain the kx = 0 frequency (hence the + 1 in line 132), int(2 / 3 * n) // 2 positive and int(2 / 3 * n) // 2 negative frequencies (fixed by adding parentheses in line 133, to have correct order of operations) in the kx direction. For the 128 by 128 grid, we have int(2 / 3 * n) // 2 = 42. --- jax_cfd/spectral/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_cfd/spectral/utils.py b/jax_cfd/spectral/utils.py index 97a6b49..4cc9cad 100644 --- a/jax_cfd/spectral/utils.py +++ b/jax_cfd/spectral/utils.py @@ -129,8 +129,8 @@ def brick_wall_filter_2d(grid: grids.Grid): """Implements the 2/3 rule.""" n, m = grid.shape filter_ = jnp.zeros((n, m // 2 + 1)) - filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (m // 2 + 1))].set(1) - filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (m // 2 + 1))].set(1) + filter_ = filter_.at[:(int(2 / 3 * n) // 2 + 1), :int(2 / 3 * (m // 2 + 1))].set(1) + filter_ = filter_.at[-(int(2 / 3 * n) // 2):, :int(2 / 3 * (m // 2 + 1))].set(1) return filter_