Skip to content

Commit

Permalink
expose FFTs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 17, 2024
1 parent 91f4b85 commit c7e548c
Showing 1 changed file with 68 additions and 18 deletions.
86 changes: 68 additions & 18 deletions jaxdecomp/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import jax.numpy as jnp
from jax import jit
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
from jax.lib import xla_client

import jaxdecomp
from jaxdecomp._src import pfft as _pfft
from jaxdecomp._src.cudecomp.fft import pfft as _cudecomp_pfft
from jaxdecomp._src.jax.fft import pfft as _jax_pfft

Shape = Sequence[int]

Expand All @@ -16,8 +18,41 @@
"pifft3d",
]

FftType = xla_client.FftType

def _fft_norm(s: Array, func_name: str, norm: str) -> Array:

def _str_to_fft_type(s: str) -> xla_client.FftType | int:
"""
Convert a string to an FFT type enum.
Parameters
----------
s : str
String representation of FFT type.
Returns
-------
xla_client.FftType
Corresponding FFT type enum.
Raises
------
ValueError
If the string `s` does not match known FFT types.
"""
if s in ("fft", "FFT"):
return xla_client.FftType.FFT
elif s in ("ifft", "IFFT"):
return xla_client.FftType.IFFT
elif s in ("rfft", "RFFT"):
return xla_client.FftType.RFFT
elif s in ("irfft", "IRFFT"):
return xla_client.FftType.IRFFT
else:
raise ValueError(f"Unknown FFT type '{s}'")


def _fft_norm(s: Array, func_name: str, norm: Optional[str]) -> Array:
"""
Compute the normalization factor for FFT operations.
Expand All @@ -43,23 +78,21 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
if norm == "backward":
return 1 / jnp.prod(s) if func_name.startswith("i") else jnp.array(1)
elif norm == "ortho":
return (1 / jnp.sqrt(jnp.prod(s)) if func_name.startswith("i") else 1 /
jnp.sqrt(jnp.prod(s)))
return (1 / jnp.sqrt(jnp.prod(s)))
elif norm == "forward":
return jnp.prod(s) if func_name.startswith("i") else 1 / jnp.prod(s)**2
return jnp.array(1) if func_name.startswith("i") else 1 / jnp.prod(s)
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')


# Has to be jitted here because _fft_norm will act on non fully addressable global array
# Which means this should be jit wrapped
@partial(jit, static_argnums=(0, 1, 3))
def _do_pfft(
func_name: str,
fft_type: xla_client.FftType,
arr: ArrayLike,
norm: Optional[str],
) -> Array:
@partial(jit, static_argnums=(0, 1, 3, 4))
def _do_pfft(func_name: str,
fft_type: xla_client.FftType,
arr: Array,
norm: Optional[str],
backend: str = "JAX") -> Array:
"""
Perform 3D FFT or inverse 3D FFT on the input array.
Expand All @@ -79,14 +112,28 @@ def _do_pfft(
Array
Transformed array after FFT or inverse FFT.
"""
local_transpose = jaxdecomp.config.transpose_axis_contiguous
transformed = _pfft(arr, fft_type, False, local_transpose)
if isinstance(fft_type, str):
typ = _str_to_fft_type(fft_type)
elif isinstance(fft_type, xla_client.FftType):
typ = fft_type
else:
raise TypeError(f"Unknown FFT type value '{fft_type}'")
print(f"Backend is {backend}")
if backend.lower() == "cudecomp":
transformed = _cudecomp_pfft(arr, typ)
elif backend.lower() == "jax":
transformed = _jax_pfft(arr, typ)
else:
raise ValueError(f"Unknown backend value '{backend}'")

transformed *= _fft_norm(
jnp.array(arr.shape, dtype=transformed.dtype), func_name, norm)
return transformed


def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
def pfft3d(a: ArrayLike,
norm: Optional[str] = "backward",
backend: str = "JAX") -> Array:
"""
Perform 3D FFT on the input array.
Expand All @@ -102,10 +149,12 @@ def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
Array
Transformed array after 3D FFT.
"""
return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm)
return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm, backend=backend)


def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
def pifft3d(a: ArrayLike,
norm: Optional[str] = "backward",
backend: str = "JAX") -> Array:
"""
Perform inverse 3D FFT on the input array.
Expand All @@ -121,4 +170,5 @@ def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
Array
Transformed array after inverse 3D FFT.
"""
return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm)
return _do_pfft(
"ifft", xla_client.FftType.IFFT, a, norm=norm, backend=backend)

0 comments on commit c7e548c

Please sign in to comment.