diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index cb384aa10c..4413bdfd00 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs["torch_dtype"]) - is_local = os.path.isdir(pretrained_model_name_or_path) + # Before loading the model, set the default dtype for torch + torch.set_default_dtype(kwargs["torch_dtype"]) + + # Load the vanilla model weights + vanilla_model = cls(config) subfolder = "" variant = None if os.path.isfile( @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k else: raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + resolved_archive_file, _ = get_checkpoint_shard_files( pretrained_model_name_or_path, archive_file, ) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 57c1bf6601..7013e85ec6 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -247,15 +247,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -554,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "bdb34b91", "metadata": {}, "outputs": [ @@ -573,15 +582,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -653,15 +671,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index b6b3683d4c..1aebe13afb 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -25,7 +25,10 @@ class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - # self.model_name = "" # <== Add model weight location here + + # Set to Meta Llama 2 by default. + self.model_name = "meta-llama/Llama-2-7b-hf" + self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -35,6 +38,10 @@ def __init__(self): self.num_warmup_steps = 5 self.num_training_steps = 10 + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + hyperparams = HyperParameters() @@ -76,13 +83,49 @@ def tokenize(element): return train_dataloader +def ensure_model_is_downloaded(hyperparams): + assert hyperparams.model_name in [ + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-2-7b-hf", + ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(hyperparams.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None + ) + hyperparams.weights_cache_dir = snapshot_download( + repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir + ) + + print(f"Model cache directory : {hyperparams.weights_cache_dir}") + + def init_baseline_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams): def init_te_llama_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model from te_llama import TELlamaForCausalLM - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, )