diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 759fb6e21..d19e81ecd 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -4,13 +4,23 @@ import torch -def gpu_memory_usage(device): +def gpu_memory_usage(device=0): + return torch.cuda.memory_allocated(device) / 1024.0**3 + + +def gpu_memory_usage_all(device=0): + usage = torch.cuda.memory_allocated(device) / 1024.0**3 + reserved = torch.cuda.memory_reserved(device) / 1024.0**3 + smi = gpu_memory_usage_smi(device) + return usage, reserved - usage, max(0, smi - reserved) + + +def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): device = device.index if isinstance(device, str) and device.startswith("cuda:"): device = int(device[5:]) - # NB torch.cuda.memory_usage returns zero so we use lower level api pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(device) info = pynvml.nvmlDeviceGetMemoryInfo(handle) @@ -18,6 +28,16 @@ def gpu_memory_usage(device): def log_gpu_memory_usage(log, msg, device): + if not torch.cuda.is_available(): + return (0, 0, 0) + + usage, cache, misc = gpu_memory_usage_all(device) + extras = [] + if cache > 0: + extras.append(f"+{cache:.03f}GB cache") + if misc > 0: + extras.append(f"+{misc:.03f}GB misc") log.info( - f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2 + f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 ) + return usage, cache, misc diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f06762b6b..9e54d239f 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -74,14 +74,13 @@ def on_step_end( return control -class PrintGPUStatsCallback( +class GPUStatsCallback( TrainerCallback ): # pylint: disable=too-few-public-methods disable=unused-argument - """Callback to print GPU utilization""" + """Callback to track GPU utilization""" def __init__(self, cfg): self.cfg = cfg - self.logged = False def on_step_end( self, @@ -90,7 +89,31 @@ def on_step_end( control: TrainerControl, **kwargs, ): - if not self.logged: + should_log = ( + state.global_step == 1 + or (state.global_step in range(1, 100) and state.global_step % 10 == 0) + or (state.global_step > 100 and state.global_step % 100 == 0) + ) + if should_log: log_gpu_memory_usage(LOG, "while training", self.cfg.device) - self.logged = True + return control + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + log_gpu_memory_usage(LOG, "after training", self.cfg.device) + return control + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + log_gpu_memory_usage(LOG, "after eval", self.cfg.device) return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a5d2ea74e..4175e429e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -18,7 +18,7 @@ from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( - PrintGPUStatsCallback, + GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, ) @@ -293,7 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) callbacks = [] - callbacks.append(PrintGPUStatsCallback(cfg)) + callbacks.append(GPUStatsCallback(cfg)) # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback(