Skip to content

Commit

Permalink
handle slabs from python and handle Y pencil case as start layout
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jul 24, 2024
1 parent 66b17f2 commit e654d19
Showing 1 changed file with 94 additions and 12 deletions.
106 changes: 94 additions & 12 deletions jaxdecomp/_src/fft.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from enum import Enum
from functools import partial
from typing import Tuple, Union

import jax
import jaxlib.mlir.ir as ir
import numpy as np
from jax import ShapeDtypeStruct
from jax._src import mesh as mesh_lib
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.util import promote_dtypes_complex
Expand All @@ -23,6 +25,23 @@
FftType = xla_client.FftType


def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
pdims = None
else:
pdims = mesh.devices.shape[::-1]

if pdims == (1, 1) or pdims == None:
return _jaxdecomp.NO_DECOMP
elif pdims[0] == 1:
return _jaxdecomp.SLAB_XY
elif pdims[1] == 1:
return _jaxdecomp.SLAB_YZ
else:
return _jaxdecomp.PENCILS


def _str_to_fft_type(s: str) -> xla_client.FftType:
"""
Convert a string to an FFT type enum.
Expand Down Expand Up @@ -93,17 +112,41 @@ def abstract(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, int],
if global_shape == x.shape:
return FFTPrimitive.outer_abstract(x, fft_type=fft_type, adjoint=adjoint)

pencil_type = get_pencil_type()
match fft_type:
case xla_client.FftType.FFT:
# FFT is X to Y to Z so Z-Pencil is returned
# Except if we are doing a YZ slab in which case we return a Y-Pencil
transpose_shape = (1, 2, 0)
transposed_pdims = pdims
match pencil_type:
case _jaxdecomp.SLAB_YZ:
transpose_shape = (2, 0, 1)
transposed_pdims = pdims[::-1]
case _jaxdecomp.SLAB_XY:
transpose_shape = (1, 2, 0)
transposed_pdims = pdims
case _jaxdecomp.PENCILS:
transpose_shape = (1, 2, 0)
transposed_pdims = pdims
case _jaxdecomp.NO_DECOMP:
transpose_shape = (0, 1, 2)
transposed_pdims = (1, 1)

case xla_client.FftType.IFFT:
# IFFT is Z to X to Y so X-Pencil is returned
# In YZ slab case we only need one transposition back to get the X-Pencil
transpose_shape = (2, 0, 1)
transposed_pdims = pdims
match pencil_type:
case _jaxdecomp.SLAB_YZ:
transpose_shape = (1, 2, 0)
transposed_pdims = pdims
case _jaxdecomp.SLAB_XY:
transpose_shape = (2, 0, 1)
transposed_pdims = pdims[::-1]
case _jaxdecomp.PENCILS:
transpose_shape = (2, 0, 1)
transposed_pdims = pdims
case _jaxdecomp.NO_DECOMP:
transpose_shape = (0, 1, 2)
transposed_pdims = (1, 1)
case _:
raise TypeError(
"only complex FFTs are currently supported through pfft.")
Expand Down Expand Up @@ -134,15 +177,28 @@ def outer_abstract(x: Array, fft_type: xla_client.FftType,
ShapedArray
Shape of the output array.
"""

# TODO(Wassim) we should not get here if we do not have a context mesh
pencil_type = get_pencil_type()
match fft_type:
case xla_client.FftType.FFT:
# FFT is X to Y to Z so Z-Pencil is returned
# Except if we are doing a YZ slab in which case we return a Y-Pencil
transpose_shape = (1, 2, 0)
match pencil_type:
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
transpose_shape = (1, 2, 0)
case _jaxdecomp.SLAB_YZ:
transpose_shape = (2, 0, 1)

case xla_client.FftType.IFFT:
# IFFT is Z to X to Y so X-Pencil is returned
# In YZ slab case we only need one transposition back to get the X-Pencil
transpose_shape = (2, 0, 1)
match pencil_type:
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
transpose_shape = (2, 0, 1)
case _jaxdecomp.SLAB_YZ:
transpose_shape = (1, 2, 0)

case _:
raise TypeError(
"only complex FFTs are currently supported through pfft.")
Expand Down Expand Up @@ -190,11 +246,20 @@ def lowering(ctx, a: Array, *, fft_type: xla_client.FftType,
is_double = np.finfo(dtype).dtype == np.float64

# Get original global shape
pencil_type = get_pencil_type()
match fft_type:
case xla_client.FftType.FFT:
transpose_back_shape = (0, 1, 2)
if pencil_type == _jaxdecomp.SLAB_XY:
transpose_back_shape = (1, 2, 0)
else:
transpose_back_shape = (0, 1, 2)
case xla_client.FftType.IFFT:
transpose_back_shape = (2, 0, 1)
if pencil_type == _jaxdecomp.SLAB_XY:
transpose_back_shape = (0, 1, 2)
elif pencil_type == _jaxdecomp.SLAB_YZ:
transpose_back_shape = (1, 2, 0)
else:
transpose_back_shape = (2, 0, 1)
case _:
raise TypeError(
"only complex FFTs are currently supported through pfft.")
Expand All @@ -207,7 +272,7 @@ def lowering(ctx, a: Array, *, fft_type: xla_client.FftType,
config.halo_comm_backend = jaxdecomp.config.halo_comm_backend
config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend
workspace_size, opaque = _jaxdecomp.build_fft_descriptor(
config, forward, is_double, adjoint)
config, forward, is_double, adjoint, pencil_type)

n = len(a_type.shape)
layout = tuple(range(n - 1, -1, -1))
Expand Down Expand Up @@ -261,6 +326,7 @@ def impl(x: Array, fft_type: Union[str, xla_client.FftType], adjoint: bool):
if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]:
raise TypeError("only complex FFTs are currently supported through pfft.")

# TODO (Wassim) this should call jnp.fft.fftn so it works in single device
pdims = (1, jax.device_count())
global_shape = x.shape

Expand Down Expand Up @@ -330,7 +396,15 @@ def infer_sharding_from_operands(
Sharding information for the result.
"""
input_sharding = arg_infos[0].sharding
return NamedSharding(mesh, P(*input_sharding.spec))
pencil_type = get_pencil_type()
match pencil_type:
case _jaxdecomp.SLAB_XY | _jaxdecomp.SLAB_YZ:
transposed_specs = (input_sharding.spec[1], input_sharding.spec[0],
None)
case _jaxdecomp.PENCILS:
transposed_specs = input_sharding.spec

return NamedSharding(mesh, P(*transposed_specs))

@staticmethod
def partition(
Expand Down Expand Up @@ -360,10 +434,18 @@ def partition(
"""
input_sharding = NamedSharding(mesh, P(*arg_shapes[0].sharding.spec))
output_sharding = NamedSharding(mesh, P(*result_shape.sharding.spec))

pdims = (get_axis_size(input_sharding, 1), get_axis_size(input_sharding, 0))
global_shape = arg_shapes[0].shape

pencil_type = get_pencil_type()
if (pencil_type == _jaxdecomp.SLAB_XY and fft_type == FftType.IFFT) or \
(pencil_type == _jaxdecomp.SLAB_YZ and fft_type == FftType.FFT or \
pencil_type == _jaxdecomp.PENCILS): #yapf: disable
pdims = (get_axis_size(input_sharding,
1), get_axis_size(input_sharding, 0))
else:
pdims = (get_axis_size(input_sharding,
0), get_axis_size(input_sharding, 1))

impl = partial(
FFTPrimitive.per_shard_impl,
fft_type=fft_type,
Expand Down

0 comments on commit e654d19

Please sign in to comment.