diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 31e3b2c02..489b7a147 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -92,6 +92,9 @@ def get_train_input_output_trees(func, input_args, input_kwargs): p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) return p_train_step +def calculate_tokens_training_per_device(config): + """Calculate training Tokens per device""" + return config.max_target_length * config.per_device_batch_size def calculate_tflops_training_per_device(config, log=True): """Calculate training TFLOP""" diff --git a/MaxText/train.py b/MaxText/train.py index 3f38869ed..3a6464450 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -95,11 +95,13 @@ def load_next_batch(train_iter, example_batch, config): return next(train_iter) -def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): +def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr, per_device_tokens): """Records scalar metrics to be written to tensorboard""" metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()}) metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tokens": per_device_tokens}) + metrics["scalar"].update({"perf/per_device_tokens_per_sec": per_device_tokens / step_time_delta.total_seconds()}) metrics["scalar"].update({"learning/current_learning_rate": lr}) @@ -147,6 +149,7 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): max_logging.log( f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " f"loss: {metrics['scalar']['learning/loss']:.3f}" ) @@ -483,6 +486,7 @@ def train_loop(config, state=None): num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config) + per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config) # Write train config params, num model params, and XLA flags to tensorboard max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer) @@ -542,7 +546,7 @@ def train_loop(config, state=None): state, metrics = p_train_step(state, example_batch, nextrng) new_time = datetime.datetime.now() - record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) + record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens) last_step_completion = new_time if checkpoint_manager is not None: