Skip to content

Commit

Permalink
Remove dead codepaths now that MemorySpaceDescription works in OSS
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708016987
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 14, 2025
1 parent c72ed26 commit 163e7a3
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,9 +2368,6 @@ def lower_sharding_computation(
out_layouts=out_layouts,
pmap_nreps=nreps,
shape_poly_state=shape_poly_state,
# TODO(yashkatariya): Remove `all_default_mem_kind` after
# MemoryDescription works in OSS.
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=unique_intermediate_shardings,
Expand Down Expand Up @@ -2442,21 +2439,15 @@ def get_out_shardings_from_executable(
device_assignment: Sequence[xc.Device],
num_out_avals: int,
num_ordered_effects: int,
all_default_mem_kind: bool,
) -> Sequence[sharding_impls.GSPMDSharding] | None:
from jax._src import pjit

# TODO(yashkatariya): Remove `all_default_mem_kind` branch after
# MemoryDescription works in OSS.
if all_default_mem_kind:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
else:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals

assert len(omk) == num_out_avals, (len(omk), num_out_avals)

Expand Down Expand Up @@ -2781,11 +2772,11 @@ def _maybe_get_and_check_in_shardings(

def _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, device_assignment, global_out_avals,
num_ordered_effects, all_default_mem_kind
num_ordered_effects
):
out_shardings_xla = get_out_shardings_from_executable(
xla_executable, device_assignment, len(global_out_avals),
num_ordered_effects, all_default_mem_kind)
num_ordered_effects)
if out_shardings_xla is None:
return out_shardings

Expand Down Expand Up @@ -2893,7 +2884,6 @@ def from_hlo(name: str,
pmap_nreps: int = 1,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
Expand Down Expand Up @@ -2951,7 +2941,7 @@ def from_hlo(name: str,
len(ordered_effects))
out_shardings = _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
len(ordered_effects))
else:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
Expand Down

0 comments on commit 163e7a3

Please sign in to comment.