diff --git a/CHANGELOG.md b/CHANGELOG.md index 7202e520d28e..ec33122064b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`. * from {mod}`jax.numpy`: `round_`. +* New Features + * {func}`jax.export.export` can be used for device-polymorphic export with + shardings constructed with {func}`jax.sharding.AbstractMesh`. + See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + ## jax 0.4.37 (Dec 9, 2024) This is a patch release of jax 0.4.36. Only "jax" was released at this version. diff --git a/docs/export/export.md b/docs/export/export.md index aa686b03e2b2..bab4723f3ebc 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -240,7 +240,7 @@ present on the exporting machine: ``` -There is a safety check that will be raise an error when trying to compile +There is a safety check that will raise an error when trying to compile an `Exported` object on a machine that does not have the accelerator for which the code was exported. @@ -326,7 +326,7 @@ combinations of input shapes. See the {ref}`shape_poly` documentation. -## Device polymorphic export +## Device-polymorphic export An exported artifact may contain sharding annotations for inputs, outputs and for some intermediates, but these annotations do not refer @@ -335,20 +335,28 @@ Instead, the sharding annotations refer to logical devices. This means that you can compile and run the exported artifacts on different physical devices that were used for exporting. +The cleanest way to achieve a device-polymorphic export is to +use shardings constructed with a `jax.sharding.AbstractMesh`, +which contains only the mesh shape and axis names. But, +you can achieve the same results if you use shardings +constructed for a mesh with concrete devices, since the actual +devices in the mesh are ignored for tracing and lowering: + ```python >>> import jax >>> from jax import export ->>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding >>> from jax.sharding import PartitionSpec as P +>>> +>>> # Use an AbstractMesh for exporting +>>> export_mesh = AbstractMesh((("a", 4),)) ->>> # Use the first 4 devices for exporting. ->>> export_devices = jax.local_devices()[:4] ->>> export_mesh = Mesh(export_devices, ("a",)) >>> def f(x): ... return x.T ->>> arg = jnp.arange(8 * len(export_devices)) ->>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) +>>> exp = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((32,), dtype=np.int32, +... sharding=NamedSharding(export_mesh, P("a")))) >>> # `exp` knows for how many devices it was exported. >>> exp.nr_devices @@ -359,8 +367,20 @@ physical devices that were used for exporting. >>> exp.in_shardings_hlo ({devices=[4]<=[4]},) +>>> # You can also use a concrete set of devices for exporting +>>> concrete_devices = jax.local_devices()[:4] +>>> concrete_mesh = Mesh(concrete_devices, ("a",)) +>>> exp2 = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((32,), dtype=np.int32, +... sharding=NamedSharding(concrete_mesh, P("a")))) + +>>> # You can expect the same results +>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo + +>>> # When you call an Exported, you must use a concrete set of devices +>>> arg = jnp.arange(8 * 4) >>> res1 = exp.call(jax.device_put(arg, -... NamedSharding(export_mesh, P("a")))) +... NamedSharding(concrete_mesh, P("a")))) >>> # Check out the first 2 shards of the result >>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]] @@ -397,9 +417,11 @@ of devices than it was exported for: >>> def f(x): ... return x.T ->>> arg = jnp.arange(4 * len(export_devices)) ->>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) +>>> exp = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32, +... sharding=NamedSharding(export_mesh, P("a")))) +>>> arg = jnp.arange(4 * len(export_devices)) >>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device. @@ -420,13 +442,16 @@ artifacts using a new mesh constructed at the call site: >>> def f(x): ... return x.T ->>> arg = jnp.arange(4 * len(export_devices)) ->>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> exp = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32, +... sharding=NamedSharding(export_mesh, P("a")))) >>> # Prepare the mesh for calling `exp`. >>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",)) >>> # Shard the arg according to what `exp` expects. +>>> arg = jnp.arange(4 * len(export_devices)) >>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0]) >>> res = exp.call(sharded_arg) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b20ca4670163..8ba43083d7a3 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -633,7 +633,7 @@ def _export_lowered( jaxpr: core.ClosedJaxpr, fun_name: str, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, ) -> Exported: version = config.jax_export_calling_convention_version.value if (version < minimum_supported_calling_convention_version or @@ -698,7 +698,7 @@ def _export_lowered( ordered_effects = tuple(lowering.compile_args["ordered_effects"]) unordered_effects = tuple(lowering.compile_args["unordered_effects"]) - nr_devices = len(lowering.compile_args["device_assignment"]) + nr_devices = lowering.compile_args["num_devices"] def export_sharding(s: LoweringSharding, aval: core.ShapedArray) -> HloSharding | None: if isinstance(s, sharding_impls.UnspecifiedValue): @@ -971,7 +971,8 @@ def _check_lowering(lowering) -> None: "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info", - "pgle_profiler", "intermediate_shardings", "context_mesh"} + "pgle_profiler", "intermediate_shardings", "context_mesh", + "num_devices"} for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c80475e181c0..840b9e522968 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1999,7 +1999,7 @@ def jaxpr_transfer_mem_kinds( return out -def are_all_shardings_default_mem_kind(da_object, shardings): +def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings): if da_object is None: return True try: @@ -2084,38 +2084,32 @@ def write(var, val): def _get_num_devices( - shardings, device_assignment, lowering_platforms, prim_requires_devices - ) -> tuple[int, tuple[xc.Device, ...] | None]: - ext_abstract_mesh, concrete_sharding = None, False + shardings, device_assignment + ) -> tuple[int, tuple[xc.Device, ...] | None]: + """Number of lowering devices, and the device_assignment to use. + + If all the specified shardings have an abstract mesh, then we are compiling + with abstract devices, and the returned device_assignment is None. + """ + abstract_mesh, any_concrete_sharding = None, False for s in shardings: if isinstance(s, UnspecifiedValue): continue elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): - if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh: + if abstract_mesh is not None and abstract_mesh != s.mesh: raise ValueError("AbstractMesh should be the same across all " - f"shardings. Got {ext_abstract_mesh} and {s.mesh}") - ext_abstract_mesh = s.mesh + f"shardings. Got {abstract_mesh} and {s.mesh}") + abstract_mesh = s.mesh else: - concrete_sharding = True - if (concrete_sharding and ext_abstract_mesh is not None and - len(device_assignment) != ext_abstract_mesh.size): + any_concrete_sharding = True + if (any_concrete_sharding and abstract_mesh is not None and + len(device_assignment) != abstract_mesh.size): raise ValueError( - f"AbstractMesh size: {ext_abstract_mesh.size} does not match the" + f"AbstractMesh size: {abstract_mesh.size} does not match the" f" device assignment size: {len(device_assignment)}") - if concrete_sharding: + if any_concrete_sharding or abstract_mesh is None: return len(device_assignment), device_assignment - if ext_abstract_mesh is None: - return len(device_assignment), device_assignment - if lowering_platforms is None: - raise ValueError( - "Passing lowering_platforms via" - " jit(f).trace(*args).lower(lowering_platforms=...) is required when" - " only AbstractMesh exists in a jitted computation.") - if prim_requires_devices: - raise ValueError( - "AbstractMesh cannot be used when jaxpr contains primitives that" - " require devices to be present during lowering.") - return ext_abstract_mesh.size, None + return abstract_mesh.size, None MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] @@ -2269,7 +2263,17 @@ def lower_sharding_computation( num_devices, device_assignment = _get_num_devices( # type: ignore it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings), - device_assignment, lowering_platforms, prim_requires_devices) + device_assignment) + if device_assignment is None: + if lowering_platforms is None: + raise ValueError( + "Passing lowering_platforms via jax.export or " + " jit(f).trace(*args).lower(lowering_platforms=...) is required when" + " only AbstractMesh exists in a jitted computation.") + if prim_requires_devices: + raise ValueError( + "AbstractMesh cannot be used when jaxpr contains primitives that" + " require devices to be present during lowering.") committed = bool( devices_from_context @@ -2349,6 +2353,7 @@ def lower_sharding_computation( mut=mut, backend=backend, device_assignment=da_object, + num_devices=num_devices, committed=committed, in_layouts=in_layouts, out_layouts=out_layouts, @@ -2874,6 +2879,7 @@ def from_hlo(name: str, in_layouts: MaybeLayout, out_layouts: MaybeLayout, compiler_options_kvs: tuple[tuple[str, Any], ...], + num_devices: int, pmap_nreps: int = 1, mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, @@ -2883,6 +2889,7 @@ def from_hlo(name: str, intermediate_shardings: Sequence[JSharding] | None = None, context_mesh: Mesh | None = None, ) -> MeshExecutable: + del num_devices # For compilation, we have an actual device_assignment if (device_assignment is None or any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) for s in it.chain(in_shardings, out_shardings))): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 0ce882626fb6..2351266c6c24 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -368,7 +368,7 @@ class AbstractMesh: It does not contain concrete devices compared to `jax.sharding.Mesh`. You should use this as an input to the sharding passed to with_sharding_constraint and mesh passed to shard_map to avoid tracing and lowering cache misses when - your mesh shape and names stay the same but the devices change. + your mesh shape and axis names stay the same but the devices change. See the description of https://github.com/jax-ml/jax/pull/23022 for more details. """ diff --git a/tests/export_test.py b/tests/export_test.py index 2946854aa549..da0e9daf2f00 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1156,7 +1156,35 @@ def test_input_shardings_unused_args(self): self.assertEqual(res.addressable_shards[0].device, run_devices[0]) self.assertEqual(res.addressable_shards[1].device, run_devices[1]) - def test_call_with_different_no_of_devices(self): + def test_export_abstract_mesh(self): + if jax.local_device_count() < 2: + self.skipTest("Need at least 2 devices") + + abs_mesh = jax.sharding.AbstractMesh((("x", 2),)) + input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None)) + output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x")) + @jax.jit + def f(a): + b = a @ a.T + return jax.lax.with_sharding_constraint(b, output_sharding) + + exp = get_exported(f)( + jax.ShapeDtypeStruct((16, 16), dtype=np.float32, + sharding=input_sharding)) + # Call the Exported with a concrete Mesh + devices = jax.local_devices()[:2] + run_mesh = Mesh(devices, ("x",)) + a_sharding = jax.sharding.NamedSharding(run_mesh, P("x", None)) + a = jnp.arange(16 * 16, dtype=np.float32).reshape((16, 16)) + a = jax.device_put(a, a_sharding) + + res = exp.call(a) + self.assertAllClose(res, f(a)) + self.assertLen(res.addressable_shards, 2) + self.assertEqual(res.addressable_shards[0].index, (slice(None), slice(0, 8))) + self.assertEqual(res.addressable_shards[1].index, (slice(None), slice(8, 16))) + + def test_call_single_device_export_with_different_no_of_devices(self): if jax.local_device_count() < 2: self.skipTest("Need at least 2 devices")