diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 28490e91c7..c37af689fa 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -204,46 +204,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - if cfg.get("drop_long_sequences") is False: - if "num_tokens_pre_truncation" not in train_dataset: - raise ValueError( - "`drop_long_sequences` is set to False but `num_tokens_pre_truncation` is missing from dataset" - ) - plot_ascii_lengths_histogram( - data=train_dataset["num_tokens_pre_truncation"], - title="Train Dataset lengths", - logger=LOG, - ) - num_longer_seqs = sum( - 1 - for seq_len in train_dataset["num_tokens_pre_truncation"] - if seq_len > cfg.sequence_len - ) - max_len = max(train_dataset["num_tokens_pre_truncation"]) - if num_longer_seqs > 0: - raise ValueError( - f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Train Dataset. " - f"Longest sequence is {max_len} tokens. " - f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences." - ) - - plot_ascii_lengths_histogram( - data=eval_dataset["num_tokens_pre_truncation"], - title="Eval Dataset lengths", - logger=LOG, - ) - num_longer_seqs = sum( - 1 - for seq_len in eval_dataset["num_tokens_pre_truncation"] - if seq_len > cfg.sequence_len + drop_long = ( + _validate_datasets_sequence_lengths( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, ) - max_len = max(eval_dataset["num_tokens_pre_truncation"]) - if num_longer_seqs > 0: - raise ValueError( - f"Found {num_longer_seqs} sequences longer than {cfg.sequence_len} tokens in Eval Dataset. " - f"Longest sequence is {max_len} tokens. " - f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences." - ) + or drop_long + ) train_dataset = train_dataset.filter( drop_long, @@ -259,6 +227,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Dropping Long Sequences", ) + train_dataset, eval_dataset = _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, + ) + # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 @@ -543,3 +516,82 @@ def setup_trainer( trainer_builder.eval_dataset = eval_dataset return trainer_builder.build(total_num_steps) + + +def _drop_long_seq(sample, max_sequence_len, min_sequence_len): + return min_sequence_len <= sample["num_tokens_pre_truncation"] <= max_sequence_len + + +def _validate_dataset_sequence_lengths( + dataset, + dataset_type, + sequence_len, + long_sequences_strategy, +): + if "num_tokens_pre_truncation" not in dataset.features: + raise ValueError( + f"`long_sequences_strategy` is set to {long_sequences_strategy} but `num_tokens_pre_truncation` is missing from {dataset_type} dataset" + ) + plot_ascii_lengths_histogram( + data=dataset["num_tokens_pre_truncation"], + title=f"{dataset_type} Dataset lengths", + logger=LOG, + ) + num_longer_seqs = sum( + 1 for seq_len in dataset["num_tokens_pre_truncation"] if seq_len > sequence_len + ) + max_len = max(dataset["num_tokens_pre_truncation"]) + if num_longer_seqs > 0: + message = f"""\ +Found {num_longer_seqs}/{len(dataset)} sequences longer than {sequence_len} tokens in {dataset_type} Dataset. +Longest sequence is {max_len} tokens.""" + if long_sequences_strategy == "error": + raise ValueError( + f"{message}\n" + f"Please either increase --sequence_len or set --long_sequences_strategy to `drop` to drop and ignore such sequences." + ) + + LOG.warning(f"{message}\n" f"These sequences will be dropped.") + + +def _validate_datasets_sequence_lengths( + cfg, + train_dataset, + eval_dataset, +): + long_sequences_strategy = cfg.get("long_sequences_strategy", "truncate") + if long_sequences_strategy in ["drop", "error"]: + _validate_dataset_sequence_lengths( + dataset=train_dataset, + dataset_type="Train", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if eval_dataset: + _validate_dataset_sequence_lengths( + dataset=eval_dataset, + dataset_type="Eval", + sequence_len=cfg.sequence_len, + long_sequences_strategy=long_sequences_strategy, + ) + if long_sequences_strategy == "drop": + drop_long = partial( + _drop_long_seq, + min_sequence_len=cfg.min_sample_len or 2, + max_sequence_len=cfg.sequence_len, + ) + return drop_long + + return None + + +def _drop_num_tokens_pre_truncation( + train_dataset, + eval_dataset, +): + if "num_tokens_pre_truncation" in train_dataset.features: + train_dataset = train_dataset.remove_columns(["num_tokens_pre_truncation"]) + if eval_dataset and "num_tokens_pre_truncation" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["num_tokens_pre_truncation"]) + + return train_dataset, eval_dataset