From b28a69127ed6dea50538fef53bd4110e438822f5 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Fri, 4 Aug 2023 16:29:53 -0700 Subject: [PATCH] support for fractional per core batch size --- paxml/tasks/lm/params/lm_cloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paxml/tasks/lm/params/lm_cloud.py b/paxml/tasks/lm/params/lm_cloud.py index 93cb11825..f2b560b8b 100644 --- a/paxml/tasks/lm/params/lm_cloud.py +++ b/paxml/tasks/lm/params/lm_cloud.py @@ -37,7 +37,7 @@ def _dataset_common( self, is_training ) -> pax_fiddle.Config[base_input.BaseInput]: num_local_devices = jax.local_device_count() - batch_size = round(self.PERCORE_BATCH_SIZE * num_local_devices) + batch_size = round(self.PERCORE_BATCH_SIZE * num_local_devices * jax.process_count()) input_p = input_generator.SyntheticLmData.HParams() if is_training: input_p.batch_size = batch_size