Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
sungho.park committed Mar 17, 2023
1 parent f44b609 commit aabe225
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LLaMAForCausalLM, LLaMATokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
Expand Down Expand Up @@ -398,9 +397,9 @@ def main():
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name:
tokenizer = LLaMATokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.model_name_or_path:
tokenizer = LLaMATokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
Expand All @@ -413,7 +412,7 @@ def main():
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = LLaMAForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
Expand All @@ -423,7 +422,7 @@ def main():
torch_dtype=torch_dtype,
)
else:
model = LLaMAForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")

Expand Down

0 comments on commit aabe225

Please sign in to comment.