Skip to content

Commit

Permalink
Sort manual axes when lowering jax.shard_map to `sdy.manual_computa…
Browse files Browse the repository at this point in the history
…tion`, which ensures the determinism in the generated `sdy.manual_computation`.

PiperOrigin-RevId: 712664646
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Jan 7, 2025
1 parent 00c363e commit 8aafae5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 5 additions & 4 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,11 @@ def _shard_map_lowering_shardy(
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
args = (*ctx.dim_var_values, *in_nodes)

manual_axes = sub_ctx.axis_context.manual_axes
mesh_shape = mesh.shape
manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes])
if manual_axes_size == 1:
# The order of manual axes should match the order of mesh.axis_names to avoid
# non-determinism issues.
manual_axes = [a for a in mesh.axis_names
if a in sub_ctx.axis_context.manual_axes]
if np.prod([mesh.shape[a] for a in manual_axes]) == 1:
# No need for a `ManualComputationOp` if all manual axes are size 1.
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
out_nodes, _ = mlir.jaxpr_subcomp(
Expand Down
9 changes: 5 additions & 4 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def f(x):
self.assertAllClose(v*v, f(v), check_dtypes=False)

def test_partial_auto_propagate_through(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k'))
sharding = jax.sharding.NamedSharding(mesh, P('i'))

def g(x):
Expand All @@ -1943,16 +1943,17 @@ def f(x):
)(x)

v = jnp.arange(32.0).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i')))
v = jax.device_put(v, sharding)
if config.use_shardy_partitioner.value:
self.assertIn(
'in_shardings=[<@mesh, [{?}, {?}]>]'
' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}',
' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}',
f.lower(v).as_text(),
)
else:
self.assertIn(
'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}',
'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,'
' replicated}}',
f.lower(v).as_text('hlo'),
)
actual = f(v)
Expand Down

0 comments on commit 8aafae5

Please sign in to comment.