Skip to content

Commit

Permalink
support min sample len and define num chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 25, 2024
1 parent 8700784 commit cd089f9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ class Config:
unfrozen_parameters: Optional[List[str]] = None

sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
Expand All @@ -511,6 +512,7 @@ class Config:
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None

pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
Expand Down
10 changes: 9 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
Expand Down Expand Up @@ -221,10 +225,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
)

if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
Expand Down

0 comments on commit cd089f9

Please sign in to comment.