From d7d2fd366eb2b7eaa35c0868a2a2d111b842ca81 Mon Sep 17 00:00:00 2001 From: Wing Lian <wing.lian@gmail.com> Date: Wed, 4 Dec 2024 12:26:08 -0500 Subject: [PATCH] update from unsloth-zoo with additional fixes (#2122) only update tokens seen in the train dataset, log them out explicitly --- src/axolotl/core/tokenizer_utils.py | 196 +++++++++++++++++++++++----- 1 file changed, 162 insertions(+), 34 deletions(-) diff --git a/src/axolotl/core/tokenizer_utils.py b/src/axolotl/core/tokenizer_utils.py index 53c44a75c4..8e5c9205b6 100644 --- a/src/axolotl/core/tokenizer_utils.py +++ b/src/axolotl/core/tokenizer_utils.py @@ -18,21 +18,79 @@ import gc import itertools +import logging +from collections import Counter +import datasets import numpy as np import torch +LOG = logging.getLogger("axolotl.core.tokenizer_utils") -@torch.inference_mode -def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16): + +@torch.inference_mode() +def fix_untrained_tokens( # pylint: disable=too-many-return-statements + model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16 +): """ - Many of the newer models have reserved tokens that are not trained. + Llama-3 for eg has untrained vectors in the base model. + These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> + We reset them to the mean of the rest of the tokens """ + # Code licensed under LGPL embedding_matrix = model.get_input_embeddings().weight lm_head_matrix = model.get_output_embeddings().weight + chat_template = getattr(tokenizer, "chat_template", None) + tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer + + # Ignore some model checks for now + if not ignored_tokenizer_names: + ignored_tokenizer_names = [] + if ( + model.config._name_or_path # pylint: disable=protected-access + in ignored_tokenizer_names + ): + return + + # Sometimes the sizes can be different like in vision models + # Ie <image> is in input, but not in output + min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1]) + embedding_matrix = embedding_matrix[:, :min_size] + lm_head_matrix = lm_head_matrix[:, :min_size] # Get untrained tokens - indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps + indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps + # Check lm_head as well + + # Does NOT work for Llama 3.1!! + indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps + + # We instead check for repeated vectors + lm_head_where = torch.where(indicator_untrained1)[0] + lm_head_bad = lm_head_matrix[lm_head_where] + lm_head_bad = lm_head_bad.cpu().float().numpy().round(3) + counter = Counter() + for row in lm_head_bad: + counter[hash(row.data.tobytes())] += 1 + counter = Counter({k: c for k, c in counter.items() if c >= 2}) + + lm_head_where = lm_head_where.cpu().numpy() + final_bad_lm_head = [] + for j, row in enumerate(lm_head_bad): + if hash(row.data.tobytes()) in counter: + final_bad_lm_head.append(lm_head_where[j]) + indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2) + indicator_untrained2[final_bad_lm_head] = True + + # Combine both checks + indicator_untrained = indicator_untrained1 & indicator_untrained2 + + # Remove pad token possibility + if hasattr(tokenizer, "pad_token_id"): + pad_token_id = tokenizer.pad_token_id + if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]: + indicator_untrained[pad_token_id] = False + where_untrained = torch.where(indicator_untrained)[0] n_untrained = where_untrained.shape[0] n_trained = embedding_matrix.shape[0] - n_untrained @@ -40,10 +98,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16): # Get set and actual tokens where_untrained = where_untrained.tolist() if len(where_untrained) == 0: - return False + return # Remove untrained indices where it's longer - where_untrained_set = frozenset(where_untrained) actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained) # Remove None items in actual_bad_tokens @@ -53,10 +110,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16): if_bad_first = False if_bad_second = False # Check tokenizer's chat template for any untrained tokens - chat_template = getattr(tokenizer, "chat_template", None) if chat_template is not None: if_bad_first = any(x in chat_template for x in actual_bad_tokens) + if isinstance(train_dataset, datasets.IterableDataset): + # Skip the check, since the code below assumes + # an indexable dataset + return + # Check the first 250, last 250 input_ids size_dataset = len(train_dataset) size = min(size_dataset, 250) @@ -83,7 +144,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16): # Check if bad tokens exists! if not if_bad_first and not if_bad_second: - return False + return + + # Check if lm_head / embed_token are trainable! + bad_not_trainable = False + if not embedding_matrix.requires_grad: + bad_not_trainable = True + if not lm_head_matrix.requires_grad: + bad_not_trainable = True + + if bad_not_trainable: # pylint: disable=too-many-nested-blocks + final_bad_items = [] + + # Re-check the first 250, last 250 input_ids + size_dataset = len(train_dataset) + size = min(size_dataset, 250) + for j in range(size): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + for item in input_ids: + if item in where_untrained_set: + final_bad_items.append(item) + + # Re-check last 250 + left = max(size_dataset - 250, 0) + for j in range(left, size_dataset): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + for item in input_ids: + if item in where_untrained_set: + final_bad_items.append(item) + + # If no bad tokens, possibly chat template itself has issues? + if len(final_bad_items) == 0: + # Recheck 2000 and last 2000 items + size_dataset = len(train_dataset) + size = min(size_dataset, 2000) + for j in range(size): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + for item in input_ids: + if item in where_untrained_set: + final_bad_items.append(item) + + # Re-check last 2000 + left = max(size_dataset - 2000, 0) + for j in range(left, size_dataset): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + for item in input_ids: + if item in where_untrained_set: + final_bad_items.append(item) + + # Most likely false signal! + if len(final_bad_items) == 0: + return + + raise ValueError( + f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. " + ) # Count all the possible bad tokens final_counts = np.zeros( @@ -97,6 +220,23 @@ def mapping(examples): train_dataset.map(mapping, batched=True, desc="Counting untrained tokens") + # Get counts for untrained tokens + counts_untrained = final_counts[where_untrained] + # Identify untrained tokens seen in train_dataset + indices_seen_in_train = np.where(counts_untrained > 0)[0] + tokens_to_update = [where_untrained[i] for i in indices_seen_in_train] + + if len(tokens_to_update) == 0: + LOG.info( + "No untrained tokens found in train_dataset. No embeddings were modified." + ) + return + + # Log the token IDs that are being rescaled + LOG.info( + f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}" + ) + # Get sum of all items sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0) sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0) @@ -113,38 +253,26 @@ def mapping(examples): mean_embedding = sum_embedding / n_trained mean_lm_head = sum_lm_head / n_trained - # Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen - scaling = final_counts[where_untrained] / max(final_counts.max(), 1) + # Compute scaling for tokens to update + scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1) scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1) - mean_embedding = ( - mean_embedding.repeat( - ( - n_untrained, - 1, - ) - ) - * scaling + + # Prepare mean embeddings for tokens to update + mean_embedding_repeated = ( + mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling ) - mean_lm_head = ( - mean_lm_head.repeat( - ( - n_untrained, - 1, - ) - ) - * scaling + mean_lm_head_repeated = ( + mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling ) - where_null = scaling.ravel() == 0 - mean_embedding[where_null] = 0 - mean_lm_head[where_null] = 0 - # Set them to the mean - embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype) - lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype) + # Update embeddings only for tokens seen in train_dataset + embedding_matrix[tokens_to_update] = mean_embedding_repeated.to( + embedding_matrix.dtype + ) + lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype) # Clean up for _ in range(3): gc.collect() torch.cuda.empty_cache() - - return True + return