Skip to content

Commit

Permalink
add llama2-qlora job
Browse files Browse the repository at this point in the history
  • Loading branch information
bcsherma committed Jul 16, 2024
1 parent 9adbc45 commit a62fc17
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 0 deletions.
7 changes: 7 additions & 0 deletions jobs/torchtune/Dockerfile.wandb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime

COPY requirements.txt /tmp/requirements.txt
RUN pip install -r /tmp/requirements.txt && rm /tmp/requirements.txt

WORKDIR /wandb
COPY *.py /wandb/
60 changes: 60 additions & 0 deletions jobs/torchtune/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import importlib.util
import site
import tempfile

from huggingface_hub import snapshot_download
from omegaconf import DictConfig


def get_script_path(recipe_name: str):
"""Get recipe script path from recipe name.
Torchtune installs recipe files including configs and scripts to a `recipes`
directory in the site-packages directory.
"""
site_packages = site.getsitepackages()[0]
return site_packages + f"/recipes/{recipe_name}.py"


def load_recipe(recipe_name: str, recipe_class: str):
"""Load recipe module given the recipe name and class class.
This function loads the recipe module from the recipe script path and returns
the recipe. Direct import is not possible due to an exception raised in
`recipes/__init__.py`.
"""
script_path = get_script_path(recipe_name)
spec = importlib.util.spec_from_file_location(recipe_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.__dict__[recipe_class]


def execute(
model: str,
recipe_name: str,
recipe_class: str,
config: DictConfig,
):
"""Run a given recipe, model, and config.
This function downloads the model snapshot, sets the model directory in the
config, creates a temporary output directory, and runs the recipe. The recipe
is loaded from the recipe script path and the recipe class is instantiated with
the config.
Args:
model (str): Hugging Face model hub name.
recipe_name (str): Recipe name.
recipe_class (str): Recipe class name.
config (DictConfig): Recipe configuration.
"""
model_dir = snapshot_download(model)
config["model_dir"] = model_dir
with tempfile.TemporaryDirectory() as outdir:
config["output_dir"] = outdir
recipe_constructor = load_recipe(recipe_name, recipe_class)
recipe = recipe_constructor(config)
recipe.setup(config)
recipe.train()
recipe.cleanup()
81 changes: 81 additions & 0 deletions jobs/torchtune/llama3_8b_qlora_single_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from omegaconf import DictConfig

from helpers import execute

MODEL = "meta-llama/Meta-Llama-3-8B"
CONFIG = DictConfig(
{
"device": "cuda",
"dtype": "bf16",
"log_every_n_steps": None,
"seed": None,
"shuffle": False,
"compile": False,
"max_steps_per_epoch": None,
"gradient_accumulation_steps": 16,
"resume_from_checkpoint": False,
"enable_activation_checkpointing": True,
"epochs": 3,
"batch_size": 2,
"tokenizer": {
"_component_": "torchtune.models.llama3.llama3_tokenizer",
"path": "${model_dir}/original/tokenizer.model",
},
"dataset": {
"_component_": "torchtune.datasets.alpaca_dataset",
"train_on_input": True,
},
"model": {
"_component_": "torchtune.models.llama3.qlora_llama3_8b",
"lora_attn_modules": ["q_proj", "k_proj", "v_proj", "output_proj"],
"apply_lora_to_mlp": True,
"apply_lora_to_output": False,
"lora_rank": 8,
"lora_alpha": 16,
},
"metric_logger": {
"_component_": "torchtune.utils.metric_logging.WandBLogger",
"log_dir": "${output_dir}",
},
"optimizer": {
"_component_": "torch.optim.AdamW",
"lr": 3e-4,
"weight_decay": 0.01,
},
"lr_scheduler": {
"_component_": "torchtune.modules.get_cosine_schedule_with_warmup",
"num_warmup_steps": 100,
},
"loss": {
"_component_": "torch.nn.CrossEntropyLoss",
},
"checkpointer": {
"_component_": "torchtune.utils.FullModelMetaCheckpointer",
"checkpoint_dir": "${model_dir}/original",
"checkpoint_files": [
"consolidated.00.pth"
],
"recipe_checkpoint": None,
"output_dir": "${output_dir}",
"model_type": "LLAMA3",
},
"profiler": {
"_component_": "torchtune.utils.profiler",
"enabled": False,
"output_dir": "${output_dir}",
},
}
)


def main():
execute(
MODEL,
"lora_finetune_single_device",
"LoRAFinetuneRecipeSingleDevice",
CONFIG,
)


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions jobs/torchtune/mistral_7b_qlora_single_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from omegaconf import DictConfig

from helpers import execute

MODEL = "mistralai/Mistral-7B-v0.1"
CONFIG = DictConfig(
{
"device": "cuda",
"dtype": "bf16",
"log_every_n_steps": None,
"seed": None,
"shuffle": False,
"compile": False,
"max_steps_per_epoch": None,
"gradient_accumulation_steps": 4,
"resume_from_checkpoint": False,
"enable_activation_checkpointing": True,
"epochs": 3,
"batch_size": 4,
"tokenizer": {
"_component_": "torchtune.models.mistral.mistral_tokenizer",
"path": "${model_dir}/tokenizer.model",
},
"dataset": {
"_component_": "torchtune.datasets.alpaca_dataset",
"train_on_input": True,
},
"model": {
"_component_": "torchtune.models.mistral.qlora_mistral_7b",
"lora_attn_modules": ["q_proj", "k_proj", "v_proj"],
"apply_lora_to_mlp": True,
"apply_lora_to_output": False,
"lora_rank": 64,
"lora_alpha": 16,
},
"metric_logger": {
"_component_": "torchtune.utils.metric_logging.WandBLogger",
"log_dir": "${output_dir}",
},
"optimizer": {
"_component_": "torch.optim.AdamW",
"lr": 2e-5,
},
"lr_scheduler": {
"_component_": "torchtune.modules.get_cosine_schedule_with_warmup",
"num_warmup_steps": 100,
},
"loss": {
"_component_": "torch.nn.CrossEntropyLoss",
},
"checkpointer": {
"_component_": "torchtune.utils.FullModelHFCheckpointer",
"checkpoint_dir": "${model_dir}",
"checkpoint_files": [
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
],
"recipe_checkpoint": None,
"output_dir": "${output_dir}",
"model_type": "MISTRAL",
},
"profiler": {
"_component_": "torchtune.utils.profiler",
"enabled": False,
"output_dir": "${output_dir}",
},
}
)


def main():
execute(
MODEL,
"lora_finetune_single_device",
"LoRAFinetuneRecipeSingleDevice",
CONFIG,
)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions jobs/torchtune/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch==2.3.1
torchao==0.1
torchtune==0.1.1
wandb

0 comments on commit a62fc17

Please sign in to comment.