From 54060d54e246bfdb7b80cc06470b25cbe4b7401f Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 2 Jul 2024 16:31:24 +0000 Subject: [PATCH] Turn off truncation in chat template strategy to do length checks correctly --- .../prompt_strategies/chat_template.py | 2 +- src/axolotl/utils/samplers/utils.py | 21 ++++ src/axolotl/utils/trainer.py | 102 ++++++++---------- 3 files changed, 66 insertions(+), 59 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 2c5a4a75e2..e1cf8b59e2 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -51,7 +51,7 @@ def build_prompt(self, conversation, add_generation_prompt=False): return self.tokenizer.apply_chat_template( turns, - truncation=True, + truncation=False, max_length=self.max_length, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f35..62a8299ea6 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset): lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths + + +def plot_ascii_lengths_histogram(data, title, logger): + max_value = max(data) + bucket_width = 512 + bins = np.arange(0, max_value + bucket_width, bucket_width) + histogram, _ = np.histogram(data, bins=bins) + top = " ".join(("-" * 10, title, "-" * 10)) + bottom = "-" * len(top) + logger.info(top) + scale_factor = 40 / max(histogram) + for i, value in enumerate(histogram): + lower_bound = i * bucket_width + upper_bound = (i + 1) * bucket_width - 1 + if value: + hist_bar = "□" * int(value * scale_factor) + else: + hist_bar = "x" + logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})") + logger.info(bottom) + logger.info("\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d37acd19ac..04e29961ff 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,6 +17,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram LOG = get_logger("axolotl") @@ -180,18 +181,51 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): ) -def process_datasets_for_packing(cfg, train_dataset, eval_dataset): +def _maybe_drop_sequences(cfg, dataset, ds_split_name: str): + _ds_lens = get_dataset_lengths(dataset) + plot_ascii_lengths_histogram( + data=_ds_lens, title=f"{ds_split_name} Dataset Lengths", logger=LOG + ) + min_len, max_len = np.min(_ds_lens), np.max(_ds_lens) + LOG.debug(f"min_input_len: {min_len}", main_process_only=True) + LOG.debug(f"max_input_len: {max_len}", main_process_only=True) drop_long = partial( drop_long_seq, sequence_len=cfg.sequence_len, min_sequence_len=cfg.min_sample_len or 2, ) - if cfg.is_preprocess: - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + len_pre_drop = len(dataset) + dataset = dataset.filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc=f"Dropping Long Sequences From {ds_split_name} Dataset", + ) + dropped_rows = len_pre_drop - len(dataset) + if dropped_rows > 0: + LOG.warning(f"Dropped {dropped_rows} rows from {ds_split_name} dataset") + if not cfg.drop_long_sequences: + raise ValueError( + f"Found {dropped_rows} sequences longer than {cfg.sequence_len} tokens in {ds_split_name} Dataset. " + f"Longest sequence is {max_len} tokens. " + f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences." + ) + len_pre_drop = len(dataset) + dataset = dataset.filter( + drop_no_outputs, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Sequences Without Outputs", + ) + dropped_rows = len_pre_drop - len(dataset) + if dropped_rows > 0: + LOG.warning( + f"Dropped {dropped_rows} rows with no outputs from {ds_split_name} Dataset" + ) + return dataset + +def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if ( cfg.is_mistral_derived_model and cfg.flash_attention ) or cfg.model_config_type == "mamba": @@ -207,62 +241,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - _len_pre_drop = len(train_dataset) - train_dataset = train_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences From Train Dataset", - ) - _dropped_rows = _len_pre_drop - len(train_dataset) - if _dropped_rows > 0: - LOG.warning(f"Dropped {_dropped_rows} rows from train dataset") - if not cfg.drop_long_sequences: - raise ValueError( - f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in train dataset. " - f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences." - ) - - _len_pre_drop = len(train_dataset) - train_dataset = train_dataset.filter( - drop_no_outputs, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Sequences Without Outputs", + train_dataset = _maybe_drop_sequences( + cfg=cfg, dataset=train_dataset, ds_split_name="Train" ) - _dropped_rows = _len_pre_drop - len(train_dataset) - if _dropped_rows > 0: - LOG.warning(f"Dropped {_dropped_rows} rows with no outputs from train dataset") if eval_dataset: - _len_pre_drop = len(eval_dataset) - eval_dataset = eval_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences From Eval Dataset", + eval_dataset = _maybe_drop_sequences( + cfg=cfg, dataset=eval_dataset, ds_split_name="Eval" ) - _dropped_rows = _len_pre_drop - len(eval_dataset) - if _dropped_rows > 0: - LOG.warning(f"Dropped {_dropped_rows} rows from eval dataset") - if not cfg.drop_long_sequences: - raise ValueError( - f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in eval dataset. " - f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences." - ) - - _len_pre_drop = len(eval_dataset) - eval_dataset = eval_dataset.filter( - drop_no_outputs, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Sequences Without Outputs", - ) - _dropped_rows = _len_pre_drop - len(eval_dataset) - if _dropped_rows > 0: - LOG.warning( - f"Dropped {_dropped_rows} rows with no outputs from eval dataset" - ) if cfg.group_by_length: train_dataset = train_dataset.map(