Skip to content

Commit

Permalink
Add a vram reporting script
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 19, 2024
1 parent 53f6f63 commit 8af9c87
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ RUN mkdir -p /packages && \
cd /packages && \
git clone https://github.com/truefoundry/axolotl && \
cd axolotl/ && \
git checkout 57167dd92567f64371286ebf56ab9ca01d0685d7 && \
git checkout 0011a3969eeceffc5140315a41d64048cca0ac30 && \
cd /packages/axolotl/ && \
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] && \
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --use-pep517 --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] && \
rm -rf /root/.cache/pip

# Install axolotl_truefoundry plugin with our requirements overrides over axolotl
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile-notebook
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ USER jovyan
RUN cd /packages && \
git clone https://github.com/truefoundry/axolotl && \
cd axolotl/ && \
git checkout 57167dd92567f64371286ebf56ab9ca01d0685d7 && \
git checkout 0011a3969eeceffc5140315a41d64048cca0ac30 && \
cd /packages/axolotl/ && \
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore]
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --use-pep517 --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore]

# Install axolotl_truefoundry plugin with our requirements overrides over axolotl
COPY --chown=jovyan:users plugins/axolotl_truefoundry /packages/axolotl_truefoundry
Expand Down
42 changes: 39 additions & 3 deletions plugins/axolotl_truefoundry/axolotl_truefoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import numpy as np
import orjson
import pynvml
import torch
from axolotl.integrations.base import BasePlugin
Expand All @@ -24,6 +25,10 @@
logger = logging.getLogger("axolotl")


def _json_dumps(data) -> str:
return orjson.dumps(data, option=orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY).decode("utf-8")


def _drop_non_finite_values(dct: Dict[str, Any]) -> Dict[str, Any]:
sanitized = {}
for k, v in dct.items():
Expand Down Expand Up @@ -68,6 +73,10 @@ def get_or_create_run(ml_repo: str, run_name: str, auto_end: bool = False):


class ExtraMetricsCallback(TrainerCallback):
def __init__(self, plugin_logger: logging.Logger):
super().__init__()
self._plugin_logger = plugin_logger

def _add_perplexity(self, logs):
for loss_key, perplexity_key in [
("loss", "train_perplexity"),
Expand All @@ -91,7 +100,8 @@ def on_log(self, args, state, control, logs, model=None, **kwargs):

self._add_perplexity(logs)
logs.update(get_gpu_metrics())
logger.info(f"Metrics: {logs}")
logger.info(f"Metrics: {_json_dumps(logs)}")
self._plugin_logger.info(_json_dumps(_drop_non_finite_values(logs)))


class TrueFoundryMLCallback(TrainerCallback):
Expand Down Expand Up @@ -192,16 +202,42 @@ class TruefoundryMLPluginArgs(BaseModel):
cleanup_output_dir_on_start: bool = False
logging_dir: str = "./tensorboard_logs"

truefoundry_testing_mode: bool = False


class TrueFoundryMLPlugin(BasePlugin):
def __init__(self):
super().__init__()
plugin_logger = logging.getLogger(__name__)
plugin_logger.setLevel(logging.INFO)
plugin_file_handler = logging.FileHandler("axolotl_truefoundry.plugin.log", "a")
plugin_file_handler.setLevel(logging.INFO)
plugin_file_handler.setFormatter(logging.Formatter("%(message)s"))
plugin_logger.addHandler(plugin_file_handler)
plugin_logger.propagate = False
self.plugin_logger = plugin_logger

def get_input_args(self):
return "axolotl_truefoundry.TruefoundryMLPluginArgs"

def post_model_load(self, cfg, model):
if not is_main_process():
return None
from peft import PeftModel

if isinstance(model, PeftModel):
trainable_params, all_params = model.get_nb_trainable_parameters()
else:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
params = {"trainable_params": trainable_params, "all_params": all_params}
self.plugin_logger.info(_json_dumps(params))

def add_callbacks_post_trainer(self, cfg: TruefoundryMLPluginArgs, trainer: Trainer) -> List[TrainerCallback]:
# Note: `cfg` is not really an instance of `TruefoundryMLPluginArgs` but a `DictDefault` object
if not is_main_process():
return []
logger.info(f"Config: {cfg}")
logger.debug(f"Config: {cfg}")
truefoundry_ml_cb = None
if cfg.truefoundry_ml_enable_reporting is True:
run = get_or_create_run(
Expand All @@ -215,7 +251,7 @@ def add_callbacks_post_trainer(self, cfg: TruefoundryMLPluginArgs, trainer: Trai
checkpoint_artifact_name=cfg.truefoundry_ml_checkpoint_artifact_name,
log_gpu_metrics=cfg.truefoundry_ml_log_gpu_metrics,
)
extra_metrics_cb = ExtraMetricsCallback()
extra_metrics_cb = ExtraMetricsCallback(plugin_logger=self.plugin_logger)
tensorboard_cb_idx = None
for i, cb in enumerate(trainer.callback_handler.callbacks):
if isinstance(cb, TensorBoardCallback):
Expand Down
1 change: 1 addition & 0 deletions plugins/axolotl_truefoundry/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"pynvml>=11.0.0,<12",
"torch>=2.3.0,<2.4.0",
"pydantic>=2.0.0,<3",
"orjson",
]

[tool.setuptools]
Expand Down
178 changes: 178 additions & 0 deletions reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
A very hacky script to test out capturing GPU memory usage against tokens and trainable parameters
Later this will be more automated and parallelized across TrueFoundry Jobs
"""
import itertools
import json
import os
import shlex
import subprocess
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import List

from transformers import AutoConfig

ML_REPO = "llm-ft-reporting"

MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
]
SEQ_LENS = [512, 1024, 2048, 4096, 8192]
LORA_RS = [8, 16, 32]

COMMAND = """\
accelerate launch
--mixed_precision bf16
--use_deepspeed
train.py
config-base.yaml
--deepspeed ./deepspeed_configs/3_ds_z2_config.json
--base_model {base_model}
--dataset_type chat
--train_data_uri ./sample_data/chatalpaca-openai-1k.jsonl
--val_data_uri None
--val_set_size 0.1
--sequence_len {sequence_len}
--long_sequences_strategy drop
--micro_batch_size 1
--eval_batch_size 1
--num_epochs 1
--max_steps 10
--gradient_accumulation_steps 4
--gradient_checkpointing unsloth
--learning_rate 0.00001
--output_dir ./outputs
--train_on_inputs False
--logging_steps 1
--save_strategy steps
--save_steps 0.5
--evaluation_strategy steps
--eval_steps 0.5
--adapter qlora
--lora_target_linear True
--lora_r {lora_r}
--lora_alpha {lora_alpha}
--resume_from_checkpoint False
--cleanup_output_dir_on_start True
--pad_to_sequence_len True
--truefoundry_ml_enable_reporting True
--truefoundry_ml_repo {ml_repo}
--truefoundry_ml_run_name {run_name}
--truefoundry_ml_log_checkpoints False
--truefoundry_ml_log_gpu_metrics True
--truefoundry_ml_log_merged_model False
--truefoundry_testing_mode True
"""


def stream_output(pipe, prefix=""):
for line in iter(pipe.readline, ""):
print(f"{prefix}{line.strip()}")
pipe.close()


def run_command(command: List[str]):
print("Running command: ", " ".join(command))
try:
process = subprocess.Popen(
shlex.join(command),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1, # line-buffered
env=os.environ,
shell=True,
)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(stream_output, process.stdout, "STDOUT: "),
# executor.submit(stream_output, process.stderr, "STDERR: "),
]
process.wait()
for future in futures:
future.result()
if process.returncode != 0:
raise Exception(f"Command failed with return code {process.returncode}")
except subprocess.CalledProcessError as e:
raise Exception(f"An error occurred while executing the command: {e}")
except Exception as e:
raise Exception(f"An unexpected error occurred: {e}")


def main():
env = {
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,roundup_power2_divisions:16",
"CUDA_VISIBLE_DEVICES": "0",
"TORCH_PER_PROCESS_MEMORY_LIMIT": "0.98",
}
for k, v in env.items():
os.environ[k] = v
for model, seq_len, lora_r in itertools.product(MODELS, SEQ_LENS, LORA_RS):
if os.path.exists("axolotl_truefoundry.plugin.log"):
os.remove("axolotl_truefoundry.plugin.log")
if os.path.exists("train.log"):
os.remove("train.log")
print(f"Model: {model}, Seq Len: {seq_len}, LoRA R: {lora_r}")
run_name = str(uuid.uuid4())
command = COMMAND.format(
base_model=model,
sequence_len=str(seq_len),
lora_r=str(lora_r),
lora_alpha=str(lora_r * 2),
ml_repo=ML_REPO,
run_name=run_name,
)
try:
run_command(
shlex.split(command),
)
except Exception as e:
print(f"Failed to run command: {e}")

logs = []
with open("axolotl_truefoundry.plugin.log") as f:
logs = [json.loads(line) for line in f.readlines()]
trainable_params = None
all_params = None
max_gpu_memory_allocated = -1
for log in logs:
if "trainable_params" in log:
trainable_params = log["trainable_params"]
if "all_params" in log:
all_params = log["all_params"]
if "system/gpu.0.memory_allocated" in log:
max_gpu_memory_allocated = max(max_gpu_memory_allocated, log["system/gpu.0.memory_allocated"])

cuda_oom = False
with open("train.log") as f:
for line in f.readlines():
if "CUDA out of memory. Tried to allocate" in line:
cuda_oom = True
break
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
print("=" * 80)
print(f"Config: {config}")
print(f"Model: {model}")
print(f"Seq Len: {seq_len}")
print(f"LoRA R: {lora_r}")
print(f"Trainable Params: {trainable_params}")
print(f"All Params: {all_params}")
print(f"CUDA OOM: {cuda_oom}")
print(f"GPU Memory Allocated: {max_gpu_memory_allocated}")
print("=" * 80)
if not trainable_params or not all_params:
raise Exception("Failed to capture params")

if not cuda_oom and max_gpu_memory_allocated == -1:
raise Exception("Failed to capture GPU memory usage")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions sample_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ accelerate launch \
train.py \
config-base.yaml \
--deepspeed ./deepspeed_configs/3_ds_z2_config.json \
--base_model Qwen/Qwen2.5-7B-Instruct \
--base_model Qwen/Qwen2.5-0.5B-Instruct \
--dataset_type chat \
--train_data_uri ./sample_data/chatalpaca-openai-1k.jsonl \
--val_data_uri None \
--val_set_size 0.2 \
--dataset_type chat \
--sequence_len 4096 \
--max_steps 0 \
--micro_batch_size 1 \
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
barrier()
if is_main_process():
cfg = load_config_file(path=axolotl_config)
if cfg.truefoundry_testing_mode is True:
return
model_dir = cfg.output_dir
log_step = get_step_for_final_model(
output_dir=cfg.output_dir, load_best_model_at_end=cfg.load_best_model_at_end
Expand Down Expand Up @@ -298,6 +300,7 @@ def train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
error_message = (
f"Rank {LOCAL_RANK} failed with error: {str(e)}\nPlease see the following traceback for more details."
)
logger.error(error_message)
c.print(panel.Panel.fit(f"[red]{error_message}[/]", title="Error", border_style="bright_red"))
raise

Expand Down

0 comments on commit 8af9c87

Please sign in to comment.