From da8ff7dc25db83f6e70832bdda651c1a2d0f4270 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 15 Dec 2024 23:28:02 -0500 Subject: [PATCH 1/3] use axolotl contribs for fix_untrained_tokens --- requirements.txt | 2 ++ src/axolotl/__init__.py | 4 ++++ src/axolotl/train.py | 4 +++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0373548d1..92da0163f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,3 +60,5 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 + +axolotl-contribs-lgpl==0.0.1b2 diff --git a/src/axolotl/__init__.py b/src/axolotl/__init__.py index 9a3174598..8b0ba0532 100644 --- a/src/axolotl/__init__.py +++ b/src/axolotl/__init__.py @@ -1,3 +1,7 @@ """Axolotl - Train and fine-tune large language models""" +import pkgutil + +__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package + __version__ = "0.6.0" diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 851a71e54..d925afc25 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -19,7 +19,9 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs -from axolotl.core.tokenizer_utils import fix_untrained_tokens +from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module + fix_untrained_tokens, +) from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except From f68487a8059057fd066f11c1062fef1ee9b6e7cd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 16 Dec 2024 09:01:25 -0500 Subject: [PATCH 2/3] remove the module we're replacing --- src/axolotl/core/tokenizer_utils.py | 272 ---------------------------- 1 file changed, 272 deletions(-) delete mode 100644 src/axolotl/core/tokenizer_utils.py diff --git a/src/axolotl/core/tokenizer_utils.py b/src/axolotl/core/tokenizer_utils.py deleted file mode 100644 index 1b86b9c49..000000000 --- a/src/axolotl/core/tokenizer_utils.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -helper functions for fixing the embeddings/tokenizer -""" - -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. -# GNU LESSER GENERAL PUBLIC LICENSE -# Version 3, 29 June 2007 -# -# Copyright (C) 2007 Free Software Foundation, Inc. -# Everyone is permitted to copy and distribute verbatim copies -# of this license document, but changing it is not allowed. - -import gc -import itertools -import logging -from collections import Counter - -import datasets -import numpy as np -import torch - -LOG = logging.getLogger("axolotl.core.tokenizer_utils") - - -@torch.inference_mode() -def fix_untrained_tokens( # pylint: disable=too-many-return-statements - model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16 -): - """ - Llama-3 for eg has untrained vectors in the base model. - These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> - We reset them to the mean of the rest of the tokens - """ - # Code licensed under LGPL - embedding_matrix = model.get_input_embeddings().weight - lm_head_matrix = model.get_output_embeddings().weight - chat_template = getattr(tokenizer, "chat_template", None) - tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer - - # Ignore some model checks for now - if not ignored_tokenizer_names: - ignored_tokenizer_names = [] - if ( - model.config._name_or_path # pylint: disable=protected-access - in ignored_tokenizer_names - ): - return - - # Sometimes the sizes can be different like in vision models - # Ie is in input, but not in output - min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1]) - embedding_matrix = embedding_matrix[:, :min_size] - lm_head_matrix = lm_head_matrix[:, :min_size] - - # Get untrained tokens - indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps - # Check lm_head as well - - # Does NOT work for Llama 3.1!! - indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps - - # We instead check for repeated vectors - lm_head_where = torch.where(indicator_untrained1)[0] - lm_head_bad = lm_head_matrix[lm_head_where] - lm_head_bad = lm_head_bad.cpu().float().numpy().round(3) - counter = Counter() - for row in lm_head_bad: - counter[hash(row.data.tobytes())] += 1 - counter = Counter({k: c for k, c in counter.items() if c >= 2}) - - lm_head_where = lm_head_where.cpu().numpy() - final_bad_lm_head = [] - for j, row in enumerate(lm_head_bad): - if hash(row.data.tobytes()) in counter: - final_bad_lm_head.append(lm_head_where[j]) - indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2) - indicator_untrained2[final_bad_lm_head] = True - - # Combine both checks - indicator_untrained = indicator_untrained1 & indicator_untrained2 - - # Remove pad token possibility - if hasattr(tokenizer, "pad_token_id"): - pad_token_id = tokenizer.pad_token_id - if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]: - indicator_untrained[pad_token_id] = False - - where_untrained = torch.where(indicator_untrained)[0] - n_untrained = where_untrained.shape[0] - n_trained = embedding_matrix.shape[0] - n_untrained - - # Get set and actual tokens - where_untrained = where_untrained.tolist() - if len(where_untrained) == 0: - return - - # Remove untrained indices where it's longer - where_untrained_set = frozenset(where_untrained) - actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained) - # Remove None items in actual_bad_tokens - actual_bad_tokens = [x for x in actual_bad_tokens if x is not None] - - # Check if tokenizer and training datasets have bad tokens - if_bad_first = False - if_bad_second = False - # Check tokenizer's chat template for any untrained tokens - if chat_template is not None: - if_bad_first = any(x in chat_template for x in actual_bad_tokens) - - if isinstance(train_dataset, datasets.IterableDataset): - # Skip the check, since the code below assumes - # an indexable dataset - return - - # Check the first 250, last 250 input_ids - size_dataset = len(train_dataset) - size = min(size_dataset, 250) - for j in range(size): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - if_bad = any(item in where_untrained_set for item in input_ids) - if if_bad: - if_bad_second = True - break - - # Check last 250 - if not if_bad_second: - left = max(size_dataset - 250, 0) - for j in range(left, size_dataset): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - if_bad = any(item in where_untrained_set for item in input_ids) - if if_bad: - if_bad_second = True - break - - # Check if bad tokens exists! - if not if_bad_first and not if_bad_second: - return - - # Check if lm_head / embed_token are trainable! - bad_not_trainable = False - if not embedding_matrix.requires_grad: - bad_not_trainable = True - if not lm_head_matrix.requires_grad: - bad_not_trainable = True - - if bad_not_trainable: # pylint: disable=too-many-nested-blocks - final_bad_items = [] - - # Re-check the first 250, last 250 input_ids - size_dataset = len(train_dataset) - size = min(size_dataset, 250) - for j in range(size): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: - final_bad_items.append(item) - - # Re-check last 250 - left = max(size_dataset - 250, 0) - for j in range(left, size_dataset): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: - final_bad_items.append(item) - - # If no bad tokens, possibly chat template itself has issues? - if len(final_bad_items) == 0: - # Recheck 2000 and last 2000 items - size_dataset = len(train_dataset) - size = min(size_dataset, 2000) - for j in range(size): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: - final_bad_items.append(item) - - # Re-check last 2000 - left = max(size_dataset - 2000, 0) - for j in range(left, size_dataset): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: - final_bad_items.append(item) - - # Most likely false signal! - if len(final_bad_items) == 0: - return - - raise ValueError( - f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. " - ) - - # Count all the possible bad tokens - final_counts = np.zeros( - max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64 - ) - - def mapping(examples): - input_ids = examples["input_ids"] - counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32) - np.add.at(final_counts, counter, 1) - - train_dataset.map(mapping, batched=True, desc="Counting untrained tokens") - - # Get counts for untrained tokens - counts_untrained = final_counts[where_untrained] - # Identify untrained tokens seen in train_dataset - indices_seen_in_train = np.where(counts_untrained > 0)[0] - tokens_to_update = [where_untrained[i] for i in indices_seen_in_train] - - if len(tokens_to_update) == 0: - LOG.info( - "No untrained tokens found in train_dataset. No embeddings were modified." - ) - return - - # Log the token IDs that are being rescaled - LOG.info( - f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}" - ) - - # Get sum of all items - sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0) - sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0) - - # Remove bad tokens - sum_embedding -= torch.sum( - embedding_matrix[where_untrained], dtype=torch.float32, axis=0 - ) - sum_lm_head -= torch.sum( - lm_head_matrix[where_untrained], dtype=torch.float32, axis=0 - ) - - # Find correct average by dividing by sum of trained tokens - mean_embedding = sum_embedding / n_trained - mean_lm_head = sum_lm_head / n_trained - - # Compute scaling for tokens to update - scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1) - scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1) - - # Prepare mean embeddings for tokens to update - mean_embedding_repeated = ( - mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling - ) - mean_lm_head_repeated = ( - mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling - ) - - # Update embeddings only for tokens seen in train_dataset - embedding_matrix[tokens_to_update] = mean_embedding_repeated.to( - embedding_matrix.dtype - ) - lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype) - - # Clean up - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - return From df584fc6d61d377a461053cf9d0f2f75774f7eec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 11:36:40 -0500 Subject: [PATCH 3/3] Add check for using fix_untrained_tokens --- tests/e2e/test_llama.py | 52 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 4e885a76d..33d12157a 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -4,7 +4,6 @@ import logging import os -import unittest from pathlib import Path from axolotl.cli import load_datasets @@ -13,18 +12,15 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir - LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestLlama(unittest.TestCase): +class TestLlama: """ Test case for Llama models """ - @with_temp_dir def test_fft_trust_remote_code(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( @@ -46,7 +42,8 @@ def test_fft_trust_remote_code(self, temp_dir): }, ], "num_epochs": 1, - "micro_batch_size": 8, + "max_steps": 5, + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, @@ -64,3 +61,46 @@ def test_fft_trust_remote_code(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() + + def test_fix_untrained_tokens(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "fix_untrained_tokens": True, + "sequence_len": 512, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()