Skip to content

Commit

Permalink
Add seq len validation
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 9, 2024
1 parent bf3e3b6 commit 24e870d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
57 changes: 42 additions & 15 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
HF Chat Templates prompt strategy
"""

import functools
import logging
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -64,14 +64,16 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):

if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

if self.processor:
text = self.processor.apply_chat_template(
turns,
_apply_chat_template = functools.partial(
self.processor.apply_chat_template,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
text = _apply_chat_template(
turns,
tokenize=False,
)
batch = self.processor(
text=text,
images=images,
Expand All @@ -85,15 +87,27 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
batch["num_tokens_pre_truncation"] = len(
_apply_chat_template(turns, tokenize=True)
)
return batch

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
_apply_chat_template = functools.partial(
self.tokenizer.apply_chat_template,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
inputs = _apply_chat_template(
turns,
truncation=True,
)
return {
"input_ids": inputs,
"num_tokens_pre_truncation": len(
_apply_chat_template(turns, truncation=False)
),
}

def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
Expand Down Expand Up @@ -237,20 +251,29 @@ def tokenize_prompt(self, prompt):
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
prompt_tokenized = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
all_turns_tokenized = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
if "attention_mask" not in all_turns_tokenized:
prompt_ids = prompt_tokenized["input_ids"]
input_ids = (
prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :]
)
tokenized_prompt["input_ids"] = input_ids
num_tokens_pre_truncation = all_turns_tokenized[
"num_tokens_pre_truncation"
]
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
input_ids = all_turns_tokenized["input_ids"]
num_tokens_pre_truncation = all_turns_tokenized[
"num_tokens_pre_truncation"
]
tokenized_prompt = all_turns_tokenized

if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
Expand All @@ -259,11 +282,14 @@ def tokenize_prompt(self, prompt):
labels = input_ids

tokenized_prompt["labels"] = labels
tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation

return tokenized_prompt

turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
tokenized_res = self.prompter.build_prompt(turns)
input_ids = tokenized_res["input_ids"]
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
labels = [IGNORE_TOKEN_ID] * len(input_ids)

last_eos_idx = -1
Expand Down Expand Up @@ -342,6 +368,7 @@ def tokenize_prompt(self, prompt):
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
"num_tokens_pre_truncation": num_tokens_pre_truncation,
}

def find_eos_token(self, input_ids, start_idx):
Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset):
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths


def plot_ascii_lengths_histogram(data, title, logger):
max_value = max(data)
bucket_width = 512
bins = np.arange(0, max_value + bucket_width, bucket_width)
histogram, _ = np.histogram(data, bins=bins)
top = " ".join(("-" * 10, title, "-" * 10))
bottom = "-" * len(top)
logger.info(top)
scale_factor = 40 / max(histogram)
for i, value in enumerate(histogram):
lower_bound = i * bucket_width
upper_bound = (i + 1) * bucket_width - 1
if value:
hist_bar = "□" * int(value * scale_factor)
else:
hist_bar = "x"
logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})")
logger.info(bottom)
logger.info("\n")
42 changes: 42 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram

LOG = get_logger("axolotl")

Expand Down Expand Up @@ -203,6 +204,47 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

if cfg.get("drop_long_sequences") is False:
if "num_tokens_pre_truncation" not in train_dataset:
raise ValueError(
"`drop_long_sequences` is set to False but `num_tokens_pre_truncation` is missing from dataset"
)
plot_ascii_lengths_histogram(
data=train_dataset["num_tokens_pre_truncation"],
title="Train Dataset lengths",
logger=LOG,
)
num_longer_seqs = sum(
1
for seq_len in train_dataset["num_tokens_pre_truncation"]
if seq_len > cfg.sequence_len
)
max_len = max(train_dataset["num_tokens_pre_truncation"])
if num_longer_seqs > 0:
raise ValueError(
f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Train Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)

plot_ascii_lengths_histogram(
data=eval_dataset["num_tokens_pre_truncation"],
title="Eval Dataset lengths",
logger=LOG,
)
num_longer_seqs = sum(
1
for seq_len in eval_dataset["num_tokens_pre_truncation"]
if seq_len > cfg.sequence_len
)
max_len = max(eval_dataset["num_tokens_pre_truncation"])
if num_longer_seqs > 0:
raise ValueError(
f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Eval Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
Expand Down

0 comments on commit 24e870d

Please sign in to comment.