Skip to content

Commit

Permalink
update from unsloth-zoo with additional fixes (axolotl-ai-cloud#2122)
Browse files Browse the repository at this point in the history
only update tokens seen in the train dataset, log them out explicitly
  • Loading branch information
winglian authored Dec 4, 2024
1 parent e2882dd commit d7d2fd3
Showing 1 changed file with 162 additions and 34 deletions.
196 changes: 162 additions & 34 deletions src/axolotl/core/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,89 @@

import gc
import itertools
import logging
from collections import Counter

import datasets
import numpy as np
import torch

LOG = logging.getLogger("axolotl.core.tokenizer_utils")

@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):

@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
):
"""
Many of the newer models have reserved tokens that are not trained.
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
"""
# Code licensed under LGPL
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
chat_template = getattr(tokenizer, "chat_template", None)
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer

# Ignore some model checks for now
if not ignored_tokenizer_names:
ignored_tokenizer_names = []
if (
model.config._name_or_path # pylint: disable=protected-access
in ignored_tokenizer_names
):
return

# Sometimes the sizes can be different like in vision models
# Ie <image> is in input, but not in output
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
embedding_matrix = embedding_matrix[:, :min_size]
lm_head_matrix = lm_head_matrix[:, :min_size]

# Get untrained tokens
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
# Check lm_head as well

# Does NOT work for Llama 3.1!!
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps

# We instead check for repeated vectors
lm_head_where = torch.where(indicator_untrained1)[0]
lm_head_bad = lm_head_matrix[lm_head_where]
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
counter = Counter()
for row in lm_head_bad:
counter[hash(row.data.tobytes())] += 1
counter = Counter({k: c for k, c in counter.items() if c >= 2})

lm_head_where = lm_head_where.cpu().numpy()
final_bad_lm_head = []
for j, row in enumerate(lm_head_bad):
if hash(row.data.tobytes()) in counter:
final_bad_lm_head.append(lm_head_where[j])
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
indicator_untrained2[final_bad_lm_head] = True

# Combine both checks
indicator_untrained = indicator_untrained1 & indicator_untrained2

# Remove pad token possibility
if hasattr(tokenizer, "pad_token_id"):
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
indicator_untrained[pad_token_id] = False

where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained

# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return False
return

# Remove untrained indices where it's longer

where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
Expand All @@ -53,10 +110,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)

if isinstance(train_dataset, datasets.IterableDataset):
# Skip the check, since the code below assumes
# an indexable dataset
return

# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
Expand All @@ -83,7 +144,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):

# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return False
return

# Check if lm_head / embed_token are trainable!
bad_not_trainable = False
if not embedding_matrix.requires_grad:
bad_not_trainable = True
if not lm_head_matrix.requires_grad:
bad_not_trainable = True

if bad_not_trainable: # pylint: disable=too-many-nested-blocks
final_bad_items = []

# Re-check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)

# Re-check last 250
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)

# If no bad tokens, possibly chat template itself has issues?
if len(final_bad_items) == 0:
# Recheck 2000 and last 2000 items
size_dataset = len(train_dataset)
size = min(size_dataset, 2000)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)

# Re-check last 2000
left = max(size_dataset - 2000, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)

# Most likely false signal!
if len(final_bad_items) == 0:
return

raise ValueError(
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
)

# Count all the possible bad tokens
final_counts = np.zeros(
Expand All @@ -97,6 +220,23 @@ def mapping(examples):

train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")

# Get counts for untrained tokens
counts_untrained = final_counts[where_untrained]
# Identify untrained tokens seen in train_dataset
indices_seen_in_train = np.where(counts_untrained > 0)[0]
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]

if len(tokens_to_update) == 0:
LOG.info(
"No untrained tokens found in train_dataset. No embeddings were modified."
)
return

# Log the token IDs that are being rescaled
LOG.info(
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
)

# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
Expand All @@ -113,38 +253,26 @@ def mapping(examples):
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained

# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
# Compute scaling for tokens to update
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling

# Prepare mean embeddings for tokens to update
mean_embedding_repeated = (
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
mean_lm_head_repeated = (
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0

# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Update embeddings only for tokens seen in train_dataset
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
embedding_matrix.dtype
)
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)

# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()

return True
return

0 comments on commit d7d2fd3

Please sign in to comment.