diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 599e4ab0e56c..910fa4728d69 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -652,10 +652,11 @@ def _shard_map_lowering_shardy( sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) args = (*ctx.dim_var_values, *in_nodes) - manual_axes = sub_ctx.axis_context.manual_axes - mesh_shape = mesh.shape - manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) - if manual_axes_size == 1: + # The order of manual axes should match the order of mesh.axis_names to avoid + # non-determinism issues. + manual_axes = [a for a in mesh.axis_names + if a in sub_ctx.axis_context.manual_axes] + if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with core.extend_axis_env_nd(tuple(mesh.shape.items())): out_nodes, _ = mlir.jaxpr_subcomp( diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 74fdb7a47888..ec846a32a903 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1925,7 +1925,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_propagate_through(self): - mesh = jtu.create_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k')) sharding = jax.sharding.NamedSharding(mesh, P('i')) def g(x): @@ -1943,16 +1943,17 @@ def f(x): )(x) v = jnp.arange(32.0).reshape(4, 8) - v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i'))) + v = jax.device_put(v, sharding) if config.use_shardy_partitioner.value: self.assertIn( 'in_shardings=[<@mesh, [{?}, {?}]>]' - ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}', + ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}', f.lower(v).as_text(), ) else: self.assertIn( - 'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}', + 'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,' + ' replicated}}', f.lower(v).as_text('hlo'), ) actual = f(v)