Skip to content

Commit

Permalink
Turn off truncation in chat template strategy to do length checks cor…
Browse files Browse the repository at this point in the history
…rectly
  • Loading branch information
chiragjn committed Jul 2, 2024
1 parent d070b88 commit 54060d5
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_prompt(self, conversation, add_generation_prompt=False):

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
truncation=False,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset):
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths


def plot_ascii_lengths_histogram(data, title, logger):
max_value = max(data)
bucket_width = 512
bins = np.arange(0, max_value + bucket_width, bucket_width)
histogram, _ = np.histogram(data, bins=bins)
top = " ".join(("-" * 10, title, "-" * 10))
bottom = "-" * len(top)
logger.info(top)
scale_factor = 40 / max(histogram)
for i, value in enumerate(histogram):
lower_bound = i * bucket_width
upper_bound = (i + 1) * bucket_width - 1
if value:
hist_bar = "□" * int(value * scale_factor)
else:
hist_bar = "x"
logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})")
logger.info(bottom)
logger.info("\n")
102 changes: 44 additions & 58 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram

LOG = get_logger("axolotl")

Expand Down Expand Up @@ -180,18 +181,51 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def _maybe_drop_sequences(cfg, dataset, ds_split_name: str):
_ds_lens = get_dataset_lengths(dataset)
plot_ascii_lengths_histogram(
data=_ds_lens, title=f"{ds_split_name} Dataset Lengths", logger=LOG
)
min_len, max_len = np.min(_ds_lens), np.max(_ds_lens)
LOG.debug(f"min_input_len: {min_len}", main_process_only=True)
LOG.debug(f"max_input_len: {max_len}", main_process_only=True)
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
len_pre_drop = len(dataset)
dataset = dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc=f"Dropping Long Sequences From {ds_split_name} Dataset",
)
dropped_rows = len_pre_drop - len(dataset)
if dropped_rows > 0:
LOG.warning(f"Dropped {dropped_rows} rows from {ds_split_name} dataset")
if not cfg.drop_long_sequences:
raise ValueError(
f"Found {dropped_rows} sequences longer than {cfg.sequence_len} tokens in {ds_split_name} Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)
len_pre_drop = len(dataset)
dataset = dataset.filter(
drop_no_outputs,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Sequences Without Outputs",
)
dropped_rows = len_pre_drop - len(dataset)
if dropped_rows > 0:
LOG.warning(
f"Dropped {dropped_rows} rows with no outputs from {ds_split_name} Dataset"
)
return dataset


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if (
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
Expand All @@ -207,62 +241,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

_len_pre_drop = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences From Train Dataset",
)
_dropped_rows = _len_pre_drop - len(train_dataset)
if _dropped_rows > 0:
LOG.warning(f"Dropped {_dropped_rows} rows from train dataset")
if not cfg.drop_long_sequences:
raise ValueError(
f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in train dataset. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)

_len_pre_drop = len(train_dataset)
train_dataset = train_dataset.filter(
drop_no_outputs,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Sequences Without Outputs",
train_dataset = _maybe_drop_sequences(
cfg=cfg, dataset=train_dataset, ds_split_name="Train"
)
_dropped_rows = _len_pre_drop - len(train_dataset)
if _dropped_rows > 0:
LOG.warning(f"Dropped {_dropped_rows} rows with no outputs from train dataset")

if eval_dataset:
_len_pre_drop = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences From Eval Dataset",
eval_dataset = _maybe_drop_sequences(
cfg=cfg, dataset=eval_dataset, ds_split_name="Eval"
)
_dropped_rows = _len_pre_drop - len(eval_dataset)
if _dropped_rows > 0:
LOG.warning(f"Dropped {_dropped_rows} rows from eval dataset")
if not cfg.drop_long_sequences:
raise ValueError(
f"Found {_dropped_rows} sequences longer than {cfg.sequence_len} tokens in eval dataset. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)

_len_pre_drop = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_outputs,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Sequences Without Outputs",
)
_dropped_rows = _len_pre_drop - len(eval_dataset)
if _dropped_rows > 0:
LOG.warning(
f"Dropped {_dropped_rows} rows with no outputs from eval dataset"
)

if cfg.group_by_length:
train_dataset = train_dataset.map(
Expand Down

0 comments on commit 54060d5

Please sign in to comment.