Skip to content

Commit

Permalink
Merge pull request #39 from DifferentiableUniverseInitiative/38-fix-i…
Browse files Browse the repository at this point in the history
…ssue-with-cyclic-import

38 fix issue with cyclic import
  • Loading branch information
ASKabalan authored Dec 6, 2024
2 parents e2a56b9 + 2ae6021 commit 81e18bf
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 38 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ else()
endif()

set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
install(TARGETS _jaxdecomp LIBRARY DESTINATION jaxdecomplib PUBLIC_HEADER DESTINATION jaxdecomplib)
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"
[project]
name = "jaxdecomp"
version = "0.2.0"
version = "0.2.1"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
Expand All @@ -23,7 +23,7 @@ classifiers = [
"Intended Audience :: Science/Research",
]
dependencies = [
"jaxtyping>=0.2.33",
"jaxtyping>=0.2.0",
"jax>=0.4.30",
]

Expand All @@ -36,11 +36,10 @@ cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""

#[tool.cibuildwheel]
#test-extras = "test"
#test-command = "pytest {project}/tests"
[tool.cibuildwheel]
test-extras = "test"
test-command = "pytest {project}/tests"
21 changes: 10 additions & 11 deletions src/jaxdecomp/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from jax.lib import xla_client

from . import _jaxdecomp
from jaxdecomplib import _jaxdecomp

init = _jaxdecomp.init
finalize = _jaxdecomp.finalize
Expand All @@ -9,15 +8,15 @@
make_config = _jaxdecomp.GridConfig

# Loading the comm configuration flags at the top level
from ._jaxdecomp import TransposeCommBackend # yapf: disable
from ._jaxdecomp import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
HALO_COMM_NVSHMEM, HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP, PENCILS, SLAB_XY, SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL, TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY, TRANSPOSE_YX,
TRANSPOSE_YZ, TRANSPOSE_ZY, HaloCommBackend)
from jaxdecomplib._jaxdecomp import HaloCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import TransposeCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import ( # dummy line to avoid yapf reformatting
HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL, HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING, NO_DECOMP, PENCILS, SLAB_XY, SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P, TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL, TRANSPOSE_COMM_NCCL_PL, TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY, TRANSPOSE_YX, TRANSPOSE_YZ,
TRANSPOSE_ZY)

# Registering ops for XLA
for name, fn in _jaxdecomp.registrations().items():
Expand Down
10 changes: 5 additions & 5 deletions src/jaxdecomp/_src/cudecomp/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from jax.core import Primitive, ShapedArray
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp
from jaxlib.hlo_helpers import custom_call
from jaxtyping import Array

import jaxdecomp
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.fft_utils import FftType, fftn
from jaxdecomp._src.pencil_utils import (get_lowering_args, get_output_specs,
get_pencil_type, get_transpose_order)
Expand Down Expand Up @@ -328,7 +328,7 @@ def pfft_impl(x: Array, fft_type: FftType, adjoint: bool) -> Array:
----------
x : Array
Input array.
fft_type : Union[str, xla_client.FftType]
fft_type : Union[str, lax.FftType]
Type of FFT operation.
adjoint : bool
Whether to compute the adjoint FFT.
Expand All @@ -355,7 +355,7 @@ def pfft(x: Array, fft_type: FftType, adjoint: bool = False) -> Primitive:
----------
x : Array
Input array.
fft_type : Union[str, xla_client.FftType]
fft_type : Union[str, lax.FftType]
Type of FFT operation.
adjoint : bool, optional
Whether to compute the adjoint FFT. Defaults to False.
Expand All @@ -378,7 +378,7 @@ def _pfft_fwd_rule(x: Array, fft_type: FftType,
----------
x : Array
Input array.
fft_type : Union[str, xla_client.FftType]
fft_type : Union[str, lax.FftType]
Type of FFT operation.
adjoint : bool, optional
Whether to compute the adjoint FFT. Defaults to False.
Expand All @@ -398,7 +398,7 @@ def _pfft_bwd_rule(fft_type: FftType, adjoint: bool, _,
Parameters
----------
fft_type : Union[str, xla_client.FftType]
fft_type : Union[str, lax.FftType]
Type of FFT operation.
adjoint : bool
Whether to compute the adjoint FFT.
Expand Down
2 changes: 1 addition & 1 deletion src/jaxdecomp/_src/cudecomp/halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from jax.core import ShapedArray
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp
from jaxlib.hlo_helpers import custom_call

import jaxdecomp
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.spmd_ops import (BasePrimitive, get_pdims_from_mesh,
register_primitive)
from jaxdecomp.typing import GdimsType, HaloExtentType, PdimsType, Periodicity
Expand Down
2 changes: 1 addition & 1 deletion src/jaxdecomp/_src/cudecomp/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from jax.core import ShapedArray
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp
from jaxlib.hlo_helpers import custom_call

import jaxdecomp
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.spmd_ops import (BasePrimitive, get_pdims_from_sharding,
register_primitive)
from jaxdecomp.typing import GdimsType, TransposablePdimsType
Expand Down
4 changes: 2 additions & 2 deletions src/jaxdecomp/_src/fft_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from math import prod
from typing import Tuple, TypeAlias

from jax import lax
from jax import numpy as jnp
from jax.lib import xla_client
from jaxtyping import Array

FftType: TypeAlias = xla_client.FftType
FftType: TypeAlias = lax.FftType

FORWARD_FFTs = {FftType.FFT, FftType.RFFT}
INVERSE_FFTs = {FftType.IFFT, FftType.IRFFT}
Expand Down
3 changes: 1 addition & 2 deletions src/jaxdecomp/_src/jax/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array
from jax.lib import xla_client
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp

import jaxdecomp
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.fft_utils import COMPLEX # yapf: disable
from jaxdecomp._src.fft_utils import FftType # yapf: disable
from jaxdecomp._src.fft_utils import ADJOINT, FORWARD_FFTs, fft, fft2, fftn
Expand Down
4 changes: 2 additions & 2 deletions src/jaxdecomp/_src/jax/fftfreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Tuple

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.shard_alike import shard_alike
from jax.lib import xla_client
from jaxtyping import Array

FftType = xla_client.FftType
FftType = lax.FftType


@partial(jax.jit, static_argnums=(1,))
Expand Down
2 changes: 1 addition & 1 deletion src/jaxdecomp/_src/jax/halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jax.core import ShapedArray
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp
from jaxtyping import Array

from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.spmd_ops import (CustomParPrimitive, get_pencil_type,
register_primitive)
from jaxdecomp.typing import HaloExtentType, Periodicity
Expand Down
8 changes: 4 additions & 4 deletions src/jaxdecomp/_src/pencil_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional, Tuple

import jax
from jax.lib import xla_client
from jax import lax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jaxdecomplib import _jaxdecomp

import jaxdecomp
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.fft_utils import FftType, FORWARD_FFTs
from jaxdecomp._src.spmd_ops import get_pdims_from_mesh, get_pencil_type
from jaxdecomp.typing import GdimsType, PdimsType, TransposablePdimsType
Expand Down Expand Up @@ -91,13 +91,13 @@ def get_lowering_args(fft_type: FftType, global_shape: GdimsType,

if jaxdecomp.config.transpose_axis_contiguous:
match fft_type:
case xla_client.FftType.FFT:
case lax.FftType.FFT:
if pencil_type == _jaxdecomp.SLAB_XY:
transpose_back_shape = (1, 2, 0)
pdims = pdims[::-1]
else:
transpose_back_shape = (0, 1, 2)
case xla_client.FftType.IFFT:
case lax.FftType.IFFT:
if pencil_type == _jaxdecomp.SLAB_XY:
transpose_back_shape = (0, 1, 2)
pdims = pdims[::-1]
Expand Down
2 changes: 1 addition & 1 deletion src/jaxdecomp/_src/spmd_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from jax.experimental.custom_partitioning import custom_partitioning
from jax.interpreters import mlir, xla
from jax.sharding import Mesh, NamedSharding
from jaxdecomplib import _jaxdecomp

from jaxdecomp._src import _jaxdecomp
from jaxdecomp.typing import PdimsType, TransposablePdimsType

Specs = Any
Expand Down
1 change: 0 additions & 1 deletion src/jaxdecomp/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax.numpy as jnp
from jax import jit
from jax._src.typing import Array, ArrayLike
from jax.lib import xla_client

from jaxdecomp._src.cudecomp.fft import pfft as _cudecomp_pfft
from jaxdecomp._src.fft_utils import FftType
Expand Down

0 comments on commit 81e18bf

Please sign in to comment.