From 5878daa3beec58bf4f4d21a6abd6dba3c40e74f4 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 29 Oct 2024 20:37:51 +0000 Subject: [PATCH] Misc improvements Enable callbacks injection from plugins Fix misc issues with axolotl plugins Fix remote code checking Enable loss average across devices Add seq len validation Enhance sequence lens validation Remove legacy code for patching _get_unpad_data Add pre truncation token counting for completion Fix plugin callbacks duplication Enable eval on start Read extra hf args from cfg --- src/axolotl/core/trainer_builder.py | 4 +- src/axolotl/logging_config.py | 8 +- .../prompt_strategies/alpaca_w_system.py | 6 ++ .../prompt_strategies/chat_template.py | 39 ++++++-- src/axolotl/prompt_tokenizers.py | 28 +++++- src/axolotl/train.py | 10 +- src/axolotl/utils/data/sft.py | 16 +++- src/axolotl/utils/samplers/utils.py | 21 ++++ src/axolotl/utils/trainer.py | 96 +++++++++++++++++++ 9 files changed, 207 insertions(+), 21 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee195361..66d3d529b9 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1524,7 +1524,7 @@ def build(self, total_num_steps): else max(min(int(0.005 * total_num_steps), 10), 1) ) - training_arguments_kwargs = {} + training_arguments_kwargs = self.cfg.get("extra_hf_training_args") or {} if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: @@ -2046,7 +2046,7 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def build_training_arguments(self, total_num_steps): - training_args_kwargs = {} + training_args_kwargs = self.cfg.get("extra_hf_training_args") or {} for arg in [ "adam_beta1", "adam_beta2", diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 2ddf89a8c4..a83c152946 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -54,11 +54,17 @@ def format(self, record): "filters": [], "stream": sys.stdout, }, + "file": { + "class": "logging.FileHandler", + "formatter": "simple", + "filename": "train.log", + "mode": "w", + }, }, "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, "loggers": { "axolotl": { - "handlers": ["color_console"], + "handlers": ["color_console", "file"], "level": "DEBUG", "propagate": False, }, diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc07435..b844cc08b8 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt): tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + if "num_tokens_pre_truncation" in tokenized_prompt: + tokenized_prompt["num_tokens_pre_truncation"] = ( + tokenized_prompt["num_tokens_pre_truncation"] + + tokenized_res_prompt["num_tokens_pre_truncation"] + ) + return tokenized_prompt diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d75..f77a99ab78 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -67,19 +67,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): images=images, return_tensors="pt", ) + # dict_keys(['input_ids', 'attention_mask', 'pixel_values']) # workaround since processor works in batches instead of single examples for k, val in batch.items(): if k in ["pixel_values"]: batch[k] = val.tolist() else: batch[k] = val.squeeze().tolist() + batch["num_tokens_pre_truncation"] = len(batch["input_ids"]) return batch - return self.tokenizer.apply_chat_template( + input_ids = self.tokenizer.apply_chat_template( conversation, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, ) + return { + "input_ids": input_ids, + "num_tokens_pre_truncation": len(input_ids), + } def get_offsets_for_train_detail( self, text: str, train_details: List[Dict], mask_untrainable: bool = True @@ -230,20 +236,29 @@ def tokenize_prompt(self, prompt): ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_tokenized = self.prompter.build_prompt( turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + all_turns_tokenized = self.prompter.build_prompt(turns, images=images) tokenized_prompt = {} - if isinstance(tokenized_res, list): - input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + if "attention_mask" not in all_turns_tokenized: + prompt_ids = prompt_tokenized["input_ids"] + input_ids = ( + prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :] + ) tokenized_prompt["input_ids"] = input_ids + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: - input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + input_ids = all_turns_tokenized["input_ids"] + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] + tokenized_prompt = all_turns_tokenized if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -252,11 +267,14 @@ def tokenize_prompt(self, prompt): labels = input_ids tokenized_prompt["labels"] = labels + tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns) + input_ids = tokenized_res["input_ids"] + num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"] labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -333,6 +351,7 @@ def tokenize_prompt(self, prompt): "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), + "num_tokens_pre_truncation": num_tokens_pre_truncation, } def find_first_eos_token(self, input_ids, start_idx): @@ -369,10 +388,10 @@ def find_turn(self, turns: list[dict], turn_idx: int): turns_with_content = turns[: turn_idx + 1] # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore + dummy_ids = self.prompter.build_prompt(turns_with_empty)["input_ids"] # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore + full_ids = self.prompter.build_prompt(turns_with_content)["input_ids"] # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dce..6b3efe1a1f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,6 +1,7 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc +import functools import logging from typing import Dict, List, Tuple, Union @@ -60,18 +61,23 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding( + data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0} + ) if not prompt: LOG.warning("Empty text requested for tokenization.") return empty - result = self.tokenizer( - prompt, - truncation=True, + _tokenize = functools.partial( + self.tokenizer, max_length=self.max_length, padding=False, return_tensors=None, ) + result = _tokenize( + prompt, + truncation=True, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") return empty @@ -89,6 +95,20 @@ def _tokenize( result["attention_mask"] = result["attention_mask"][1:] result["labels"] = result["input_ids"].copy() + + _all_tokens = _tokenize(prompt, truncation=False) + num_tokens_pre_truncation = len(_all_tokens["input_ids"]) + if ( + _all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id + and add_eos_token + ): + num_tokens_pre_truncation += 1 + if ( + _all_tokens["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + num_tokens_pre_truncation -= 1 + result["num_tokens_pre_truncation"] = num_tokens_pre_truncation return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b093..84e5094383 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -22,6 +22,7 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -95,6 +96,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_model_load(cfg, model) if model.generation_config is not None: model.generation_config.do_sample = True @@ -144,7 +147,7 @@ def train( model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: + if cfg.local_rank == 0 and cfg.get("save_model_on_interrupt", True): def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -290,6 +293,11 @@ def terminate_handler(_, __, model_weakref): # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + if cfg.deepspeed: + trainer.deepspeed.destroy() + trainer.accelerator.free_memory() + trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None + return model, tokenizer diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 286e5f2d70..8ea5aa16a9 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -48,7 +48,11 @@ retry_on_request_exceptions, ) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.distributed import ( + compute_and_broadcast, + is_local_main_process, + zero_first, +) from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -125,9 +129,15 @@ def prepare_dataset(cfg, tokenizer, processor=None): if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: - raise ValueError( - "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + LOG.warning( + "eval dataset split is too small for sample_packing. Setting `eval_sample_packing to False`." ) + if cfg.world_size > 1: + _eval_sample_packing = compute_and_broadcast(lambda: 0) + if _eval_sample_packing < 1: + cfg.eval_sample_packing = False + else: + cfg.eval_sample_packing = False if cfg.max_steps: total_num_steps = min( diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f35..62a8299ea6 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset): lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths + + +def plot_ascii_lengths_histogram(data, title, logger): + max_value = max(data) + bucket_width = 512 + bins = np.arange(0, max_value + bucket_width, bucket_width) + histogram, _ = np.histogram(data, bins=bins) + top = " ".join(("-" * 10, title, "-" * 10)) + bottom = "-" * len(top) + logger.info(top) + scale_factor = 40 / max(histogram) + for i, value in enumerate(histogram): + lower_bound = i * bucket_width + upper_bound = (i + 1) * bucket_width - 1 + if value: + hist_bar = "□" * int(value * scale_factor) + else: + hist_bar = "x" + logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})") + logger.info(bottom) + logger.info("\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..795019147b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -19,6 +19,7 @@ from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram LOG = get_logger("axolotl") @@ -203,6 +204,17 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") + # TODO (chiragjn): The validation of sequence lengths should be done at the caller of this function + # This function is only called when `cfg.sample_packing` is True + drop_long = ( + _validate_datasets_sequence_lengths( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + or drop_long + ) + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_long, @@ -226,6 +238,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if dropped: LOG.warning(f"Dropped {dropped} long samples from eval dataset") + train_dataset, eval_dataset = _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, + ) + # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 @@ -526,3 +543,82 @@ def setup_trainer( trainer_builder.eval_dataset = eval_dataset return trainer_builder.build(total_num_steps) + + +def _drop_long_seq(sample, max_sequence_len, min_sequence_len): + return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len + + +def _validate_dataset_sequence_lengths( + dataset, + dataset_type, + sequence_len, + long_sequences_strategy, +): + if "num_tokens_pre_truncation" not in dataset.features: + raise ValueError( + f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset" + ) + plot_ascii_lengths_histogram( + data=dataset["num_tokens_pre_truncation"], + title=f"{dataset_type} Dataset lengths", + logger=LOG, + ) + num_longer_seqs = sum( + 1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len + ) + max_len = max(dataset["num_tokens_pre_truncation"]) + if num_longer_seqs > 0: + message = f"""\ +Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset. +Longest sequence is {max_len} tokens.""" + if long_sequences_strategy == "error": + raise ValueError( + f"{message}\n" + f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences." + ) + + LOG.warning(f"{message}\n" f"These sequences will be dropped.") + + +def _validate_datasets_sequence_lengths( + cfg, + train_dataset, + eval_dataset, +): + long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate") + if long_sequences_strategy in ["drop", "error"]: + _validate_dataset_sequence_lengths( + dataset=train_dataset, + dataset_type="Train", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if eval_dataset: + _validate_dataset_sequence_lengths( + dataset=eval_dataset, + dataset_type="Eval", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if long_sequences_strategy == "drop": + drop_long = partial( + _drop_long_seq, + min_sequence_len=cfg.min_sample_len or 2, + max_sequence_len=cfg.sequence_len, + ) + return drop_long + + return None + + +def _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, +): + if "num_tokens_pre_truncation" in train_dataset.features: + train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"]) + if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"]) + + return train_dataset, eval_dataset