diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index a5e02aab9..616ce0a72 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -323,6 +323,8 @@ target_eval_loss: 0. # early stop once reaching target eval_loss # Goodput parameters enable_goodput_recording: False +monitor_goodput: False +goodput_upload_interval_seconds: 60 # Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md # Set to True for GCE, False if running via XPK diff --git a/MaxText/train.py b/MaxText/train.py index d565793b9..3f38869ed 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -64,6 +64,7 @@ from layers import quantizations from ml_goodput_measurement import goodput +from ml_goodput_measurement import monitoring Transformer = models.Transformer EPS = 1e-8 @@ -597,6 +598,16 @@ def main(argv: Sequence[str]) -> None: if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): vertex_tensorboard_manager.configure_vertex_tensorboard(config) + if config.monitor_goodput and jax.process_index == 0: + logger_name = f'goodput_{config.run_name}' + goodput_monitor = monitoring.GoodputMonitor( + job_name=config.run_name, + logger_name=logger_name, + tensorboard_dir=config.tensorboard_dir, + upload_interval=config.goodput_upload_interval_seconds, + monitoring_enabled=True + ) + goodput_monitor.start_goodput_uploader() debug_config = debug_configuration.DebugConfig( stack_trace_config=stack_trace_configuration.StackTraceConfig( collect_stack_trace=config.collect_stack_trace,