Skip to content

Commit

Permalink
Enhance sequence lens validation
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 9, 2024
1 parent 24e870d commit ed1ecc3
Showing 1 changed file with 91 additions and 39 deletions.
130 changes: 91 additions & 39 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,46 +204,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

if cfg.get("drop_long_sequences") is False:
if "num_tokens_pre_truncation" not in train_dataset:
raise ValueError(
"`drop_long_sequences` is set to False but `num_tokens_pre_truncation` is missing from dataset"
)
plot_ascii_lengths_histogram(
data=train_dataset["num_tokens_pre_truncation"],
title="Train Dataset lengths",
logger=LOG,
)
num_longer_seqs = sum(
1
for seq_len in train_dataset["num_tokens_pre_truncation"]
if seq_len > cfg.sequence_len
)
max_len = max(train_dataset["num_tokens_pre_truncation"])
if num_longer_seqs > 0:
raise ValueError(
f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Train Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)

plot_ascii_lengths_histogram(
data=eval_dataset["num_tokens_pre_truncation"],
title="Eval Dataset lengths",
logger=LOG,
)
num_longer_seqs = sum(
1
for seq_len in eval_dataset["num_tokens_pre_truncation"]
if seq_len > cfg.sequence_len
drop_long = (
_validate_datasets_sequence_lengths(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
max_len = max(eval_dataset["num_tokens_pre_truncation"])
if num_longer_seqs > 0:
raise ValueError(
f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Eval Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)
or drop_long
)

train_dataset = train_dataset.filter(
drop_long,
Expand All @@ -259,6 +227,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Dropping Long Sequences",
)

train_dataset, eval_dataset = _drop_num_tokens_pre_truncation(
train_dataset,
eval_dataset,
)

# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
Expand Down Expand Up @@ -543,3 +516,82 @@ def setup_trainer(
trainer_builder.eval_dataset = eval_dataset

return trainer_builder.build(total_num_steps)


def _drop_long_seq(sample, max_sequence_len, min_sequence_len):
return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len


def _validate_dataset_sequence_lengths(
dataset,
dataset_type,
sequence_len,
long_sequences_strategy,
):
if "num_tokens_pre_truncation" not in dataset.features:
raise ValueError(
f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset"
)
plot_ascii_lengths_histogram(
data=dataset["num_tokens_pre_truncation"],
title=f"{dataset_type} Dataset lengths",
logger=LOG,
)
num_longer_seqs = sum(
1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len
)
max_len = max(dataset["num_tokens_pre_truncation"])
if num_longer_seqs > 0:
message = f"""\
Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset.
Longest sequence is {max_len} tokens."""
if long_sequences_strategy == "error":
raise ValueError(
f"{message}\n"
f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences."
)

LOG.warning(f"{message}\n" f"These sequences will be dropped.")


def _validate_datasets_sequence_lengths(
cfg,
train_dataset,
eval_dataset,
):
long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate")
if long_sequences_strategy in ["drop", "error"]:
_validate_dataset_sequence_lengths(
dataset=train_dataset,
dataset_type="Train",
sequence_len=cfg.sequence_len,
long_sequences_strategy=long_sequences_strategy,
)
if eval_dataset:
_validate_dataset_sequence_lengths(
dataset=eval_dataset,
dataset_type="Eval",
sequence_len=cfg.sequence_len,
long_sequences_strategy=long_sequences_strategy,
)
if long_sequences_strategy == "drop":
drop_long = partial(
_drop_long_seq,
min_sequence_len=cfg.min_sample_len or 2,
max_sequence_len=cfg.sequence_len,
)
return drop_long

return None


def _drop_num_tokens_pre_truncation(
train_dataset,
eval_dataset,
):
if "num_tokens_pre_truncation" in train_dataset.features:
train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"])
if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"])

return train_dataset, eval_dataset

0 comments on commit ed1ecc3

Please sign in to comment.