Skip to content

Commit

Permalink
profile every step
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Nov 14, 2023
1 parent 00c27ea commit 868ce45
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,12 @@ def train_loop(config, state=None):

# Start profiling at end of first step to avoid compilation.
# Move before for loop to include.
if step == 0:
max_utils.activate_profiler(config)
# if step == 0:
# max_utils.activate_profiler(config)
if jax.process_index() == 0 and config.enable_profiler:
if step > 0:
jax.profiler.stop_trace()
jax.profiler.start_trace(os.path.join(config.tensorboard_dir, str(step)))

max_utils.deactivate_profiler(config)
writer.close()
Expand Down

0 comments on commit 868ce45

Please sign in to comment.