From e2dbfc75deadf75678a0134183b56012cebeee28 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 30 Jul 2024 15:57:35 -0700 Subject: [PATCH] Avoid complex->real casts via jax.numpy.astype This currently issues a warning about implicitly discarding the imaginary part, and it will issue an error in the future. PiperOrigin-RevId: 657760335 --- jax_cfd/base/fast_diagonalization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jax_cfd/base/fast_diagonalization.py b/jax_cfd/base/fast_diagonalization.py index 86705c7..af00a31 100644 --- a/jax_cfd/base/fast_diagonalization.py +++ b/jax_cfd/base/fast_diagonalization.py @@ -165,6 +165,13 @@ def apply(rhs: Array) -> Array: return apply +def _cast(x, dtype): + if (np.issubdtype(x.dtype, np.complexfloating) + and not np.issubdtype(dtype, np.complexfloating)): + x = x.real + return x.astype(dtype) + + def _circulant_fft_transform( func: Callable[[Array], Array], operators: Sequence[np.ndarray], @@ -184,7 +191,7 @@ def _circulant_fft_transform( def apply(rhs: Array) -> Array: if rhs.shape != shape: raise ValueError(f'rhs.shape={rhs.shape} does not match shape={shape}') - return jnp.fft.ifftn(diagonals * jnp.fft.fftn(rhs)).astype(dtype) + return _cast(jnp.fft.ifftn(diagonals * jnp.fft.fftn(rhs)), dtype) return apply @@ -213,7 +220,7 @@ def _circulant_rfft_transform( def apply(rhs: Array) -> Array: if rhs.dtype != dtype: raise ValueError(f'rhs.dtype={rhs.dtype} does not match dtype={dtype}') - return jnp.fft.irfftn(diagonals * jnp.fft.rfftn(rhs)).astype(dtype) + return _cast(jnp.fft.irfftn(diagonals * jnp.fft.rfftn(rhs)), dtype) return apply