diff --git a/MaxText/train.py b/MaxText/train.py index bfdf8e6d8..3a51d6a16 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -312,7 +312,7 @@ def train_loop(config, state=None): # Start profiling at end of first step to avoid compilation. # Move before for loop to include. - if step == 0: + if step == config.steps - 2: max_utils.activate_profiler(config) max_utils.deactivate_profiler(config)