diff --git a/distributed_shampoo/examples/hybrid_shard_cifar10_example.py b/distributed_shampoo/examples/hybrid_shard_cifar10_example.py index 4c94dcd..d006111 100644 --- a/distributed_shampoo/examples/hybrid_shard_cifar10_example.py +++ b/distributed_shampoo/examples/hybrid_shard_cifar10_example.py @@ -199,7 +199,7 @@ def create_model_and_optimizer_and_loss_fn( # initialize device_mesh for hybrid shard data parallel device_mesh: DeviceMesh = init_device_mesh( "cuda", - (args.dp_replicate_degree, WORLD_RANK // args.dp_replicate_degree), + (args.dp_replicate_degree, WORLD_SIZE // args.dp_replicate_degree), mesh_dim_names=("dp_replicate", "dp_shard"), )