Skip to content

Commit

Permalink
disable step time logging
Browse files Browse the repository at this point in the history
  • Loading branch information
anfals committed Aug 20, 2024
1 parent 8cbea66 commit f9ccbc7
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):

full_log = step % config.log_period == 0

max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)
# max_logging.log(
# f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
# f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
# f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
# f"total_weights: {metrics['scalar']['learning/total_weights']}, "
# f"loss: {metrics['scalar']['learning/loss']:.3f}"
# )

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
Expand Down Expand Up @@ -642,9 +642,9 @@ def map_fn(key_path, value):
new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
step_time_delta = new_time - last_step_completion
max_logging.log(f"completed step: {step}, seconds: {step_time_delta.total_seconds()}, "
f"TFLOP/s/device: {per_device_tflops / step_time_delta.total_seconds()}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}")
# max_logging.log(f"completed step: {step}, seconds: {step_time_delta.total_seconds()}, "
# f"TFLOP/s/device: {per_device_tflops / step_time_delta.total_seconds()}, "
# f"loss: {metrics['scalar']['learning/loss']:.3f}")
last_step_completion = new_time

if checkpoint_manager is not None:
Expand Down

0 comments on commit f9ccbc7

Please sign in to comment.