Skip to content

Commit 6a1da81

Browse files
authored
Fix API usage of jax.make_mesh (#1051)
Signed-off-by: Lihao Ran <[email protected]>
1 parent e05d4b1 commit 6a1da81

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

examples/disagg/multi_proc_per_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def get_uuid(pid: int) -> int:
8080

8181
def get_mesh() -> Mesh:
8282
sharding_size = jax.device_count()
83-
return jax.make_mesh((sharding_size, ), ("model", ))
83+
return jax.make_mesh(
84+
(sharding_size, ),
85+
("model", ),
86+
axis_types=(jax.sharding.AxisType.Auto, ) * len(("model", )),
87+
)
8488

8589

8690
def get_kv_shape(mesh: Mesh, per_shard: bool = False) -> tuple[int, ...]:

examples/disagg/single_proc_per_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def get_uuid() -> int:
7474

7575
def get_mesh() -> Mesh:
7676
sharding_size = jax.device_count()
77-
return jax.make_mesh((sharding_size, ), ("model", ))
77+
return jax.make_mesh(
78+
(sharding_size, ),
79+
("model", ),
80+
axis_types=(jax.sharding.AxisType.Auto, ) * len(("model", )),
81+
)
7882

7983

8084
def get_kv_spec(mesh: Mesh) -> list[int]:

0 commit comments

Comments
 (0)