Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Mar 1, 2024
1 parent 3a7e8af commit ec4e68f
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 21 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ keyword arguments not specified in the config if we'd like:
# Tokenizer is needed for the dataset, configure it first
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/tokenizer.model
dataset:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/finetune_llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ An example config for training the Llama 7B model using the Alpaca dataset looks
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/tokenizer.model
# Dataset
Expand All @@ -40,7 +40,7 @@ An example config for training the Llama 7B model using the Alpaca dataset looks
# Model Arguments
model:
_component_: torchtune.models.llama2_7b
_component_: torchtune.models.llama2.llama2_7b
model_checkpoint: /tmp/llama2-7b
# Fine-tuning arguments
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/first_finetune_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ lowering the epochs to 1 so you can see results sooner, and updating the learnin
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/tokenizer.model
# Dataset
Expand All @@ -108,7 +108,7 @@ lowering the epochs to 1 so you can see results sooner, and updating the learnin
# Model Arguments
model:
_component_: torchtune.models.llama2_7b
_component_: torchtune.models.llama2.llama2_7b
model_checkpoint: /tmp/llama2/native_pytorch_model.pt
# Fine-tuning arguments
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/alpaca_llama2_full_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Dataset
Expand All @@ -17,7 +17,7 @@ shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2_7b
_component_: torchtune.models.llama2.llama2_7b
model_checkpoint: /tmp/llama2_native

# Fine-tuning arguments
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/alpaca_llama2_generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

# Model arguments
model:
_component_: torchtune.models.llama2_7b
_component_: torchtune.models.llama2.llama2_7b
model_checkpoint: /tmp/llama2_native

# Tokenizer arguments
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Generation arguments
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/alpaca_llama2_lora_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Model Arguments
model:
_component_: torchtune.models.lora_llama2_7b
_component_: torchtune.models.llama2.lora_llama2_7b
lora_attn_modules: ['q_proj', 'v_proj']
lora_rank: 8
lora_alpha: 16
Expand All @@ -15,7 +15,7 @@ lora_checkpoint: null

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2_tokenizer
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Dataset and Sampler
Expand Down
6 changes: 3 additions & 3 deletions recipes/tests/test_alpaca_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ class TestAlpacaGenerateRecipe:
def _fetch_ckpt_model_path(self, ckpt) -> str:
if ckpt == "small_test_ckpt":
return "/tmp/test-artifacts/small-ckpt-01242024"
if ckpt == "llama2_7b":
if ckpt == "llama2.llama2_7b":
return "/tmp/test-artifacts/llama2-7b-01242024"
raise ValueError(f"Unknown ckpt {ckpt}")

def test_alpaca_generate(self, capsys, pytestconfig):
large_scale = pytestconfig.getoption("--large-scale")
ckpt = "llama2_7b" if large_scale else "small_test_ckpt"
ckpt = "llama2.llama2_7b" if large_scale else "small_test_ckpt"

kwargs_values = {
"model": {"_component_": f"torchtune.models.{ckpt}"},
"model_checkpoint": self._fetch_ckpt_model_path(ckpt),
"tokenizer": {
"_component_": "torchtune.models.llama2_tokenizer",
"_component_": "torchtune.models.llama2.llama2_tokenizer",
"path": "/tmp/test-artifacts/tokenizer.model",
},
"instruction": "Answer the question.",
Expand Down
10 changes: 5 additions & 5 deletions recipes/tests/test_full_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]:
}
if ckpt == "small_test_ckpt":
return small_test_ckpt_loss_values
if ckpt == "llama2_7b":
if ckpt == "llama2.llama2_7b":
return llama2_7b_ckpt_loss_values
raise ValueError(f"Unknown ckpt {ckpt}")

def test_loss(self, capsys, pytestconfig):
large_scale = pytestconfig.getoption("--large-scale")
ckpt = "llama2_7b" if large_scale else "small_test_ckpt"
ckpt = "llama2.llama2_7b" if large_scale else "small_test_ckpt"
expected_loss_values = self._fetch_expected_loss_values(ckpt)

kwargs_values = default_recipe_kwargs(ckpt)
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_training_state_on_resume(self):
"model": {"_component_": f"torchtune.models.{model_ckpt}"},
"model_checkpoint": fetch_ckpt_model_path(model_ckpt),
"tokenizer": {
"_component_": "torchtune.models.llama2_tokenizer",
"_component_": "torchtune.models.llama2.llama2_tokenizer",
"path": "/tmp/test-artifacts/tokenizer.model",
},
"epochs": 4,
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_training_state_on_resume(self):
"model": {"_component_": f"torchtune.models.{model_ckpt}"},
"model_checkpoint": os.path.join(tmpdirname, "model_2.ckpt"),
"tokenizer": {
"_component_": "torchtune.models.llama2_tokenizer",
"_component_": "torchtune.models.llama2.llama2_tokenizer",
"path": "/tmp/test-artifacts/tokenizer.model",
},
"epochs": 4,
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_gradient_accumulation(
"model": {"_component_": f"torchtune.models.{model_ckpt}"},
"model_checkpoint": None,
"tokenizer": {
"_component_": "torchtune.models.llama2_tokenizer",
"_component_": "torchtune.models.llama2.llama2_tokenizer",
"path": "/tmp/test-artifacts/tokenizer.model",
},
"batch_size": full_batch_size,
Expand Down
2 changes: 1 addition & 1 deletion recipes/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def default_recipe_kwargs(ckpt):
"model": {"_component_": f"torchtune.models.{ckpt}"},
"model_checkpoint": fetch_ckpt_model_path(ckpt),
"tokenizer": {
"_component_": "torchtune.models.llama2_tokenizer",
"_component_": "torchtune.models.llama2.llama2_tokenizer",
"path": "/tmp/test-artifacts/tokenizer.model",
},
"batch_size": 8,
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_get_component_from_path(self):
good_paths = [
"torchtune", # Test single module without dot
"torchtune.models", # Test dotpath for a module
"torchtune.models.llama2_7b", # Test dotpath for an object
"torchtune.models.llama2.llama2_7b", # Test dotpath for an object
]
for path in good_paths:
_ = _get_component_from_path(path)
Expand Down
11 changes: 11 additions & 0 deletions torchtune/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models import llama2

__all__ = [
"llama2",
]

0 comments on commit ec4e68f

Please sign in to comment.