diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee195361..66d3d529b9 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1524,7 +1524,7 @@ def build(self, total_num_steps): else max(min(int(0.005 * total_num_steps), 10), 1) ) - training_arguments_kwargs = {} + training_arguments_kwargs = self.cfg.get("extra_hf_training_args") or {} if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: @@ -2046,7 +2046,7 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def build_training_arguments(self, total_num_steps): - training_args_kwargs = {} + training_args_kwargs = self.cfg.get("extra_hf_training_args") or {} for arg in [ "adam_beta1", "adam_beta2", diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 2ddf89a8c4..a83c152946 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -54,11 +54,17 @@ def format(self, record): "filters": [], "stream": sys.stdout, }, + "file": { + "class": "logging.FileHandler", + "formatter": "simple", + "filename": "train.log", + "mode": "w", + }, }, "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, "loggers": { "axolotl": { - "handlers": ["color_console"], + "handlers": ["color_console", "file"], "level": "DEBUG", "propagate": False, }, diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc07435..b844cc08b8 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt): tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + if "num_tokens_pre_truncation" in tokenized_prompt: + tokenized_prompt["num_tokens_pre_truncation"] = ( + tokenized_prompt["num_tokens_pre_truncation"] + + tokenized_res_prompt["num_tokens_pre_truncation"] + ) + return tokenized_prompt diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d75..f77a99ab78 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -67,19 +67,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): images=images, return_tensors="pt", ) + # dict_keys(['input_ids', 'attention_mask', 'pixel_values']) # workaround since processor works in batches instead of single examples for k, val in batch.items(): if k in ["pixel_values"]: batch[k] = val.tolist() else: batch[k] = val.squeeze().tolist() + batch["num_tokens_pre_truncation"] = len(batch["input_ids"]) return batch - return self.tokenizer.apply_chat_template( + input_ids = self.tokenizer.apply_chat_template( conversation, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, ) + return { + "input_ids": input_ids, + "num_tokens_pre_truncation": len(input_ids), + } def get_offsets_for_train_detail( self, text: str, train_details: List[Dict], mask_untrainable: bool = True @@ -230,20 +236,29 @@ def tokenize_prompt(self, prompt): ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_tokenized = self.prompter.build_prompt( turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + all_turns_tokenized = self.prompter.build_prompt(turns, images=images) tokenized_prompt = {} - if isinstance(tokenized_res, list): - input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + if "attention_mask" not in all_turns_tokenized: + prompt_ids = prompt_tokenized["input_ids"] + input_ids = ( + prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :] + ) tokenized_prompt["input_ids"] = input_ids + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: - input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + input_ids = all_turns_tokenized["input_ids"] + num_tokens_pre_truncation = all_turns_tokenized[ + "num_tokens_pre_truncation" + ] + tokenized_prompt = all_turns_tokenized if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -252,11 +267,14 @@ def tokenize_prompt(self, prompt): labels = input_ids tokenized_prompt["labels"] = labels + tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns) + input_ids = tokenized_res["input_ids"] + num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"] labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -333,6 +351,7 @@ def tokenize_prompt(self, prompt): "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), + "num_tokens_pre_truncation": num_tokens_pre_truncation, } def find_first_eos_token(self, input_ids, start_idx): @@ -369,10 +388,10 @@ def find_turn(self, turns: list[dict], turn_idx: int): turns_with_content = turns[: turn_idx + 1] # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore + dummy_ids = self.prompter.build_prompt(turns_with_empty)["input_ids"] # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore + full_ids = self.prompter.build_prompt(turns_with_content)["input_ids"] # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dce..6b3efe1a1f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,6 +1,7 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc +import functools import logging from typing import Dict, List, Tuple, Union @@ -60,18 +61,23 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding( + data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0} + ) if not prompt: LOG.warning("Empty text requested for tokenization.") return empty - result = self.tokenizer( - prompt, - truncation=True, + _tokenize = functools.partial( + self.tokenizer, max_length=self.max_length, padding=False, return_tensors=None, ) + result = _tokenize( + prompt, + truncation=True, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") return empty @@ -89,6 +95,20 @@ def _tokenize( result["attention_mask"] = result["attention_mask"][1:] result["labels"] = result["input_ids"].copy() + + _all_tokens = _tokenize(prompt, truncation=False) + num_tokens_pre_truncation = len(_all_tokens["input_ids"]) + if ( + _all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id + and add_eos_token + ): + num_tokens_pre_truncation += 1 + if ( + _all_tokens["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + num_tokens_pre_truncation -= 1 + result["num_tokens_pre_truncation"] = num_tokens_pre_truncation return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b093..84e5094383 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -22,6 +22,7 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -95,6 +96,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_model_load(cfg, model) if model.generation_config is not None: model.generation_config.do_sample = True @@ -144,7 +147,7 @@ def train( model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: + if cfg.local_rank == 0 and cfg.get("save_model_on_interrupt", True): def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -290,6 +293,11 @@ def terminate_handler(_, __, model_weakref): # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + if cfg.deepspeed: + trainer.deepspeed.destroy() + trainer.accelerator.free_memory() + trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None + return model, tokenizer diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 286e5f2d70..8ea5aa16a9 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -48,7 +48,11 @@ retry_on_request_exceptions, ) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.distributed import ( + compute_and_broadcast, + is_local_main_process, + zero_first, +) from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -125,9 +129,15 @@ def prepare_dataset(cfg, tokenizer, processor=None): if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: - raise ValueError( - "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + LOG.warning( + "eval dataset split is too small for sample_packing. Setting `eval_sample_packing to False`." ) + if cfg.world_size > 1: + _eval_sample_packing = compute_and_broadcast(lambda: 0) + if _eval_sample_packing < 1: + cfg.eval_sample_packing = False + else: + cfg.eval_sample_packing = False if cfg.max_steps: total_num_steps = min( diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f35..62a8299ea6 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset): lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths + + +def plot_ascii_lengths_histogram(data, title, logger): + max_value = max(data) + bucket_width = 512 + bins = np.arange(0, max_value + bucket_width, bucket_width) + histogram, _ = np.histogram(data, bins=bins) + top = " ".join(("-" * 10, title, "-" * 10)) + bottom = "-" * len(top) + logger.info(top) + scale_factor = 40 / max(histogram) + for i, value in enumerate(histogram): + lower_bound = i * bucket_width + upper_bound = (i + 1) * bucket_width - 1 + if value: + hist_bar = "□" * int(value * scale_factor) + else: + hist_bar = "x" + logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})") + logger.info(bottom) + logger.info("\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..795019147b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -19,6 +19,7 @@ from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram LOG = get_logger("axolotl") @@ -203,6 +204,17 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") + # TODO (chiragjn): The validation of sequence lengths should be done at the caller of this function + # This function is only called when `cfg.sample_packing` is True + drop_long = ( + _validate_datasets_sequence_lengths( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + or drop_long + ) + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_long, @@ -226,6 +238,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if dropped: LOG.warning(f"Dropped {dropped} long samples from eval dataset") + train_dataset, eval_dataset = _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, + ) + # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 @@ -526,3 +543,82 @@ def setup_trainer( trainer_builder.eval_dataset = eval_dataset return trainer_builder.build(total_num_steps) + + +def _drop_long_seq(sample, max_sequence_len, min_sequence_len): + return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len + + +def _validate_dataset_sequence_lengths( + dataset, + dataset_type, + sequence_len, + long_sequences_strategy, +): + if "num_tokens_pre_truncation" not in dataset.features: + raise ValueError( + f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset" + ) + plot_ascii_lengths_histogram( + data=dataset["num_tokens_pre_truncation"], + title=f"{dataset_type} Dataset lengths", + logger=LOG, + ) + num_longer_seqs = sum( + 1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len + ) + max_len = max(dataset["num_tokens_pre_truncation"]) + if num_longer_seqs > 0: + message = f"""\ +Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset. +Longest sequence is {max_len} tokens.""" + if long_sequences_strategy == "error": + raise ValueError( + f"{message}\n" + f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences." + ) + + LOG.warning(f"{message}\n" f"These sequences will be dropped.") + + +def _validate_datasets_sequence_lengths( + cfg, + train_dataset, + eval_dataset, +): + long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate") + if long_sequences_strategy in ["drop", "error"]: + _validate_dataset_sequence_lengths( + dataset=train_dataset, + dataset_type="Train", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if eval_dataset: + _validate_dataset_sequence_lengths( + dataset=eval_dataset, + dataset_type="Eval", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if long_sequences_strategy == "drop": + drop_long = partial( + _drop_long_seq, + min_sequence_len=cfg.min_sample_len or 2, + max_sequence_len=cfg.sequence_len, + ) + return drop_long + + return None + + +def _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, +): + if "num_tokens_pre_truncation" in train_dataset.features: + train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"]) + if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"]) + + return train_dataset, eval_dataset