From cd089f9819f36d23832b9f1a6baa8d2f69ff8577 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 16:02:55 -0400 Subject: [PATCH] support min sample len and define num chunks --- .../utils/config/models/input/v0_4_1/__init__.py | 2 ++ src/axolotl/utils/trainer.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b5727d9629..0fb794ba32 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -503,6 +503,7 @@ class Config: unfrozen_parameters: Optional[List[str]] = None sequence_len: int = Field(default=512) + min_sample_len: Optional[int] = None sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None @@ -511,6 +512,7 @@ class Config: use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None pose_max_context_len: Optional[int] = None + pose_num_chunks: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 80df4ccd24..95b6f11be9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -175,7 +175,11 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + drop_long = partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len or 2, + ) with zero_first(is_main_process()): if cfg.is_preprocess: min_input_len = np.min(get_dataset_lengths(train_dataset)) @@ -221,10 +225,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): ) if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks pose_fn = partial( add_pose_position_ids, max_context_len=cfg.pose_max_context_len, split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, ) train_dataset = train_dataset.map( pose_fn,