Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 16, 2024
1 parent 27a8d84 commit 674e499
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions docs/examples/te_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):

# Set to Meta Llama 2 by default.
self.model_name = "meta-llama/Llama-2-7b-hf"

self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
Expand All @@ -37,8 +37,8 @@ def __init__(self):
self.gradient_accumulation_steps = 1
self.num_warmup_steps = 5
self.num_training_steps = 10
# This is either provided by the user or it will be set when the

# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""

Expand Down Expand Up @@ -106,8 +106,12 @@ def ensure_model_is_downloaded(hyperparams):
# Download the model if it doesn't exist
from huggingface_hub import snapshot_download

supplied_cache_dir = hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
hyperparams.weights_cache_dir = snapshot_download(repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir)
supplied_cache_dir = (
hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
)
hyperparams.weights_cache_dir = snapshot_download(
repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
)

print(f"Model cache directory : {hyperparams.weights_cache_dir}")

Expand Down

0 comments on commit 674e499

Please sign in to comment.