Skip to content

Commit

Permalink
Pretrain multipack v2 (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Apr 2, 2024
1 parent cae608f commit 5aa5097
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 7 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ gcsfs
# adlfs

trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0
13 changes: 12 additions & 1 deletion src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,24 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
Collator for multipack specific to the using the BatchSampler
"""

def __init__(self, *args, multipack_attn=True, **kwargs):
super().__init__(*args, **kwargs)
self.multipack_attn = multipack_attn

def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
if self.multipack_attn:
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features[feature])
if feature in item
]
else:
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [np.array(item) for item in features[feature]]
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,14 @@ class Config:
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None

pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
default=True,
metadata={
"help": "whether to prevent cross attention for packed sequences during pretraining",
},
)

xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def prepare_dataset(cfg, tokenizer):
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
Expand Down Expand Up @@ -816,13 +817,15 @@ def wrap_pretraining_dataset(
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
encode = functools.partial(
encode_packed_pretraining,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
Expand Down Expand Up @@ -861,14 +864,17 @@ def encode_packed_pretraining(
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = False,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]

train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
)

sampler = MultipackBatchSampler(
Expand Down
14 changes: 9 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
return train_dataset, eval_dataset


def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)

train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
)
train_dataset = train_dataset.map(
add_position_ids,
desc="Add position_id column (Pretraining Sample Packing)",
)
if skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,
desc="Add position_id column (Pretraining Sample Packing)",
)

return train_dataset


Expand Down

0 comments on commit 5aa5097

Please sign in to comment.