Skip to content

Commit

Permalink
Merge pull request #2 from axolotl-ai-cloud/fix-untrained-manual
Browse files Browse the repository at this point in the history
allow manual token_ids to be passed to fix untrained lrs
  • Loading branch information
winglian authored Dec 23, 2024
2 parents c99307b + 2722306 commit 00da8be
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/axolotl/contribs/lgpl/unsloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@

@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None,
):
"""
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
"""
# Code licensed under LGPL
if not token_ids_to_fix:
token_ids_to_fix = []
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
chat_template = getattr(tokenizer, "chat_template", None)
Expand Down Expand Up @@ -91,6 +93,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# Get set and actual tokens
where_untrained = where_untrained.tolist()
where_untrained = list(set(token_ids_to_fix + where_untrained))
if len(where_untrained) == 0:
return

Expand Down

0 comments on commit 00da8be

Please sign in to comment.