diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index baac94da80..7a6e08c645 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1492,7 +1492,7 @@ def build(self, total_num_steps): else max(min(int(0.005 * total_num_steps), 10), 1) ) - training_arguments_kwargs = {} + training_arguments_kwargs = {"average_tokens_across_devices": True} if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: @@ -2006,7 +2006,7 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def build_training_arguments(self, total_num_steps): - training_args_kwargs = {} + training_args_kwargs = {"average_tokens_across_devices": True} for arg in [ "adam_beta1", "adam_beta2", diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 2ddf89a8c4..a83c152946 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -54,11 +54,17 @@ def format(self, record): "filters": [], "stream": sys.stdout, }, + "file": { + "class": "logging.FileHandler", + "formatter": "simple", + "filename": "train.log", + "mode": "w", + }, }, "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, "loggers": { "axolotl": { - "handlers": ["color_console"], + "handlers": ["color_console", "file"], "level": "DEBUG", "propagate": False, }, diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc07435..b844cc08b8 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt): tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + if "num_tokens_pre_truncation" in tokenized_prompt: + tokenized_prompt["num_tokens_pre_truncation"] = ( + tokenized_prompt["num_tokens_pre_truncation"] + + tokenized_res_prompt["num_tokens_pre_truncation"] + ) + return tokenized_prompt diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 0946a4b8c7..e8d183b965 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1,7 +1,7 @@ """ HF Chat Templates prompt strategy """ - +import functools import logging from typing import Any, Dict, List, Optional @@ -64,14 +64,16 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): if self.drop_system_message and turns[0]["role"] == "system": turns = turns[1:] - if self.processor: - text = self.processor.apply_chat_template( - turns, + _apply_chat_template = functools.partial( + self.processor.apply_chat_template, chat_template=self.chat_template, - tokenize=False, add_generation_prompt=add_generation_prompt, ) + text = _apply_chat_template( + turns, + tokenize=False, + ) batch = self.processor( text=text, images=images, @@ -85,15 +87,27 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): batch[k] = val.tolist() else: batch[k] = val.squeeze().tolist() + batch["num_tokens_pre_truncation"] = len( + _apply_chat_template(turns, tokenize=True) + ) return batch - return self.tokenizer.apply_chat_template( - turns, - truncation=True, + _apply_chat_template = functools.partial( + self.tokenizer.apply_chat_template, max_length=self.max_length, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, ) + inputs = _apply_chat_template( + turns, + truncation=True, + ) + return { + "input_ids": inputs, + "num_tokens_pre_truncation": len( + _apply_chat_template(turns, truncation=False) + ), + } def get_offsets_for_train_detail( self, text: str, train_details: List[Dict], mask_untrainable: bool = True @@ -237,20 +251,29 @@ def tokenize_prompt(self, prompt): ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_tokenized = self.prompter.build_prompt( turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + all_turns_tokenized = self.prompter.build_prompt(turns, images=images) tokenized_prompt = {} - if isinstance(tokenized_res, list): - input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + if "attention_mask" not in all_turns_tokenized: + prompt_ids = prompt_tokenized["input_ids"] + input_ids = ( + prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :] + ) tokenized_prompt["input_ids"] = input_ids + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: - input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + input_ids = all_turns_tokenized["input_ids"] + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] + tokenized_prompt = all_turns_tokenized if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -259,11 +282,14 @@ def tokenize_prompt(self, prompt): labels = input_ids tokenized_prompt["labels"] = labels + tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation return tokenized_prompt turns = prompt[self.messages] - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns) + input_ids = tokenized_res["input_ids"] + num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"] labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -342,6 +368,7 @@ def tokenize_prompt(self, prompt): "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), + "num_tokens_pre_truncation": num_tokens_pre_truncation, } def find_eos_token(self, input_ids, start_idx): diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dce..6b3efe1a1f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,6 +1,7 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc +import functools import logging from typing import Dict, List, Tuple, Union @@ -60,18 +61,23 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding( + data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0} + ) if not prompt: LOG.warning("Empty text requested for tokenization.") return empty - result = self.tokenizer( - prompt, - truncation=True, + _tokenize = functools.partial( + self.tokenizer, max_length=self.max_length, padding=False, return_tensors=None, ) + result = _tokenize( + prompt, + truncation=True, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") return empty @@ -89,6 +95,20 @@ def _tokenize( result["attention_mask"] = result["attention_mask"][1:] result["labels"] = result["input_ids"].copy() + + _all_tokens = _tokenize(prompt, truncation=False) + num_tokens_pre_truncation = len(_all_tokens["input_ids"]) + if ( + _all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id + and add_eos_token + ): + num_tokens_pre_truncation += 1 + if ( + _all_tokens["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + num_tokens_pre_truncation -= 1 + result["num_tokens_pre_truncation"] = num_tokens_pre_truncation return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 39af9f45c9..4d68e0aa56 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -20,6 +20,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.core.tokenizer_utils import fix_untrained_tokens +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -99,6 +100,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_model_load(cfg, model) if model.generation_config is not None: model.generation_config.do_sample = True @@ -148,7 +151,7 @@ def train( model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: + if cfg.local_rank == 0 and cfg.get("save_model_on_interrupt", True): def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -289,6 +292,11 @@ def terminate_handler(_, __, model_weakref): # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + if cfg.deepspeed: + trainer.deepspeed.destroy() + trainer.accelerator.free_memory() + trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None + return model, tokenizer diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f35..62a8299ea6 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset): lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths + + +def plot_ascii_lengths_histogram(data, title, logger): + max_value = max(data) + bucket_width = 512 + bins = np.arange(0, max_value + bucket_width, bucket_width) + histogram, _ = np.histogram(data, bins=bins) + top = " ".join(("-" * 10, title, "-" * 10)) + bottom = "-" * len(top) + logger.info(top) + scale_factor = 40 / max(histogram) + for i, value in enumerate(histogram): + lower_bound = i * bucket_width + upper_bound = (i + 1) * bucket_width - 1 + if value: + hist_bar = "□" * int(value * scale_factor) + else: + hist_bar = "x" + logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})") + logger.info(bottom) + logger.info("\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..5d3edb5748 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -19,6 +19,7 @@ from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram LOG = get_logger("axolotl") @@ -203,6 +204,15 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") + drop_long = ( + _validate_datasets_sequence_lengths( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + or drop_long + ) + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_long, @@ -226,6 +236,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if dropped: LOG.warning(f"Dropped {dropped} long samples from eval dataset") + train_dataset, eval_dataset = _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, + ) + # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 @@ -526,3 +541,82 @@ def setup_trainer( trainer_builder.eval_dataset = eval_dataset return trainer_builder.build(total_num_steps) + + +def _drop_long_seq(sample, max_sequence_len, min_sequence_len): + return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len + + +def _validate_dataset_sequence_lengths( + dataset, + dataset_type, + sequence_len, + long_sequences_strategy, +): + if "num_tokens_pre_truncation" not in dataset.features: + raise ValueError( + f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset" + ) + plot_ascii_lengths_histogram( + data=dataset["num_tokens_pre_truncation"], + title=f"{dataset_type} Dataset lengths", + logger=LOG, + ) + num_longer_seqs = sum( + 1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len + ) + max_len = max(dataset["num_tokens_pre_truncation"]) + if num_longer_seqs > 0: + message = f"""\ +Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset. +Longest sequence is {max_len} tokens.""" + if long_sequences_strategy == "error": + raise ValueError( + f"{message}\n" + f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences." + ) + + LOG.warning(f"{message}\n" f"These sequences will be dropped.") + + +def _validate_datasets_sequence_lengths( + cfg, + train_dataset, + eval_dataset, +): + long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate") + if long_sequences_strategy in ["drop", "error"]: + _validate_dataset_sequence_lengths( + dataset=train_dataset, + dataset_type="Train", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if eval_dataset: + _validate_dataset_sequence_lengths( + dataset=eval_dataset, + dataset_type="Eval", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if long_sequences_strategy == "drop": + drop_long = partial( + _drop_long_seq, + min_sequence_len=cfg.min_sample_len or 2, + max_sequence_len=cfg.sequence_len, + ) + return drop_long + + return None + + +def _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, +): + if "num_tokens_pre_truncation" in train_dataset.features: + train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"]) + if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"]) + + return train_dataset, eval_dataset