Skip to content

Commit

Permalink
Allow AdapterModels to have custom tokens (#306)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
mapmeld and clefourrier authored Dec 12, 2024
1 parent f907a34 commit 54244b3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ def accelerate( # noqa C901
# Keeping only non null params
args_dict = {k: v for k, v in args_dict.items() if v is not None}

if config["merged_weights"]["delta_weights"]:
if config["merged_weights"].get("delta_weights", False):
if config["merged_weights"]["base_model"] is None:
raise ValueError("You need to specify a base model when using delta weights")
model_config = DeltaModelConfig(**args_dict)
elif config["merged_weights"]["adapter_weights"]:
elif config["merged_weights"].get("adapter_weights", False):
if config["merged_weights"]["base_model"] is None:
raise ValueError("You need to specify a base model when using adapter weights")
model_config = AdapterModelConfig(**args_dict)
Expand Down
12 changes: 12 additions & 0 deletions src/lighteval/models/transformers/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)
base = AutoModelForCausalLM.from_pretrained(
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
)
# resize model for adapters with added tokens
token_diff = len(self._tokenizer) - base.config.vocab_size
if token_diff != 0:
if token_diff > 0:
logger.info(
f"You're using the adapter model's tokenizer, which has more tokens than the base model. Adding {token_diff} token(s)."
)
else:
logger.info(
f"You're using the adapter model's tokenizer, which has fewer tokens than the base model. Removing {abs(token_diff)} token(s)."
)
base.resize_token_embeddings(len(self._tokenizer))
# Should pass revision
model = PeftModel.from_pretrained(base, adapter_weights)
model = model.merge_and_unload()
Expand Down

0 comments on commit 54244b3

Please sign in to comment.