diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 6b5f6bd22b..b5727d9629 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -510,6 +510,7 @@ class Config: # for PoSE context length extension use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None + pose_max_context_len: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 99ce774c55..80df4ccd24 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -223,7 +223,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.use_pose: pose_fn = partial( add_pose_position_ids, - max_context_len=cfg.sequence_len, + max_context_len=cfg.pose_max_context_len, split_on_token_ids=cfg.pose_split_on_token_ids, ) train_dataset = train_dataset.map(