From bd2a594b8954103719f8d1ef739e2c3267ca36f6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 17:46:44 -0500 Subject: [PATCH] use DataCollatorWithFlattening when not sample packing (#2167) --- docs/config.qmd | 3 + src/axolotl/core/trainer_builder.py | 13 +++- .../config/models/input/v0_4_1/__init__.py | 26 +++++++ tests/e2e/test_llama.py | 39 +++++++++++ tests/patched/test_validation.py | 70 +++++++++++++++++++ 5 files changed, 149 insertions(+), 2 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index d52170959..70679791e 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -245,6 +245,9 @@ sample_packing_group_size: 100000 # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. sample_packing_bin_size: 200 +# Use batch flattening for speedups when not using sample_packing +batch_flattening: + # Passed through to transformers when loading the model when launched without accelerate # Use `sequential` when training w/ model parallelism to limit memory device_map: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0f30f511c..54ee19536 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( + DataCollatorWithFlattening, EarlyStoppingCallback, Trainer, TrainerCallback, @@ -1989,9 +1990,11 @@ def build_collator( V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + DataCollatorWithFlattening, RewardDataCollatorWithPadding, ] ] + collator_args = [self.tokenizer] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding if "max_length" in kwargs: @@ -2011,12 +2014,18 @@ def build_collator( collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template + elif self.cfg.batch_flattening: + collator = DataCollatorWithFlattening + collator_args.pop(0) + kwargs.pop("pad_to_multiple_of", None) + kwargs.pop("padding", None) else: collator = DataCollatorForSeq2Seq + kwargs["return_tensors"] = "pt" + return collator( - self.tokenizer, - return_tensors="pt", + *collator_args, **kwargs, ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 69baf9af2..5ddf04811 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -696,6 +696,8 @@ class Config: curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None + batch_flattening: Optional[Union[Literal["auto"], bool]] = None + # for PoSE context length extension use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None @@ -924,6 +926,30 @@ def check_sample_packing_wo_flash(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_batch_flattening_fa(cls, data): + if data.get("batch_flattening"): + batch_flattening_auto = data.get("batch_flattening") == "auto" + if not data.get("flash_attention") and not batch_flattening_auto: + raise ValueError("batch_flattening requires flash attention") + if data.get("sample_packing") and not batch_flattening_auto: + raise ValueError("batch_flattening not compatible with sample_packing") + if data.get("micro_batch_size") == 1 and not batch_flattening_auto: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + + if ( + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 + ): + data["batch_flattening"] = True + elif batch_flattening_auto: + data["batch_flattening"] = False + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_w_rl(cls, data): diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 33d12157a..1ce9d60b9 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -104,3 +104,42 @@ def test_fix_untrained_tokens(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() + + def test_batch_flattening(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": False, + "batch_flattening": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 3d1b74789..9d41dac76 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1236,6 +1236,76 @@ def test_torch_compile_auto(self, minimal_cfg): assert updated_cfg.torch_compile is False +class TestSampleOptimConfigValidation(BaseValidation): + """ + test configurations for sample optimizations like batch flattening + """ + + def test_batch_flattening_auto_enables(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is True + + def test_batch_flattening_auto_no_fa(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": False, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_mbsz_1(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 1, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": True, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + class TestValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available