From 53fbe83986941887ad5ba4ec8824de75fcfc0d63 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Mon, 16 Dec 2024 17:44:19 +0100 Subject: [PATCH] Fix gradient issue with PFFT3D --- src/jaxdecomp/fft.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jaxdecomp/fft.py b/src/jaxdecomp/fft.py index 1651cfd..18fb915 100644 --- a/src/jaxdecomp/fft.py +++ b/src/jaxdecomp/fft.py @@ -9,6 +9,8 @@ from jaxdecomp._src.fft_utils import FftType from jaxdecomp._src.jax import fftfreq as _fftfreq from jaxdecomp._src.jax.fft import pfft as _jax_pfft +from jax import lax +from jax._src import dtypes Shape = Sequence[int] @@ -115,6 +117,13 @@ def _do_pfft(func_name: str, typ = fft_type else: raise TypeError(f"Unknown FFT type value '{fft_type}'") + + match typ: + case FftType.FFT | FftType.IFFT: + arr = lax.convert_element_type(arr, dtypes.to_complex_dtype(dtypes.dtype(arr))) + case FftType.RFFT | FftType.IRFFT: + raise ValueError("Not implemented wait (SOON)") + if backend.lower() == "cudecomp": transformed = _cudecomp_pfft(arr, typ) elif backend.lower() == "jax":