Skip to content

Commit

Permalink
Merge pull request #834 from google:integrate-badput-monitor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676236509
  • Loading branch information
maxtext authors committed Sep 19, 2024
2 parents 46d704a + d82ec42 commit c2de646
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,15 @@ def create_goodput_recorder(config):
return None


def record_goodput(recorder, config, step=None, job_start=False, job_end=False):
def record_goodput(
recorder,
config,
record_func,
*args,
):
"""Record data for Goodput and Badput computation."""
if recorder and config.enable_goodput_recording:
if job_start and step is None:
recorder.record_job_start_time()
if job_end and step is None:
recorder.record_job_end_time()
if step is not None:
recorder.record_step_start_time(step)
record_func(*args)

def check_example_batch(config, example_batch):
if config.max_checkify:
Expand Down Expand Up @@ -511,7 +512,11 @@ def setup_train_loop(config):
data_iterator:
state: the initialized train state
"""
recorder = create_goodput_recorder(config)
record_goodput(recorder, config, recorder.record_tpu_init_start_time if recorder else None)
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config)
record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None)
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None)
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)

state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(
Expand All @@ -521,7 +526,7 @@ def setup_train_loop(config):
if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02)

record_goodput(recorder, config, recorder.record_training_preparation_end_time if recorder else None)
return (
init_rng,
writer,
Expand All @@ -546,7 +551,7 @@ def train_loop(config, state=None):
"""
# Create a GoodputRecorder to log information
recorder = create_goodput_recorder(config)
record_goodput(recorder, config, job_start=True)
record_goodput(recorder, config, recorder.record_job_start_time if recorder else None)

(
init_rng,
Expand Down Expand Up @@ -634,10 +639,12 @@ def train_loop(config, state=None):
prof.activate()

with jax.profiler.StepTraceAnnotation("train", step_num=step):
record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None)
example_batch = load_next_batch(data_iterator, example_batch, config)
record_goodput(recorder, config, recorder.record_data_loading_end_time if recorder else None)
check_example_batch(config, example_batch=example_batch)
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
record_goodput(recorder, config, step=step)
record_goodput(recorder, config, recorder.record_step_start_time if recorder else None, step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

Expand Down Expand Up @@ -693,7 +700,7 @@ def train_loop(config, state=None):
checkpoint_manager.wait_until_finished()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
max_utils.close_summary_writer(writer)
record_goodput(recorder, config, job_end=True)
record_goodput(recorder, config, recorder.record_job_end_time if recorder else None)
clear_buffered_metrics()
return state

Expand All @@ -719,7 +726,8 @@ def main(argv: Sequence[str]) -> None:
logger_name=logger_name,
tensorboard_dir=config.tensorboard_dir,
upload_interval=config.goodput_upload_interval_seconds,
monitoring_enabled=True
monitoring_enabled=True,
include_badput_breakdown=True,
)
goodput_monitor.start_goodput_uploader()
max_logging.log("Started Goodput upload to Tensorboard in the background!")
Expand Down

0 comments on commit c2de646

Please sign in to comment.