Skip to content

Commit

Permalink
[export] Expand exporting to work with AbstractMesh.
Browse files Browse the repository at this point in the history
This is a follow up from jax-ml#25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
  • Loading branch information
gnecula committed Dec 16, 2024
1 parent c7d1c3d commit afcb62e
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 43 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
* from {mod}`jax.numpy`: `round_`.

* New Features
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).

## jax 0.4.37 (Dec 9, 2024)

This is a patch release of jax 0.4.36. Only "jax" was released at this version.
Expand Down
51 changes: 38 additions & 13 deletions docs/export/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ present on the exporting machine:

```

There is a safety check that will be raise an error when trying to compile
There is a safety check that will raise an error when trying to compile
an `Exported` object on a machine that does not have the accelerator
for which the code was exported.

Expand Down Expand Up @@ -326,7 +326,7 @@ combinations of input shapes.

See the {ref}`shape_poly` documentation.

## Device polymorphic export
## Device-polymorphic export

An exported artifact may contain sharding annotations for inputs,
outputs and for some intermediates, but these annotations do not refer
Expand All @@ -335,20 +335,28 @@ Instead, the sharding annotations refer to logical devices. This
means that you can compile and run the exported artifacts on different
physical devices that were used for exporting.

The cleanest way to achieve a device-polymorphic export is to
use shardings constructed with a `jax.sharding.AbstractMesh`,
which contains only the mesh shape and axis names. But,
you can achieve the same results if you use shardings
constructed for a mesh with concrete devices, since the actual
devices in the mesh are ignored for tracing and lowering:

```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # Use an AbstractMesh for exporting
>>> export_mesh = AbstractMesh((("a", 4),))

>>> # Use the first 4 devices for exporting.
>>> export_devices = jax.local_devices()[:4]
>>> export_mesh = Mesh(export_devices, ("a",))
>>> def f(x):
... return x.T

>>> arg = jnp.arange(8 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))

>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
Expand All @@ -359,8 +367,20 @@ physical devices that were used for exporting.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)

>>> # You can also use a concrete set of devices for exporting
>>> concrete_devices = jax.local_devices()[:4]
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
>>> exp2 = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
... sharding=NamedSharding(concrete_mesh, P("a"))))

>>> # You can expect the same results
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo

>>> # When you call an Exported, you must use a concrete set of devices
>>> arg = jnp.arange(8 * 4)
>>> res1 = exp.call(jax.device_put(arg,
... NamedSharding(export_mesh, P("a"))))
... NamedSharding(concrete_mesh, P("a"))))

>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
Expand Down Expand Up @@ -397,9 +417,11 @@ of devices than it was exported for:
>>> def f(x):
... return x.T

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
Expand All @@ -420,13 +442,16 @@ artifacts using a new mesh constructed at the call site:
>>> def f(x):
... return x.T

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)

>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> arg = jnp.arange(4 * len(export_devices))
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg)

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def _export_lowered(
jaxpr: core.ClosedJaxpr,
fun_name: str,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
_device_assignment_for_internal_jax2tf_use_only = None,
_device_assignment_for_internal_jax2tf_use_only=None,
) -> Exported:
version = config.jax_export_calling_convention_version.value
if (version < minimum_supported_calling_convention_version or
Expand Down Expand Up @@ -698,7 +698,7 @@ def _export_lowered(
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
unordered_effects = tuple(lowering.compile_args["unordered_effects"])

nr_devices = len(lowering.compile_args["device_assignment"])
nr_devices = lowering.compile_args["num_devices"]
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> HloSharding | None:
if isinstance(s, sharding_impls.UnspecifiedValue):
Expand Down Expand Up @@ -971,7 +971,8 @@ def _check_lowering(lowering) -> None:
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
"pgle_profiler", "intermediate_shardings", "context_mesh"}
"pgle_profiler", "intermediate_shardings", "context_mesh",
"num_devices"}
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
Expand Down
57 changes: 32 additions & 25 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def jaxpr_transfer_mem_kinds(
return out


def are_all_shardings_default_mem_kind(da_object, shardings):
def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings):
if da_object is None:
return True
try:
Expand Down Expand Up @@ -2084,38 +2084,32 @@ def write(var, val):


def _get_num_devices(
shardings, device_assignment, lowering_platforms, prim_requires_devices
) -> tuple[int, tuple[xc.Device, ...] | None]:
ext_abstract_mesh, concrete_sharding = None, False
shardings, device_assignment
) -> tuple[int, tuple[xc.Device, ...] | None]:
"""Number of lowering devices, and the device_assignment to use.
If all the specified shardings have an abstract mesh, then we are compiling
with abstract devices, and the returned device_assignment is None.
"""
abstract_mesh, any_concrete_sharding = None, False
for s in shardings:
if isinstance(s, UnspecifiedValue):
continue
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
if abstract_mesh is not None and abstract_mesh != s.mesh:
raise ValueError("AbstractMesh should be the same across all "
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
ext_abstract_mesh = s.mesh
f"shardings. Got {abstract_mesh} and {s.mesh}")
abstract_mesh = s.mesh
else:
concrete_sharding = True
if (concrete_sharding and ext_abstract_mesh is not None and
len(device_assignment) != ext_abstract_mesh.size):
any_concrete_sharding = True
if (any_concrete_sharding and abstract_mesh is not None and
len(device_assignment) != abstract_mesh.size):
raise ValueError(
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
f"AbstractMesh size: {abstract_mesh.size} does not match the"
f" device assignment size: {len(device_assignment)}")
if concrete_sharding:
if any_concrete_sharding or abstract_mesh is None:
return len(device_assignment), device_assignment
if ext_abstract_mesh is None:
return len(device_assignment), device_assignment
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via"
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
" only AbstractMesh exists in a jitted computation.")
if prim_requires_devices:
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")
return ext_abstract_mesh.size, None
return abstract_mesh.size, None


MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
Expand Down Expand Up @@ -2269,7 +2263,17 @@ def lower_sharding_computation(
num_devices, device_assignment = _get_num_devices( # type: ignore
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings),
device_assignment, lowering_platforms, prim_requires_devices)
device_assignment)
if device_assignment is None:
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via jax.export or "
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
" only AbstractMesh exists in a jitted computation.")
if prim_requires_devices:
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")

committed = bool(
devices_from_context
Expand Down Expand Up @@ -2349,6 +2353,7 @@ def lower_sharding_computation(
mut=mut,
backend=backend,
device_assignment=da_object,
num_devices=num_devices,
committed=committed,
in_layouts=in_layouts,
out_layouts=out_layouts,
Expand Down Expand Up @@ -2874,6 +2879,7 @@ def from_hlo(name: str,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
compiler_options_kvs: tuple[tuple[str, Any], ...],
num_devices: int,
pmap_nreps: int = 1,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
Expand All @@ -2883,6 +2889,7 @@ def from_hlo(name: str,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: Mesh | None = None,
) -> MeshExecutable:
del num_devices # For compilation, we have an actual device_assignment
if (device_assignment is None or
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
for s in it.chain(in_shardings, out_shardings))):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class AbstractMesh:
It does not contain concrete devices compared to `jax.sharding.Mesh`. You
should use this as an input to the sharding passed to with_sharding_constraint
and mesh passed to shard_map to avoid tracing and lowering cache misses when
your mesh shape and names stay the same but the devices change.
your mesh shape and axis names stay the same but the devices change.
See the description of https://github.com/jax-ml/jax/pull/23022 for more
details.
"""
Expand Down
30 changes: 29 additions & 1 deletion tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,35 @@ def test_input_shardings_unused_args(self):
self.assertEqual(res.addressable_shards[0].device, run_devices[0])
self.assertEqual(res.addressable_shards[1].device, run_devices[1])

def test_call_with_different_no_of_devices(self):
def test_export_abstract_mesh(self):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")

abs_mesh = jax.sharding.AbstractMesh((("x", 2),))
input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None))
output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x"))
@jax.jit
def f(a):
b = a @ a.T
return jax.lax.with_sharding_constraint(b, output_sharding)

exp = get_exported(f)(
jax.ShapeDtypeStruct((16, 16), dtype=np.float32,
sharding=input_sharding))
# Call the Exported with a concrete Mesh
devices = jax.local_devices()[:2]
run_mesh = Mesh(devices, ("x",))
a_sharding = jax.sharding.NamedSharding(run_mesh, P("x", None))
a = jnp.arange(16 * 16, dtype=np.float32).reshape((16, 16))
a = jax.device_put(a, a_sharding)

res = exp.call(a)
self.assertAllClose(res, f(a))
self.assertLen(res.addressable_shards, 2)
self.assertEqual(res.addressable_shards[0].index, (slice(None), slice(0, 8)))
self.assertEqual(res.addressable_shards[1].index, (slice(None), slice(8, 16)))

def test_call_single_device_export_with_different_no_of_devices(self):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")

Expand Down

0 comments on commit afcb62e

Please sign in to comment.