-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
283 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime | ||
|
||
COPY requirements.txt /tmp/requirements.txt | ||
RUN pip install -r /tmp/requirements.txt && rm /tmp/requirements.txt | ||
|
||
WORKDIR /wandb | ||
COPY *.py /wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import importlib.util | ||
import site | ||
import tempfile | ||
from pathlib import Path | ||
|
||
from huggingface_hub import snapshot_download | ||
from omegaconf import DictConfig | ||
|
||
import wandb | ||
|
||
|
||
def get_script_path(recipe_name: str): | ||
"""Get recipe script path from recipe name. | ||
Torchtune installs recipe files including configs and scripts to a `recipes` | ||
directory in the site-packages directory. | ||
""" | ||
site_packages = site.getsitepackages()[0] | ||
return site_packages + f"/recipes/{recipe_name}.py" | ||
|
||
|
||
def load_recipe(recipe_name: str, recipe_class: str): | ||
"""Load recipe module given the recipe name and class class. | ||
This function loads the recipe module from the recipe script path and returns | ||
the recipe. Direct import is not possible due to an exception raised in | ||
`recipes/__init__.py`. | ||
""" | ||
script_path = get_script_path(recipe_name) | ||
spec = importlib.util.spec_from_file_location(recipe_name, script_path) | ||
module = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
return module.__dict__[recipe_class] | ||
|
||
|
||
def monkeypatch_checkpointing(recipe_class): | ||
"""Monkeypatch checkpointing in the recipe.""" | ||
|
||
# Make a copy of the original save_checkpoint method. | ||
_original_save_checkpoint = recipe_class.save_checkpoint | ||
|
||
def save_checkpoint(self, epoch): | ||
"""Save the checkpoint.""" | ||
_original_save_checkpoint(self, epoch) | ||
ckpt_dir = Path(self._checkpointer._output_dir) | ||
# List all files matching *_{epoch}.pt in the checkpoint directory. | ||
files = [f for f in ckpt_dir.iterdir() if f.name.endswith(f"_{epoch}.pt")] | ||
artifact = wandb.Artifact( | ||
name=f"model_{epoch}", | ||
type="model", | ||
metadata={ | ||
"seed": self.seed, | ||
"epochs_run": self.epochs_run, | ||
"total_epochs": self.total_epochs, | ||
"max_steps_per_epochs": self.max_steps_per_epoch, | ||
}, | ||
) | ||
for file in files: | ||
artifact.add_file(file) | ||
wandb.log_artifact(artifact) | ||
|
||
recipe_class.save_checkpoint = save_checkpoint | ||
|
||
|
||
def execute( | ||
model: str, | ||
recipe_name: str, | ||
recipe_class: str, | ||
config: DictConfig, | ||
dataset_artifact: str, | ||
): | ||
"""Run a given recipe, model, and config. | ||
This function downloads the model snapshot, sets the model directory in the | ||
config, creates a temporary output directory, and runs the recipe. The recipe | ||
is loaded from the recipe script path and the recipe class is instantiated with | ||
the config. | ||
Args: | ||
model (str): Hugging Face model hub name. | ||
recipe_name (str): Recipe name. | ||
recipe_class (str): Recipe class name. | ||
config (DictConfig): Recipe configuration. | ||
dataset_artifact (str): Dataset artifact name. | ||
""" | ||
with wandb.init(config={"dataset_artifact": dataset_artifact}) as run: | ||
dataset = run.use_artifact(run.config.dataset_artifact) | ||
config.data_dir = dataset.download() | ||
model_dir = snapshot_download(model) | ||
config["model_dir"] = model_dir | ||
with tempfile.TemporaryDirectory() as outdir: | ||
config["output_dir"] = outdir | ||
recipe_constructor = load_recipe(recipe_name, recipe_class) | ||
monkeypatch_checkpointing(recipe_constructor) | ||
recipe = recipe_constructor(config) | ||
recipe.setup(config) | ||
recipe.train() | ||
recipe.cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from omegaconf import DictConfig | ||
|
||
from helpers import execute | ||
|
||
MODEL = "meta-llama/Meta-Llama-3-8B" | ||
DATASET = "bcanfieldsherman/torchtune/alpaca:v0" | ||
CONFIG = DictConfig( | ||
{ | ||
"device": "cuda", | ||
"dtype": "bf16", | ||
"log_every_n_steps": None, | ||
"seed": None, | ||
"shuffle": False, | ||
"compile": False, | ||
"max_steps_per_epoch": None, | ||
"gradient_accumulation_steps": 16, | ||
"resume_from_checkpoint": False, | ||
"enable_activation_checkpointing": True, | ||
"epochs": 3, | ||
"batch_size": 2, | ||
"tokenizer": { | ||
"_component_": "torchtune.models.llama3.llama3_tokenizer", | ||
"path": "${model_dir}/original/tokenizer.model", | ||
}, | ||
"dataset": { | ||
"_component_": "torchtune.datasets.instruct_dataset", | ||
"source": "json", | ||
"data_files": "${data_dir}/train.json", | ||
"split": "train", | ||
"template": "AlpacaInstructTemplate", | ||
"train_on_input": True, | ||
}, | ||
"model": { | ||
"_component_": "torchtune.models.llama3.qlora_llama3_8b", | ||
"lora_attn_modules": ["q_proj", "k_proj", "v_proj", "output_proj"], | ||
"apply_lora_to_mlp": True, | ||
"apply_lora_to_output": False, | ||
"lora_rank": 8, | ||
"lora_alpha": 16, | ||
}, | ||
"metric_logger": { | ||
"_component_": "torchtune.utils.metric_logging.WandBLogger", | ||
"log_dir": "${output_dir}", | ||
}, | ||
"optimizer": { | ||
"_component_": "torch.optim.AdamW", | ||
"lr": 3e-4, | ||
"weight_decay": 0.01, | ||
}, | ||
"lr_scheduler": { | ||
"_component_": "torchtune.modules.get_cosine_schedule_with_warmup", | ||
"num_warmup_steps": 100, | ||
}, | ||
"loss": { | ||
"_component_": "torch.nn.CrossEntropyLoss", | ||
}, | ||
"checkpointer": { | ||
"_component_": "torchtune.utils.FullModelMetaCheckpointer", | ||
"checkpoint_dir": "${model_dir}/original", | ||
"checkpoint_files": [ | ||
"consolidated.00.pth" | ||
], | ||
"recipe_checkpoint": None, | ||
"output_dir": "${output_dir}", | ||
"model_type": "LLAMA3", | ||
}, | ||
"profiler": { | ||
"_component_": "torchtune.utils.profiler", | ||
"enabled": False, | ||
"output_dir": "${output_dir}", | ||
}, | ||
} | ||
) | ||
|
||
|
||
def main(): | ||
execute( | ||
MODEL, | ||
"lora_finetune_single_device", | ||
"LoRAFinetuneRecipeSingleDevice", | ||
CONFIG, | ||
DATASET, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from omegaconf import DictConfig | ||
|
||
from helpers import execute | ||
|
||
MODEL = "mistralai/Mistral-7B-v0.1" | ||
DATASET = "bcanfieldsherman/torchtune/alpaca:v0" | ||
CONFIG = DictConfig( | ||
{ | ||
"device": "cuda", | ||
"dtype": "bf16", | ||
"log_every_n_steps": None, | ||
"seed": None, | ||
"shuffle": False, | ||
"compile": False, | ||
"max_steps_per_epoch": None, | ||
"gradient_accumulation_steps": 4, | ||
"resume_from_checkpoint": False, | ||
"enable_activation_checkpointing": True, | ||
"epochs": 3, | ||
"batch_size": 4, | ||
"tokenizer": { | ||
"_component_": "torchtune.models.mistral.mistral_tokenizer", | ||
"path": "${model_dir}/tokenizer.model", | ||
}, | ||
"dataset": { | ||
"_component_": "torchtune.datasets.instruct_dataset", | ||
"source": "json", | ||
"data_files": "${data_dir}/train.json", | ||
"split": "train", | ||
"template": "AlpacaInstructTemplate", | ||
"train_on_input": True, | ||
}, | ||
"model": { | ||
"_component_": "torchtune.models.mistral.qlora_mistral_7b", | ||
"lora_attn_modules": ["q_proj", "k_proj", "v_proj"], | ||
"apply_lora_to_mlp": True, | ||
"apply_lora_to_output": False, | ||
"lora_rank": 64, | ||
"lora_alpha": 16, | ||
}, | ||
"metric_logger": { | ||
"_component_": "torchtune.utils.metric_logging.WandBLogger", | ||
"log_dir": "${output_dir}", | ||
}, | ||
"optimizer": { | ||
"_component_": "torch.optim.AdamW", | ||
"lr": 2e-5, | ||
}, | ||
"lr_scheduler": { | ||
"_component_": "torchtune.modules.get_cosine_schedule_with_warmup", | ||
"num_warmup_steps": 100, | ||
}, | ||
"loss": { | ||
"_component_": "torch.nn.CrossEntropyLoss", | ||
}, | ||
"checkpointer": { | ||
"_component_": "torchtune.utils.FullModelHFCheckpointer", | ||
"checkpoint_dir": "${model_dir}", | ||
"checkpoint_files": [ | ||
"pytorch_model-00001-of-00002.bin", | ||
"pytorch_model-00002-of-00002.bin", | ||
], | ||
"recipe_checkpoint": None, | ||
"output_dir": "${output_dir}", | ||
"model_type": "MISTRAL", | ||
}, | ||
"profiler": { | ||
"_component_": "torchtune.utils.profiler", | ||
"enabled": False, | ||
"output_dir": "${output_dir}", | ||
}, | ||
} | ||
) | ||
|
||
|
||
def main(): | ||
execute( | ||
MODEL, | ||
"lora_finetune_single_device", | ||
"LoRAFinetuneRecipeSingleDevice", | ||
CONFIG, | ||
DATASET, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch==2.3.1 | ||
torchao==0.1 | ||
torchtune==0.1.1 | ||
wandb |