From 70e5d2bda41e20faf8b017ef8dfe06d68530a43d Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 29 Oct 2024 20:37:51 +0000 Subject: [PATCH] Misc improvements Enable callbacks injection from plugins Fix misc issues with axolotl plugins Fix remote code checking Enable loss average across devices Add seq len validation Enhance sequence lens validation Remove legacy code for patching _get_unpad_data Add pre truncation token counting for completion Fix plugin callbacks duplication --- src/axolotl/core/trainer_builder.py | 4 +- src/axolotl/logging_config.py | 8 +- .../prompt_strategies/alpaca_w_system.py | 6 ++ .../prompt_strategies/chat_template.py | 57 ++++++++--- src/axolotl/prompt_tokenizers.py | 28 +++++- src/axolotl/train.py | 10 +- src/axolotl/utils/samplers/utils.py | 21 +++++ src/axolotl/utils/trainer.py | 94 +++++++++++++++++++ 8 files changed, 205 insertions(+), 23 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 93384189e9..c342275055 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1409,7 +1409,7 @@ def build(self, total_num_steps): else max(min(int(0.005 * total_num_steps), 10), 1) ) - training_arguments_kwargs = {} + training_arguments_kwargs = {"average_tokens_across_devices": True} if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: @@ -1923,7 +1923,7 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def build_training_arguments(self, total_num_steps): - training_args_kwargs = {} + training_args_kwargs = {"average_tokens_across_devices": True} for arg in [ "adam_beta1", "adam_beta2", diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 2ddf89a8c4..a83c152946 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -54,11 +54,17 @@ def format(self, record): "filters": [], "stream": sys.stdout, }, + "file": { + "class": "logging.FileHandler", + "formatter": "simple", + "filename": "train.log", + "mode": "w", + }, }, "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, "loggers": { "axolotl": { - "handlers": ["color_console"], + "handlers": ["color_console", "file"], "level": "DEBUG", "propagate": False, }, diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc07435..b844cc08b8 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt): tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + if "num_tokens_pre_truncation" in tokenized_prompt: + tokenized_prompt["num_tokens_pre_truncation"] = ( + tokenized_prompt["num_tokens_pre_truncation"] + + tokenized_res_prompt["num_tokens_pre_truncation"] + ) + return tokenized_prompt diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 0946a4b8c7..e8d183b965 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1,7 +1,7 @@ """ HF Chat Templates prompt strategy """ - +import functools import logging from typing import Any, Dict, List, Optional @@ -64,14 +64,16 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): if self.drop_system_message and turns[0]["role"] == "system": turns = turns[1:] - if self.processor: - text = self.processor.apply_chat_template( - turns, + _apply_chat_template = functools.partial( + self.processor.apply_chat_template, chat_template=self.chat_template, - tokenize=False, add_generation_prompt=add_generation_prompt, ) + text = _apply_chat_template( + turns, + tokenize=False, + ) batch = self.processor( text=text, images=images, @@ -85,15 +87,27 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): batch[k] = val.tolist() else: batch[k] = val.squeeze().tolist() + batch["num_tokens_pre_truncation"] = len( + _apply_chat_template(turns, tokenize=True) + ) return batch - return self.tokenizer.apply_chat_template( - turns, - truncation=True, + _apply_chat_template = functools.partial( + self.tokenizer.apply_chat_template, max_length=self.max_length, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, ) + inputs = _apply_chat_template( + turns, + truncation=True, + ) + return { + "input_ids": inputs, + "num_tokens_pre_truncation": len( + _apply_chat_template(turns, truncation=False) + ), + } def get_offsets_for_train_detail( self, text: str, train_details: List[Dict], mask_untrainable: bool = True @@ -237,20 +251,29 @@ def tokenize_prompt(self, prompt): ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_tokenized = self.prompter.build_prompt( turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + all_turns_tokenized = self.prompter.build_prompt(turns, images=images) tokenized_prompt = {} - if isinstance(tokenized_res, list): - input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + if "attention_mask" not in all_turns_tokenized: + prompt_ids = prompt_tokenized["input_ids"] + input_ids = ( + prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :] + ) tokenized_prompt["input_ids"] = input_ids + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: - input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + input_ids = all_turns_tokenized["input_ids"] + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] + tokenized_prompt = all_turns_tokenized if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -259,11 +282,14 @@ def tokenize_prompt(self, prompt): labels = input_ids tokenized_prompt["labels"] = labels + tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation return tokenized_prompt turns = prompt[self.messages] - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns) + input_ids = tokenized_res["input_ids"] + num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"] labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -342,6 +368,7 @@ def tokenize_prompt(self, prompt): "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), + "num_tokens_pre_truncation": num_tokens_pre_truncation, } def find_eos_token(self, input_ids, start_idx): diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dce..6b3efe1a1f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,6 +1,7 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc +import functools import logging from typing import Dict, List, Tuple, Union @@ -60,18 +61,23 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding( + data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0} + ) if not prompt: LOG.warning("Empty text requested for tokenization.") return empty - result = self.tokenizer( - prompt, - truncation=True, + _tokenize = functools.partial( + self.tokenizer, max_length=self.max_length, padding=False, return_tensors=None, ) + result = _tokenize( + prompt, + truncation=True, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") return empty @@ -89,6 +95,20 @@ def _tokenize( result["attention_mask"] = result["attention_mask"][1:] result["labels"] = result["input_ids"].copy() + + _all_tokens = _tokenize(prompt, truncation=False) + num_tokens_pre_truncation = len(_all_tokens["input_ids"]) + if ( + _all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id + and add_eos_token + ): + num_tokens_pre_truncation += 1 + if ( + _all_tokens["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + num_tokens_pre_truncation -= 1 + result["num_tokens_pre_truncation"] = num_tokens_pre_truncation return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 39af9f45c9..4d68e0aa56 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -20,6 +20,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.core.tokenizer_utils import fix_untrained_tokens +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -99,6 +100,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_model_load(cfg, model) if model.generation_config is not None: model.generation_config.do_sample = True @@ -148,7 +151,7 @@ def train( model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: + if cfg.local_rank == 0 and cfg.get("save_model_on_interrupt", True): def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -289,6 +292,11 @@ def terminate_handler(_, __, model_weakref): # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + if cfg.deepspeed: + trainer.deepspeed.destroy() + trainer.accelerator.free_memory() + trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None + return model, tokenizer diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f35..62a8299ea6 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset): lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths + + +def plot_ascii_lengths_histogram(data, title, logger): + max_value = max(data) + bucket_width = 512 + bins = np.arange(0, max_value + bucket_width, bucket_width) + histogram, _ = np.histogram(data, bins=bins) + top = " ".join(("-" * 10, title, "-" * 10)) + bottom = "-" * len(top) + logger.info(top) + scale_factor = 40 / max(histogram) + for i, value in enumerate(histogram): + lower_bound = i * bucket_width + upper_bound = (i + 1) * bucket_width - 1 + if value: + hist_bar = "□" * int(value * scale_factor) + else: + hist_bar = "x" + logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})") + logger.info(bottom) + logger.info("\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..5d3edb5748 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -19,6 +19,7 @@ from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram LOG = get_logger("axolotl") @@ -203,6 +204,15 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") + drop_long = ( + _validate_datasets_sequence_lengths( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + or drop_long + ) + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_long, @@ -226,6 +236,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if dropped: LOG.warning(f"Dropped {dropped} long samples from eval dataset") + train_dataset, eval_dataset = _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, + ) + # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 @@ -526,3 +541,82 @@ def setup_trainer( trainer_builder.eval_dataset = eval_dataset return trainer_builder.build(total_num_steps) + + +def _drop_long_seq(sample, max_sequence_len, min_sequence_len): + return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len + + +def _validate_dataset_sequence_lengths( + dataset, + dataset_type, + sequence_len, + long_sequences_strategy, +): + if "num_tokens_pre_truncation" not in dataset.features: + raise ValueError( + f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset" + ) + plot_ascii_lengths_histogram( + data=dataset["num_tokens_pre_truncation"], + title=f"{dataset_type} Dataset lengths", + logger=LOG, + ) + num_longer_seqs = sum( + 1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len + ) + max_len = max(dataset["num_tokens_pre_truncation"]) + if num_longer_seqs > 0: + message = f"""\ +Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset. +Longest sequence is {max_len} tokens.""" + if long_sequences_strategy == "error": + raise ValueError( + f"{message}\n" + f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences." + ) + + LOG.warning(f"{message}\n" f"These sequences will be dropped.") + + +def _validate_datasets_sequence_lengths( + cfg, + train_dataset, + eval_dataset, +): + long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate") + if long_sequences_strategy in ["drop", "error"]: + _validate_dataset_sequence_lengths( + dataset=train_dataset, + dataset_type="Train", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if eval_dataset: + _validate_dataset_sequence_lengths( + dataset=eval_dataset, + dataset_type="Eval", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if long_sequences_strategy == "drop": + drop_long = partial( + _drop_long_seq, + min_sequence_len=cfg.min_sample_len or 2, + max_sequence_len=cfg.sequence_len, + ) + return drop_long + + return None + + +def _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, +): + if "num_tokens_pre_truncation" in train_dataset.features: + train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"]) + if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"]) + + return train_dataset, eval_dataset