From c7e548cff8f05e389ec050e80ca9588e6e11a970 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 17 Oct 2024 11:10:50 -0400 Subject: [PATCH] expose FFTs --- jaxdecomp/fft.py | 86 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 18 deletions(-) diff --git a/jaxdecomp/fft.py b/jaxdecomp/fft.py index 465ca73..a7b7f7d 100644 --- a/jaxdecomp/fft.py +++ b/jaxdecomp/fft.py @@ -3,11 +3,13 @@ import jax.numpy as jnp from jax import jit +from jax._src.numpy.util import promote_dtypes_inexact from jax._src.typing import Array, ArrayLike from jax.lib import xla_client import jaxdecomp -from jaxdecomp._src import pfft as _pfft +from jaxdecomp._src.cudecomp.fft import pfft as _cudecomp_pfft +from jaxdecomp._src.jax.fft import pfft as _jax_pfft Shape = Sequence[int] @@ -16,8 +18,41 @@ "pifft3d", ] +FftType = xla_client.FftType -def _fft_norm(s: Array, func_name: str, norm: str) -> Array: + +def _str_to_fft_type(s: str) -> xla_client.FftType | int: + """ + Convert a string to an FFT type enum. + + Parameters + ---------- + s : str + String representation of FFT type. + + Returns + ------- + xla_client.FftType + Corresponding FFT type enum. + + Raises + ------ + ValueError + If the string `s` does not match known FFT types. + """ + if s in ("fft", "FFT"): + return xla_client.FftType.FFT + elif s in ("ifft", "IFFT"): + return xla_client.FftType.IFFT + elif s in ("rfft", "RFFT"): + return xla_client.FftType.RFFT + elif s in ("irfft", "IRFFT"): + return xla_client.FftType.IRFFT + else: + raise ValueError(f"Unknown FFT type '{s}'") + + +def _fft_norm(s: Array, func_name: str, norm: Optional[str]) -> Array: """ Compute the normalization factor for FFT operations. @@ -43,23 +78,21 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: if norm == "backward": return 1 / jnp.prod(s) if func_name.startswith("i") else jnp.array(1) elif norm == "ortho": - return (1 / jnp.sqrt(jnp.prod(s)) if func_name.startswith("i") else 1 / - jnp.sqrt(jnp.prod(s))) + return (1 / jnp.sqrt(jnp.prod(s))) elif norm == "forward": - return jnp.prod(s) if func_name.startswith("i") else 1 / jnp.prod(s)**2 + return jnp.array(1) if func_name.startswith("i") else 1 / jnp.prod(s) raise ValueError(f'Invalid norm value {norm}; should be "backward",' '"ortho" or "forward".') # Has to be jitted here because _fft_norm will act on non fully addressable global array # Which means this should be jit wrapped -@partial(jit, static_argnums=(0, 1, 3)) -def _do_pfft( - func_name: str, - fft_type: xla_client.FftType, - arr: ArrayLike, - norm: Optional[str], -) -> Array: +@partial(jit, static_argnums=(0, 1, 3, 4)) +def _do_pfft(func_name: str, + fft_type: xla_client.FftType, + arr: Array, + norm: Optional[str], + backend: str = "JAX") -> Array: """ Perform 3D FFT or inverse 3D FFT on the input array. @@ -79,14 +112,28 @@ def _do_pfft( Array Transformed array after FFT or inverse FFT. """ - local_transpose = jaxdecomp.config.transpose_axis_contiguous - transformed = _pfft(arr, fft_type, False, local_transpose) + if isinstance(fft_type, str): + typ = _str_to_fft_type(fft_type) + elif isinstance(fft_type, xla_client.FftType): + typ = fft_type + else: + raise TypeError(f"Unknown FFT type value '{fft_type}'") + print(f"Backend is {backend}") + if backend.lower() == "cudecomp": + transformed = _cudecomp_pfft(arr, typ) + elif backend.lower() == "jax": + transformed = _jax_pfft(arr, typ) + else: + raise ValueError(f"Unknown backend value '{backend}'") + transformed *= _fft_norm( jnp.array(arr.shape, dtype=transformed.dtype), func_name, norm) return transformed -def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: +def pfft3d(a: ArrayLike, + norm: Optional[str] = "backward", + backend: str = "JAX") -> Array: """ Perform 3D FFT on the input array. @@ -102,10 +149,12 @@ def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: Array Transformed array after 3D FFT. """ - return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm) + return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm, backend=backend) -def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: +def pifft3d(a: ArrayLike, + norm: Optional[str] = "backward", + backend: str = "JAX") -> Array: """ Perform inverse 3D FFT on the input array. @@ -121,4 +170,5 @@ def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: Array Transformed array after inverse 3D FFT. """ - return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm) + return _do_pfft( + "ifft", xla_client.FftType.IFFT, a, norm=norm, backend=backend)