Skip to content

Commit

Permalink
Fix gradient issue with PFFT3D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Dec 16, 2024
1 parent ecaadb8 commit 53fbe83
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/jaxdecomp/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from jaxdecomp._src.fft_utils import FftType
from jaxdecomp._src.jax import fftfreq as _fftfreq
from jaxdecomp._src.jax.fft import pfft as _jax_pfft
from jax import lax
from jax._src import dtypes

Shape = Sequence[int]

Expand Down Expand Up @@ -115,6 +117,13 @@ def _do_pfft(func_name: str,
typ = fft_type
else:
raise TypeError(f"Unknown FFT type value '{fft_type}'")

match typ:
case FftType.FFT | FftType.IFFT:
arr = lax.convert_element_type(arr, dtypes.to_complex_dtype(dtypes.dtype(arr)))
case FftType.RFFT | FftType.IRFFT:
raise ValueError("Not implemented wait (SOON)")

if backend.lower() == "cudecomp":
transformed = _cudecomp_pfft(arr, typ)
elif backend.lower() == "jax":
Expand Down

0 comments on commit 53fbe83

Please sign in to comment.