diff --git a/paxml/tasks/lm/params/lm_cloud.py b/paxml/tasks/lm/params/lm_cloud.py index 93cb11825..f2b560b8b 100644 --- a/paxml/tasks/lm/params/lm_cloud.py +++ b/paxml/tasks/lm/params/lm_cloud.py @@ -37,7 +37,7 @@ def _dataset_common( self, is_training ) -> pax_fiddle.Config[base_input.BaseInput]: num_local_devices = jax.local_device_count() - batch_size = round(self.PERCORE_BATCH_SIZE * num_local_devices) + batch_size = round(self.PERCORE_BATCH_SIZE * num_local_devices * jax.process_count()) input_p = input_generator.SyntheticLmData.HParams() if is_training: input_p.batch_size = batch_size