diff --git a/docs/config.qmd b/docs/config.qmd index f01a2ce267..120aec8933 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -363,6 +363,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] +profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir. + # see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information + # snapshots can be visualized @ https://pytorch.org/memory_viz + loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e655..0f30f511c2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -65,6 +65,7 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory +from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -1363,6 +1364,13 @@ def get_callbacks(self) -> List[TrainerCallback]: plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) ) + if self.cfg.profiler_steps: + callbacks.append( + PytorchProfilerCallback( + steps_to_profile=self.cfg.profiler_steps, + ) + ) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py new file mode 100644 index 0000000000..8616963323 --- /dev/null +++ b/src/axolotl/utils/callbacks/profiler.py @@ -0,0 +1,43 @@ +""" +HF Trainer callback for creating pytorch profiling snapshots +""" +from pathlib import Path +from pickle import dump # nosec B403 + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class PytorchProfilerCallback(TrainerCallback): + """ + PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. + """ + + def __init__(self, steps_to_profile: int = 5): + self.steps_to_profile = steps_to_profile + if self.steps_to_profile: + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled="all" + ) + + def on_step_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + if state.global_step == self.steps_to_profile: + snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access + with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: + dump(snapshot, fout) + + # tell CUDA to stop recording memory allocations now + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled=None + ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3671e1bb93..d05de2330d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -762,6 +762,7 @@ class Config: load_best_model_at_end: Optional[bool] = False save_only_model: Optional[bool] = False use_tensorboard: Optional[bool] = None + profiler_steps: Optional[int] = None neftune_noise_alpha: Optional[float] = None