From 294e9097e2c4ea642198aea5ad0561d3b647e572 Mon Sep 17 00:00:00 2001
From: Chirag Jain <jain.chirag925@gmail.com>
Date: Tue, 2 Jul 2024 16:31:24 +0000
Subject: [PATCH] Turn off truncation in chat template strategy to do length
 checks correctly

---
 .../prompt_strategies/chat_template.py        |   4 +-
 src/axolotl/utils/samplers/utils.py           |  21 ++++
 src/axolotl/utils/trainer.py                  | 110 +++++++-----------
 3 files changed, 66 insertions(+), 69 deletions(-)

diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index c97d27ea5..4322ab9da 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -366,9 +366,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
     }
 
     strategy = ChatTemplateStrategy(
-        ChatTemplatePrompter(**prompter_params), 
-        tokenizer=tokenizer, 
-        **strategy_params
+        ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
     )
 
     if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py
index e4af4e5f3..62a8299ea 100755
--- a/src/axolotl/utils/samplers/utils.py
+++ b/src/axolotl/utils/samplers/utils.py
@@ -15,3 +15,24 @@ def get_dataset_lengths(dataset):
         lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
         return lengths
     return lengths
+
+
+def plot_ascii_lengths_histogram(data, title, logger):
+    max_value = max(data)
+    bucket_width = 512
+    bins = np.arange(0, max_value + bucket_width, bucket_width)
+    histogram, _ = np.histogram(data, bins=bins)
+    top = " ".join(("-" * 10, title, "-" * 10))
+    bottom = "-" * len(top)
+    logger.info(top)
+    scale_factor = 40 / max(histogram)
+    for i, value in enumerate(histogram):
+        lower_bound = i * bucket_width
+        upper_bound = (i + 1) * bucket_width - 1
+        if value:
+            hist_bar = "□" * int(value * scale_factor)
+        else:
+            hist_bar = "x"
+        logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})")
+    logger.info(bottom)
+    logger.info("\n")
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 24cb8e10e..ec6a9cab0 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -18,6 +18,7 @@
 from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
 from axolotl.utils.distributed import reduce_and_broadcast
 from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
+from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram
 
 LOG = get_logger("axolotl")
 
@@ -181,27 +182,52 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
     )
 
 
-def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
+def _maybe_drop_sequences(cfg, dataset, ds_split_name: str):
+    _ds_lens = get_dataset_lengths(dataset)
+    plot_ascii_lengths_histogram(
+        data=_ds_lens, title=f"{ds_split_name} Dataset Lengths", logger=LOG
+    )
+    min_len, max_len = np.min(_ds_lens), np.max(_ds_lens)
+    LOG.debug(f"min_input_len: {min_len}", main_process_only=True)
+    LOG.debug(f"max_input_len: {max_len}", main_process_only=True)
     drop_long = partial(
         drop_long_seq,
         sequence_len=cfg.sequence_len,
         min_sequence_len=cfg.min_sample_len or 2,
     )
-    if cfg.is_preprocess:
-        min_input_len = np.min(get_dataset_lengths(train_dataset))
-        LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
-        max_input_len = np.max(get_dataset_lengths(train_dataset))
-        LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
+    len_pre_drop = len(dataset)
+    dataset = dataset.filter(
+        drop_long,
+        num_proc=cfg.dataset_processes,
+        load_from_cache_file=not cfg.is_preprocess,
+        desc=f"Dropping Long Sequences From {ds_split_name} Dataset",
+    )
+    dropped_rows = len_pre_drop - len(dataset)
+    if dropped_rows > 0:
+        LOG.warning(f"Dropped {dropped_rows} rows from {ds_split_name} dataset")
+        if not cfg.drop_long_sequences:
+            raise ValueError(
+                f"Found {dropped_rows} sequences longer than {cfg.sequence_len} tokens in {ds_split_name} Dataset. "
+                f"Longest sequence is {max_len} tokens. "
+                f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
+            )
+    len_pre_drop = len(dataset)
+    dataset = dataset.filter(
+        drop_no_outputs,
+        num_proc=cfg.dataset_processes,
+        load_from_cache_file=not cfg.is_preprocess,
+        desc="Dropping Sequences Without Outputs",
+    )
+    dropped_rows = len_pre_drop - len(dataset)
+    if dropped_rows > 0:
+        LOG.warning(
+            f"Dropped {dropped_rows} rows with no outputs from {ds_split_name} Dataset"
+        )
+    return dataset
 
-    if cfg.model_config_type == "mamba":
-        LOG.info("dropping attention_mask column")
-        train_dataset = train_dataset.remove_columns("attention_mask")
-        if eval_dataset:
-            eval_dataset = eval_dataset.remove_columns("attention_mask")
 
-    if (
-        cfg.is_mistral_derived_model and cfg.flash_attention
-    ) or cfg.model_config_type == "mamba":
+def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
+    if cfg.model_config_type == "mamba":
         LOG.info("dropping attention_mask column")
         train_dataset = train_dataset.remove_columns("attention_mask")
         if eval_dataset:
@@ -214,62 +240,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
         if eval_dataset and "token_type_ids" in eval_dataset.column_names:
             eval_dataset = eval_dataset.remove_columns("token_type_ids")
 
-    _len_pre_drop = len(train_dataset)
-    train_dataset = train_dataset.filter(
-        drop_long,
-        num_proc=cfg.dataset_processes,
-        load_from_cache_file=not cfg.is_preprocess,
-        desc="Dropping Long Sequences From Train Dataset",
-    )
-    _dropped_rows = _len_pre_drop - len(train_dataset)
-    if _dropped_rows > 0:
-        LOG.warning(f"Dropped {_dropped_rows} rows from train dataset")
-        if not cfg.drop_long_sequences:
-            raise ValueError(
-                f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in train dataset. "
-                f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
-            )
-
-    _len_pre_drop = len(train_dataset)
-    train_dataset = train_dataset.filter(
-        drop_no_outputs,
-        num_proc=cfg.dataset_processes,
-        load_from_cache_file=not cfg.is_preprocess,
-        desc="Dropping Sequences Without Outputs",
+    train_dataset = _maybe_drop_sequences(
+        cfg=cfg, dataset=train_dataset, ds_split_name="Train"
     )
-    _dropped_rows = _len_pre_drop - len(train_dataset)
-    if _dropped_rows > 0:
-        LOG.warning(f"Dropped {_dropped_rows} rows with no outputs from train dataset")
 
     if eval_dataset:
-        _len_pre_drop = len(eval_dataset)
-        eval_dataset = eval_dataset.filter(
-            drop_long,
-            num_proc=cfg.dataset_processes,
-            load_from_cache_file=not cfg.is_preprocess,
-            desc="Dropping Long Sequences From Eval Dataset",
+        eval_dataset = _maybe_drop_sequences(
+            cfg=cfg, dataset=eval_dataset, ds_split_name="Eval"
         )
-        _dropped_rows = _len_pre_drop - len(eval_dataset)
-        if _dropped_rows > 0:
-            LOG.warning(f"Dropped {_dropped_rows} rows from eval dataset")
-            if not cfg.drop_long_sequences:
-                raise ValueError(
-                    f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in eval dataset. "
-                    f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
-                )
-
-        _len_pre_drop = len(eval_dataset)
-        eval_dataset = eval_dataset.filter(
-            drop_no_outputs,
-            num_proc=cfg.dataset_processes,
-            load_from_cache_file=not cfg.is_preprocess,
-            desc="Dropping Sequences Without Outputs",
-        )
-        _dropped_rows = _len_pre_drop - len(eval_dataset)
-        if _dropped_rows > 0:
-            LOG.warning(
-                f"Dropped {_dropped_rows} rows with no outputs from eval dataset"
-            )
 
     if cfg.group_by_length:
         train_dataset = train_dataset.map(