From 5aa50974ce98f31324d56b609af6753ac24105fd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Apr 2024 05:42:16 -0700 Subject: [PATCH] Pretrain multipack v2 (#1470) --- requirements.txt | 1 + src/axolotl/utils/collators.py | 13 ++++++++++++- .../utils/config/models/input/v0_4_1/__init__.py | 8 ++++++++ src/axolotl/utils/data.py | 8 +++++++- src/axolotl/utils/trainer.py | 14 +++++++++----- 5 files changed, 37 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8733885d56..785ede535e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,4 @@ gcsfs # adlfs trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f +zstandard==0.22.0 diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 8512b9408c..f0a1fb1261 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -217,13 +217,24 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): Collator for multipack specific to the using the BatchSampler """ + def __init__(self, *args, multipack_attn=True, **kwargs): + super().__init__(*args, **kwargs) + self.multipack_attn = multipack_attn + def __call__(self, features, return_tensors=None): chunked_data = {} for feature in features.keys(): if feature == "length": continue if feature == "attention_mask": - arrays = [(1) * np.array(item) for item in features[feature]] + if self.multipack_attn: + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features[feature]) + if feature in item + ] + else: + arrays = [(1) * np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) else: arrays = [np.array(item) for item in features[feature]] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2850debd0c..8dfbf0e731 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -511,6 +511,14 @@ class Config: eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + pretrain_multipack_buffer_size: Optional[int] = 10_000 + pretrain_multipack_attn: Optional[bool] = Field( + default=True, + metadata={ + "help": "whether to prevent cross attention for packed sequences during pretraining", + }, + ) + xformers_attention: Optional[bool] = None sdp_attention: Optional[bool] = None s2_attention: Optional[bool] = None diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6cc27fbdbd..5f13a2a63f 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -108,6 +108,7 @@ def prepare_dataset(cfg, tokenizer): max_tokens=cfg.sequence_len, batch_size=cfg.micro_batch_size, seed=cfg.seed or 42, + buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") @@ -816,6 +817,7 @@ def wrap_pretraining_dataset( return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens * batch_size, + multipack_attn=cfg.pretrain_multipack_attn, ) encode = functools.partial( encode_packed_pretraining, @@ -823,6 +825,7 @@ def wrap_pretraining_dataset( ds_wrapper_fn, max_seq_length=max_tokens, batch_size=batch_size, + multipack_attn=cfg.pretrain_multipack_attn, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 @@ -861,6 +864,7 @@ def encode_packed_pretraining( examples: Dict[str, List], max_seq_length: int = 2048, batch_size: int = 4, + multipack_attn: Optional[bool] = False, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -868,7 +872,9 @@ def encode_packed_pretraining( train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] train_dataset = process_pretraining_datasets_for_packing( - train_dataset, max_seq_length + train_dataset, + max_seq_length, + skip_position_ids=not multipack_attn, ) sampler = MultipackBatchSampler( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index dc995fda8e..2de2c54cce 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -172,17 +172,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): return train_dataset, eval_dataset -def process_pretraining_datasets_for_packing(train_dataset, sequence_len): +def process_pretraining_datasets_for_packing( + train_dataset, sequence_len, skip_position_ids=True +): drop_long = partial(drop_long_seq, sequence_len=sequence_len) train_dataset = train_dataset.filter( drop_long, desc="Dropping Long Sequences", ) - train_dataset = train_dataset.map( - add_position_ids, - desc="Add position_id column (Pretraining Sample Packing)", - ) + if skip_position_ids: + train_dataset = train_dataset.map( + add_position_ids, + desc="Add position_id column (Pretraining Sample Packing)", + ) + return train_dataset