diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 77a232129..fe9816dd2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,4 +1,5 @@ """Module containing the Trainer class and related functions""" + import json import math import os @@ -171,7 +172,9 @@ def add_length(sample): return sample -def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=None): + min_sequence_len = min_sequence_len or 2 + return ( len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) >= min_sequence_len @@ -182,7 +185,7 @@ def drop_long_seq_in_dataset(dataset, cfg): drop_long = partial( drop_long_seq, sequence_len=cfg.sequence_len, - min_sequence_len=cfg.min_sequence_len, + min_sequence_len=cfg.min_sample_len, ) min_input_len = np.min(get_dataset_lengths(dataset))