Skip to content

Commit

Permalink
Adding Tokens/s/device to the log.
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyjohnchen committed Jul 10, 2024
1 parent 0af4ee2 commit cae41ec
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def get_train_input_output_trees(func, input_args, input_kwargs):
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree)
return p_train_step

def calculate_tokens_training_per_device(config):
"""Calculate training Tokens per device"""
return config.max_target_length * config.per_device_batch_size

def calculate_tflops_training_per_device(config, log=True):
"""Calculate training TFLOP"""
Expand Down
8 changes: 6 additions & 2 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def load_next_batch(train_iter, example_batch, config):
return next(train_iter)


def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr, per_device_tokens):
"""Records scalar metrics to be written to tensorboard"""
metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()})
metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops})
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
metrics["scalar"].update({"perf/per_device_tokens": per_device_tokens})
metrics["scalar"].update({"perf/per_device_tokens_per_sec": per_device_tokens / step_time_delta.total_seconds()})
metrics["scalar"].update({"learning/current_learning_rate": lr})


Expand Down Expand Up @@ -147,6 +149,7 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

Expand Down Expand Up @@ -483,6 +486,7 @@ def train_loop(config, state=None):
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
Expand Down Expand Up @@ -542,7 +546,7 @@ def train_loop(config, state=None):
state, metrics = p_train_step(state, example_batch, nextrng)

new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
last_step_completion = new_time

if checkpoint_manager is not None:
Expand Down

0 comments on commit cae41ec

Please sign in to comment.