diff --git a/MaxText/train.py b/MaxText/train.py index ab3dd5fcc..6b37faed8 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -295,7 +295,7 @@ def train_loop(config, state=None): # Start profiling at end of first step to avoid compilation. # Move before for loop to include. - if step == config.steps - 2: + if step == 0: max_utils.activate_profiler(config) max_utils.deactivate_profiler(config)