From b31ce79ca1cb6b6fe19bdb534151a4aeb81cc617 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Mar 2024 13:25:55 +0100 Subject: [PATCH 01/40] add galore v1 --- src/transformers/trainer.py | 68 ++++++++++++++++++++++++++++++- src/transformers/training_args.py | 15 +++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f32be25e5326..dd51b440f952 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1009,7 +1009,12 @@ def create_optimizer(self): }, ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLoRe optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": @@ -1032,7 +1037,9 @@ def create_optimizer(self): return self.optimizer @staticmethod - def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters based on the training arguments. @@ -1170,6 +1177,63 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls = torch.optim.Adagrad elif args.optim == OptimizerNames.RMSPROP: optimizer_cls = torch.optim.RMSprop + elif args.optim in [ + OptimizerNames.GALORE_ADAMW, + OptimizerNames.GALORE_ADAMW_8BIT, + OptimizerNames.GALORE_ADAFACTOR, + ]: + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + optimizer_mapping = { + OptimizerNames.GALORE_ADAMW: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + } + + optimizer_cls = optimizer_mapping[args.optim] + + if args.galore_target_modules is None: + raise ValueError( + "You need to define a `galore_target_modules` in order to properly use GaLoRe optimizers" + ) + + if not isinstance(args.galore_target_modules, list): + raise ValueError( + f"`galore_target_modules` has to be a list of strings, you passed {args.galore_target_modules}" + ) + + if model is None: + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + + galore_params = [] + for module_name, module in model.named_modules(): + if not isinstance(module, nn.Linear): + continue + + if not any(target_key in module_name for target_key in args.galore_target_modules): + continue + + galore_params.append(module.weight) + + if len(galore_params) == 0: + raise ValueError("Target modules not found ! Please make sure to pass a valid target_modules.") + + id_galore_params = [id(p) for p in galore_params] + non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params] + + # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore + param_groups = [ + {"params": non_galore_params}, + { + "params": galore_params, + "rank": optim_args.pop("rank", 128), + "update_proj_gap": optim_args.pop("update_proj_gap", 200), + "scale": optim_args.pop("scale", 0.25), + "proj_type": optim_args.pop("proj_type", "std"), + }, + ] + + optimizer_kwargs.update({"params": param_groups}) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 884ea8ad6fb8..4e9c78317c42 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -162,6 +162,9 @@ class OptimizerNames(ExplicitEnum): RMSPROP_BNB = "rmsprop_bnb" RMSPROP_8BIT = "rmsprop_bnb_8bit" RMSPROP_32BIT = "rmsprop_bnb_32bit" + GALORE_ADAMW = "galore_adamw" + GALORE_ADAMW_8BIT = "galore_adamw_8bit" + GALORE_ADAFACTOR = "galore_adafactor" # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @@ -694,6 +697,11 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. + galore_target_modules (`Optional[List[str]]`): + The GaLoRe target modules, i.e. the module names that you would like to train, using GaLoRe algorithm + https://arxiv.org/abs/2403.03507 + See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe + optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" """ framework = "pt" @@ -1352,6 +1360,13 @@ class TrainingArguments: }, ) + galore_target_modules: Optional[list] = field( + default=None, + metadata={ + "help": "Target modules for GaLoRE optimizer. See https://github.com/jiaweizzhao/GaLore for more details." + }, + ) + def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home From 58169f15a2a6fb8c3f157da5961bc7e7e6089f99 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 11 Mar 2024 15:23:46 +0000 Subject: [PATCH 02/40] add import --- src/transformers/trainer.py | 7 +++++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 +++++ 3 files changed, 13 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index dd51b440f952..062e7df547ab 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -140,6 +140,7 @@ is_apex_available, is_bitsandbytes_available, is_datasets_available, + is_galore_torch_available, is_in_notebook, is_ipex_available, is_peft_available, @@ -1182,6 +1183,12 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.GALORE_ADAMW_8BIT, OptimizerNames.GALORE_ADAFACTOR, ]: + if not is_galore_torch_available(): + raise ValueError( + "You need to insall `galore_torch` in order to use GaLore optimizers" + " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" + ) + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit optimizer_mapping = { diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2fe931b3f38f..c9f4f3f35d82 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -125,6 +125,7 @@ is_fsdp_available, is_ftfy_available, is_g2p_en_available, + is_galore_torch_available, is_in_notebook, is_ipex_available, is_jieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index db2278fc5f58..d12d37508045 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -76,6 +76,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _apex_available = _is_package_available("apex") _aqlm_available = _is_package_available("aqlm") _bitsandbytes_available = _is_package_available("bitsandbytes") +_galore_torch_available = _is_package_available("galore_torch") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -282,6 +283,10 @@ def is_torchvision_available(): return _torchvision_available +def is_galore_torch_available(): + return _galore_torch_available + + def is_pyctcdecode_available(): return _pyctcdecode_available From 9032635c26a74762d0938968ff8834efd2a1d829 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 11 Mar 2024 15:55:40 +0000 Subject: [PATCH 03/40] add tests and doc --- docs/source/en/trainer.md | 51 ++++++++++++++++++++++++ src/transformers/testing_utils.py | 9 +++++ src/transformers/trainer.py | 3 ++ tests/trainer/test_trainer.py | 66 +++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 65bfa4176dd2..eab0d4f2e5fe 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -252,6 +252,57 @@ trainer = Trainer(..., args=training_args) NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior. +## GaLore + +Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA. + +First make sure to install GaLore official repository: + +```bash +pip install git+https://github.com/jiaweizzhao/GaLore +``` + +Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `galore_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw", + galore_target_modules=["attn", "mlp"] +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + +You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). + +Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b333678427a8..9540ffdd9cd7 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -70,6 +70,7 @@ is_fsdp_available, is_ftfy_available, is_g2p_en_available, + is_galore_torch_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -324,6 +325,14 @@ def require_bs4(test_case): return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) +def require_galore_torch(test_case): + """ + Decorator marking a test that requires Galore. These tests are skipped when Galore isn't installed. + https://github.com/jiaweizzhao/GaLore + """ + return unittest.skipUnless(is_galore_torch_available(), "test requires Galore")(test_case) + + def require_cv2(test_case): """ Decorator marking a test that requires OpenCV. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 062e7df547ab..a31da2792845 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1241,6 +1241,9 @@ def get_optimizer_cls_and_kwargs( ] optimizer_kwargs.update({"params": param_groups}) + + if args.optim == OptimizerNames.GALORE_ADAFACTOR: + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bd704bc8b59e..743179235090 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -60,6 +60,7 @@ require_accelerate, require_bitsandbytes, require_deepspeed, + require_galore_torch, require_intel_extension_for_pytorch, require_optuna, require_peft, @@ -114,6 +115,8 @@ GPT2Config, GPT2LMHeadModel, LineByLineTextDataset, + LlamaConfig, + LlamaForCausalLM, PreTrainedModel, Trainer, TrainerState, @@ -1069,6 +1072,69 @@ def test_dataloader_without_dataset(self): trainer.train() trainer.evaluate() + @require_galore_torch + def test_galore(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw", + galore_target_modules=["attn", "mlp"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + def test_galore_adamw_8bit(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_8bit", + galore_target_modules=["attn", "mlp"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + def test_galore_adafactor(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + galore_target_modules=["attn", "mlp"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel() From 136f104c1bba0e174e50d200495b6f2dc7d99eb5 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 11 Mar 2024 16:01:26 +0000 Subject: [PATCH 04/40] fix doctest --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4e9c78317c42..63797f355834 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -697,7 +697,7 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. - galore_target_modules (`Optional[List[str]]`): + galore_target_modules (`list[str]`, *optional*): The GaLoRe target modules, i.e. the module names that you would like to train, using GaLoRe algorithm https://arxiv.org/abs/2403.03507 See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe From a5483b361a2767349d6df9e2d1870b4f1a3b9016 Mon Sep 17 00:00:00 2001 From: Maxime Date: Mon, 11 Mar 2024 16:02:37 +0000 Subject: [PATCH 05/40] forward contrib credits from discussions From 887d3adcde8b99c15e04dfb0e0a862f61fb97ed6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Mar 2024 16:03:07 +0000 Subject: [PATCH 06/40] forward contrib credits from discussions From d6f119fbeb650e177c5c353e7b37bef8c0db207f Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:16:25 +0100 Subject: [PATCH 07/40] Apply suggestions from code review Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a31da2792845..26960bf3373f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1184,8 +1184,8 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.GALORE_ADAFACTOR, ]: if not is_galore_torch_available(): - raise ValueError( - "You need to insall `galore_torch` in order to use GaLore optimizers" + raise ImportError( + "You need to install `galore_torch` in order to use GaLore optimizers" " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" ) @@ -1223,7 +1223,7 @@ def get_optimizer_cls_and_kwargs( galore_params.append(module.weight) if len(galore_params) == 0: - raise ValueError("Target modules not found ! Please make sure to pass a valid target_modules.") + raise ValueError(f"None of the target modules were found! ({args.galore_target_modules}). Please make sure to pass a valid `target_modules`.") id_galore_params = [id(p) for p in galore_params] non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params] From c8c50f80534af25d37ddaaf2e75d157ea348d206 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 11 Mar 2024 16:22:33 +0000 Subject: [PATCH 08/40] fix failing tests' --- src/transformers/trainer.py | 4 +++- src/transformers/training_args.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 522979e4cd08..678da620729f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1223,7 +1223,9 @@ def get_optimizer_cls_and_kwargs( galore_params.append(module.weight) if len(galore_params) == 0: - raise ValueError(f"None of the target modules were found! ({args.galore_target_modules}). Please make sure to pass a valid `target_modules`.") + raise ValueError( + f"None of the target modules were found! ({args.galore_target_modules}). Please make sure to pass a valid `target_modules`." + ) id_galore_params = [id(p) for p in galore_params] non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params] diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9c65058b3ab8..fcfe8194c1d8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -699,7 +699,7 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. - galore_target_modules (`list[str]`, *optional*): + galore_target_modules (`List[str]`, *optional*): The GaLoRe target modules, i.e. the module names that you would like to train, using GaLoRe algorithm https://arxiv.org/abs/2403.03507 See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe @@ -1362,7 +1362,7 @@ class TrainingArguments: }, ) - galore_target_modules: Optional[list] = field( + galore_target_modules: Optional[List[str]] = field( default=None, metadata={ "help": "Target modules for GaLoRE optimizer. See https://github.com/jiaweizzhao/GaLore for more details." From 630bd13cd60ec11f6fde6fcbd3d4ac89ba6b6d71 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 08:08:11 +0000 Subject: [PATCH 09/40] switch to `optim_target_modules` and clarify docs --- docs/source/en/trainer.md | 6 ++++-- src/transformers/trainer.py | 12 ++++++------ src/transformers/training_args.py | 9 +++++---- tests/trainer/test_trainer.py | 6 +++--- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index eab0d4f2e5fe..7650cb4a4320 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -262,7 +262,7 @@ First make sure to install GaLore official repository: pip install git+https://github.com/jiaweizzhao/GaLore ``` -Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `galore_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): +Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): ```python import torch @@ -278,7 +278,7 @@ args = TrainingArguments( max_steps=100, per_device_train_batch_size=2, optim="galore_adamw", - galore_target_modules=["attn", "mlp"] + optim_target_modules=["attn", "mlp"] ) model_id = "google/gemma-2b" @@ -301,6 +301,8 @@ trainer.train() You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). +Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained. + Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. ## Accelerate and Trainer diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 28e63be7f37a..e699b09704c9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1199,14 +1199,14 @@ def get_optimizer_cls_and_kwargs( optimizer_cls = optimizer_mapping[args.optim] - if args.galore_target_modules is None: + if args.optim_target_modules is None: raise ValueError( - "You need to define a `galore_target_modules` in order to properly use GaLoRe optimizers" + "You need to define a `optim_target_modules` in order to properly use GaLoRe optimizers" ) - if not isinstance(args.galore_target_modules, list): + if not isinstance(args.optim_target_modules, list): raise ValueError( - f"`galore_target_modules` has to be a list of strings, you passed {args.galore_target_modules}" + f"`optim_target_modules` has to be a list of strings, you passed {args.optim_target_modules}" ) if model is None: @@ -1217,14 +1217,14 @@ def get_optimizer_cls_and_kwargs( if not isinstance(module, nn.Linear): continue - if not any(target_key in module_name for target_key in args.galore_target_modules): + if not any(target_key in module_name for target_key in args.optim_target_modules): continue galore_params.append(module.weight) if len(galore_params) == 0: raise ValueError( - f"None of the target modules were found! ({args.galore_target_modules}). Please make sure to pass a valid `target_modules`." + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." ) id_galore_params = [id(p) for p in galore_params] diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fcfe8194c1d8..4cb2c98ba716 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -699,11 +699,12 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. - galore_target_modules (`List[str]`, *optional*): - The GaLoRe target modules, i.e. the module names that you would like to train, using GaLoRe algorithm + optim_target_modules (`List[str]`, *optional*): + The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLoRe algorithm https://arxiv.org/abs/2403.03507 See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe - optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" + optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules + only. """ framework = "pt" @@ -1362,7 +1363,7 @@ class TrainingArguments: }, ) - galore_target_modules: Optional[List[str]] = field( + optim_target_modules: Optional[List[str]] = field( default=None, metadata={ "help": "Target modules for GaLoRE optimizer. See https://github.com/jiaweizzhao/GaLore for more details." diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 743179235090..95429d95d41b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1086,7 +1086,7 @@ def test_galore(self): learning_rate=1e-9, logging_steps=5, optim="galore_adamw", - galore_target_modules=["attn", "mlp"], + optim_target_modules=["attn", "mlp"], ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) @@ -1107,7 +1107,7 @@ def test_galore_adamw_8bit(self): learning_rate=1e-9, logging_steps=5, optim="galore_adamw_8bit", - galore_target_modules=["attn", "mlp"], + optim_target_modules=["attn", "mlp"], ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) @@ -1128,7 +1128,7 @@ def test_galore_adafactor(self): learning_rate=1e-9, logging_steps=5, optim="galore_adafactor", - galore_target_modules=["attn", "mlp"], + optim_target_modules=["attn", "mlp"], ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) From a871b75acc9ae1d695fd809fcc3f2d11b8266c75 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 08:12:37 +0000 Subject: [PATCH 10/40] more clarification --- docs/source/en/trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 7650cb4a4320..3efaa648eacc 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -301,7 +301,7 @@ trainer.train() You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). -Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained. +Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner. Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. From 51b7b29219b8a892aecca087497d43936b396101 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 08:21:55 +0000 Subject: [PATCH 11/40] enhance lookup logic --- src/transformers/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e699b09704c9..5bd54169775c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1213,6 +1213,7 @@ def get_optimizer_cls_and_kwargs( raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") galore_params = [] + galore_params_names = [] for module_name, module in model.named_modules(): if not isinstance(module, nn.Linear): continue @@ -1221,14 +1222,14 @@ def get_optimizer_cls_and_kwargs( continue galore_params.append(module.weight) + galore_params_names.append(module_name + ".weight") if len(galore_params) == 0: raise ValueError( f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." ) - id_galore_params = [id(p) for p in galore_params] - non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params] + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore param_groups = [ From 3da3b90d38eb083206ca3e8b1d846edb51de5792 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 08:35:36 +0000 Subject: [PATCH 12/40] update a test to add peak memory --- tests/trainer/test_trainer.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 95429d95d41b..bd2f0f68a0ed 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -149,6 +149,31 @@ def __getitem__(self, i): return result +# Converting Bytes to Megabytes +def bytes2megabytes(x): + return int(x / 2**20) + + +# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68 +class TorchTracemalloc: + def __enter__(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + return self + + def __exit__(self, *exc): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = bytes2megabytes(self.end - self.begin) + self.peaked = bytes2megabytes(self.peak - self.begin) + + @dataclasses.dataclass class RegressionTrainingArguments(TrainingArguments): a: float = 0.0 @@ -1073,6 +1098,7 @@ def test_dataloader_without_dataset(self): trainer.evaluate() @require_galore_torch + @require_torch_gpu def test_galore(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1094,6 +1120,7 @@ def test_galore(self): _ = trainer.train() @require_galore_torch + @require_torch_gpu def test_galore_adamw_8bit(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1115,13 +1142,16 @@ def test_galore_adamw_8bit(self): _ = trainer.train() @require_galore_torch + @require_torch_gpu def test_galore_adafactor(self): + upper_bound_pm = 700 + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) x = torch.randint(0, 100, (128,)) train_dataset = RepeatDataset(x) - with tempfile.TemporaryDirectory() as tmpdir: + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: # Trainer without inf/nan filter args = TrainingArguments( tmpdir, @@ -1135,6 +1165,9 @@ def test_galore_adafactor(self): # Check this works _ = trainer.train() + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel() From 9115c94d385db042073ff72194e98a7a90c96845 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 09:46:26 +0000 Subject: [PATCH 13/40] add regex, all-linear and single string support --- src/transformers/trainer.py | 12 +++++-- src/transformers/trainer_utils.py | 29 +++++++++++++++ tests/trainer/test_trainer.py | 60 +++++++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5bd54169775c..1a87eb2af972 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -111,6 +111,7 @@ RemoveColumnsCollator, TrainerMemoryTracker, TrainOutput, + check_target_module_exists, default_compute_objective, denumpify_detensorize, enable_full_determinism, @@ -1204,21 +1205,26 @@ def get_optimizer_cls_and_kwargs( "You need to define a `optim_target_modules` in order to properly use GaLoRe optimizers" ) - if not isinstance(args.optim_target_modules, list): + if not isinstance(args.optim_target_modules, (list, str)): raise ValueError( - f"`optim_target_modules` has to be a list of strings, you passed {args.optim_target_modules}" + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" ) if model is None: raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + galore_params = [] galore_params_names = [] for module_name, module in model.named_modules(): if not isinstance(module, nn.Linear): continue - if not any(target_key in module_name for target_key in args.optim_target_modules): + if not check_target_module_exists(args.optim_target_modules, module_name) and not all_linear: continue galore_params.append(module.weight) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5d528317e54f..b6494bc6138b 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -785,3 +785,32 @@ def _remove_columns(self, feature: dict) -> dict: def __call__(self, features: List[dict]): features = [self._remove_columns(feature) for feature in features] return self.data_collator(features) + + +def check_target_module_exists(optim_target_modules, key: str): + """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules. + + Args: + optim_target_modules (`Union[str, List[str]]`): + A list of strings to try to match. Can be also a full string. + key (`str`): + A key to search any matches in optim_target_modules + + Returns: + `bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or + None if no match found + """ + if isinstance(optim_target_modules, str): + target_module_found = re.fullmatch(optim_target_modules, key) + elif key in optim_target_modules: + # this module is specified directly in target_modules + target_module_found = True + else: + target_module_found = any(key.endswith(f".{target_key}") for target_key in optim_target_modules) + # Check also if the user passed a list of regex + if not target_module_found: + target_module_found = any( + bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules + ) + + return target_module_found diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bd2f0f68a0ed..4eeadace5ff5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1112,7 +1112,7 @@ def test_galore(self): learning_rate=1e-9, logging_steps=5, optim="galore_adamw", - optim_target_modules=["attn", "mlp"], + optim_target_modules=[r".*attn.*", r".*mlp.*"], ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) @@ -1134,7 +1134,7 @@ def test_galore_adamw_8bit(self): learning_rate=1e-9, logging_steps=5, optim="galore_adamw_8bit", - optim_target_modules=["attn", "mlp"], + optim_target_modules=[r".*attn.*", r".*mlp.*"], ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) @@ -1158,7 +1158,61 @@ def test_galore_adafactor(self): learning_rate=1e-9, logging_steps=5, optim="galore_adafactor", - optim_target_modules=["attn", "mlp"], + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + + @require_galore_torch + @require_torch_gpu + def test_galore_adafactor_attention_only(self): + upper_bound_pm = 700 + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + optim_target_modules=["q_proj", "k_proj", "v_proj"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + + @require_galore_torch + @require_torch_gpu + def test_galore_adafactor_all_linear(self): + upper_bound_pm = 700 + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + optim_target_modules="all-linear", ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) From 0b4ba83886e74b543c9be41f3db0d37b79b78b84 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 12:03:18 +0100 Subject: [PATCH 14/40] add layer-wise optimization through DummyOptimizers and LRSchedulers --- docs/source/en/trainer.md | 37 ++++++++++++ src/transformers/trainer.py | 86 +++++++++++++++++++++++++--- src/transformers/trainer_pt_utils.py | 45 +++++++++++++++ src/transformers/training_args.py | 3 + tests/trainer/test_trainer.py | 45 +++++++++++++++ 5 files changed, 209 insertions(+), 7 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 3efaa648eacc..d1c918369d98 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -305,6 +305,43 @@ Currently you can only train Linear layers that are considered as GaLore layers Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. +You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below: + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw_layerwise", + optim_target_modules=["attn", "mlp"] +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1a87eb2af972..ded8f276c855 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -83,6 +83,8 @@ DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, + LayerWiseDummyOptimizer, + LayerWiseDummyScheduler, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, @@ -1018,6 +1020,11 @@ def create_optimizer(self): if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes @@ -1183,6 +1190,9 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.GALORE_ADAMW, OptimizerNames.GALORE_ADAMW_8BIT, OptimizerNames.GALORE_ADAFACTOR, + OptimizerNames.GALORE_ADAMW_LAYERWISE, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, ]: if not is_galore_torch_available(): raise ImportError( @@ -1196,6 +1206,9 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.GALORE_ADAMW: GaLoreAdamW, OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, } optimizer_cls = optimizer_mapping[args.optim] @@ -1213,11 +1226,17 @@ def get_optimizer_cls_and_kwargs( if model is None: raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + logger.warning( + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" + ) + all_linear = ( isinstance(args.optim_target_modules, str) and args.optim_target_modules.replace("_", "-") == "all-linear" ) + is_layerwise = args.optim.lower().endswith("layerwise") + galore_params = [] galore_params_names = [] for module_name, module in model.named_modules(): @@ -1237,18 +1256,47 @@ def get_optimizer_cls_and_kwargs( non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] + galore_optim_kwargs = { + "rank": optim_args.pop("rank", 128), + "update_proj_gap": optim_args.pop("update_proj_gap", 200), + "scale": optim_args.pop("scale", 0.25), + "proj_type": optim_args.pop("proj_type", "std"), + } + # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore param_groups = [ {"params": non_galore_params}, - { - "params": galore_params, - "rank": optim_args.pop("rank", 128), - "update_proj_gap": optim_args.pop("update_proj_gap", 200), - "scale": optim_args.pop("scale", 0.25), - "proj_type": optim_args.pop("proj_type", "std"), - }, + {"params": galore_params, **galore_optim_kwargs}, ] + if is_layerwise: + # For layer-wise optimizers, the optimization step is done through post accumulation + # gradient hooks. The trick is to first attach these hooks to the model parameters then + # create a dummy optimizer that will perform no-ops in the Trainer. + # See the original implementation or the nice implementation from @hiyouga + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + if args.gradient_accumulation_steps != 1: + raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !") + + optimizer_dict = {} + for param in non_galore_params: + param_groups = [{"params": [param]}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + for param in galore_params: + param_groups = [{"params": [param], **galore_optim_kwargs}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + + def optimizer_hook(param): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in model.parameters(): + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + optimizer_kwargs.update({"params": param_groups}) if args.optim == OptimizerNames.GALORE_ADAFACTOR: @@ -1265,6 +1313,30 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim Args: num_training_steps (int): The number of training steps to do. """ + if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + self.args.lr_scheduler_type, + optimizer=optimizer_dict[param], + num_warmup_steps=self.args.get_warmup_steps(num_training_steps) * 2, + num_training_steps=num_training_steps * 2, + ) + + def scheduler_hook(param): + # Since the optimizer hook has been already attached we only need to + # attach the scheduler hook + if param.grad is not None: + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook) + + self._created_lr_scheduler = True + self.lr_scheduler = LayerWiseDummyScheduler() + if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 34d2c8416b59..394d29411d4d 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -34,6 +34,7 @@ import torch import torch.distributed as dist from torch import nn +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data.distributed import DistributedSampler @@ -1226,3 +1227,47 @@ def from_json_file(cls, json_file): def to_dict(self): return copy.deepcopy(self.__dict__) + + +class LayerWiseDummyOptimizer(torch.optim.Optimizer): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization + step is already done through the post gradient hooks. Therefore + the trick is to create a dummy optimizer that can take arbitrary + args and kwargs and return a no-op during training. + + Initial idea from @hiyouga in LLaMA-Factory: + https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + """ + + def __init__(self, optimizer_dict=None, *args, **kwargs): + dummy_tensor = torch.randn(1, 1) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": 1e-03}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure=None) -> Optional[float]: + pass + + +class LayerWiseDummyScheduler(LRScheduler): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step + are already done through the post gradient hooks. Therefore + the trick is to create a dummy scheduler that can take arbitrary + args and kwargs and return a no-op during training. + """ + + def __init__(self, *args, **kwargs): + optimizer = LayerWiseDummyOptimizer() + last_epoch = -1 + verbose = False + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return self.base_lrs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4cb2c98ba716..b74a305ac4b7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -167,6 +167,9 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAMW = "galore_adamw" GALORE_ADAMW_8BIT = "galore_adamw_8bit" GALORE_ADAFACTOR = "galore_adafactor" + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4eeadace5ff5..0e0fd4c1b09a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1119,6 +1119,51 @@ def test_galore(self): # Check this works _ = trainer.train() + @require_galore_torch + @require_torch_gpu + def test_galore_layerwise(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_layerwise", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_layerwise_with_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_layerwise", + lr_scheduler_type="cosine", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + @require_galore_torch @require_torch_gpu def test_galore_adamw_8bit(self): From 3e5930ef34b8b96135e71dfc0d09da8566a3fb2d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 13 Mar 2024 12:04:04 +0100 Subject: [PATCH 15/40] forward contrib credits from discussions and original idea From a16d3a87ba45bc46bd6bc0221f96f4b84ee2f95b Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 14:31:30 +0000 Subject: [PATCH 16/40] add a section about DDP not supported in layerwise --- docs/source/en/trainer.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index d1c918369d98..7006b3cea9f6 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -342,6 +342,8 @@ trainer = trl.SFTTrainer( trainer.train() ``` +Note layerwise optimization does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). From 29e7e94f380afbc78e825a29459ec50085a5085d Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:59:36 +0100 Subject: [PATCH 17/40] Update src/transformers/trainer.py Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ded8f276c855..80fac40dd205 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1199,7 +1199,8 @@ def get_optimizer_cls_and_kwargs( "You need to install `galore_torch` in order to use GaLore optimizers" " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" ) - + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + raise NotImplementedError("GaLore does not support DDP at this time") from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit optimizer_mapping = { From 18ea144a17d9221f607b13e4aa1fa9aa96808064 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 15:00:07 +0000 Subject: [PATCH 18/40] fix self --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 80fac40dd205..58f600ffd05d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1199,7 +1199,7 @@ def get_optimizer_cls_and_kwargs( "You need to install `galore_torch` in order to use GaLore optimizers" " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" ) - if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if args.parallel_mode == ParallelMode.DISTRIBUTED: raise NotImplementedError("GaLore does not support DDP at this time") from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit From 7800bf1f7a792fd3b891f6b4348ec0474f67fc8b Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 15:01:52 +0000 Subject: [PATCH 19/40] check only if layer_wise --- src/transformers/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 58f600ffd05d..26245ed18adc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1199,10 +1199,12 @@ def get_optimizer_cls_and_kwargs( "You need to install `galore_torch` in order to use GaLore optimizers" " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" ) - if args.parallel_mode == ParallelMode.DISTRIBUTED: - raise NotImplementedError("GaLore does not support DDP at this time") from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + is_layerwise = args.optim.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: + raise NotImplementedError("GaLore does not support DDP at this time") + optimizer_mapping = { OptimizerNames.GALORE_ADAMW: GaLoreAdamW, OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, @@ -1236,8 +1238,6 @@ def get_optimizer_cls_and_kwargs( and args.optim_target_modules.replace("_", "-") == "all-linear" ) - is_layerwise = args.optim.lower().endswith("layerwise") - galore_params = [] galore_params_names = [] for module_name, module in model.named_modules(): From e022bdda8514af98a6b008e9502a56515523e773 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 14 Mar 2024 09:52:17 +0100 Subject: [PATCH 20/40] Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b74a305ac4b7..6132a54f961d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1369,7 +1369,7 @@ class TrainingArguments: optim_target_modules: Optional[List[str]] = field( default=None, metadata={ - "help": "Target modules for GaLoRE optimizer. See https://github.com/jiaweizzhao/GaLore for more details." + "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment. }, ) From 830c68dfac792d90415f475b2bcc751150720510 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:14:22 +0000 Subject: [PATCH 21/40] oops --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6132a54f961d..25c1b75a1179 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1369,7 +1369,7 @@ class TrainingArguments: optim_target_modules: Optional[List[str]] = field( default=None, metadata={ - "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment. + "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." }, ) From b640e980718dc3fb64321cdff4faa836ae7a8b10 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:24:13 +0000 Subject: [PATCH 22/40] make use of intervals --- tests/trainer/test_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0e0fd4c1b09a..2d1199175019 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1189,7 +1189,10 @@ def test_galore_adamw_8bit(self): @require_galore_torch @require_torch_gpu def test_galore_adafactor(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere upper_bound_pm = 700 + lower_bound_pm = 650 config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1211,12 +1214,17 @@ def test_galore_adafactor(self): _ = trainer.train() galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) @require_galore_torch @require_torch_gpu def test_galore_adafactor_attention_only(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere upper_bound_pm = 700 + lower_bound_pm = 650 config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1239,11 +1247,15 @@ def test_galore_adafactor_attention_only(self): galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) @require_galore_torch @require_torch_gpu def test_galore_adafactor_all_linear(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere upper_bound_pm = 700 + lower_bound_pm = 650 config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1266,6 +1278,7 @@ def test_galore_adafactor_all_linear(self): galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): From 14a89b2f0ebfb1b65145311796f4d73e6260f177 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:24:54 +0000 Subject: [PATCH 23/40] clarify comment --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 26245ed18adc..ba8bbd4ebc52 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1203,7 +1203,7 @@ def get_optimizer_cls_and_kwargs( is_layerwise = args.optim.lower().endswith("layerwise") if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: - raise NotImplementedError("GaLore does not support DDP at this time") + raise NotImplementedError("Layer-wise GaLore does not support DDP at this time") optimizer_mapping = { OptimizerNames.GALORE_ADAMW: GaLoreAdamW, From 6f7102db64907f3d8dee9234c956057cb65932f1 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:40:08 +0000 Subject: [PATCH 24/40] add matching tests --- src/transformers/trainer_utils.py | 2 +- tests/trainer/test_trainer.py | 59 ++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index b6494bc6138b..9d8cda6eafd8 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -801,7 +801,7 @@ def check_target_module_exists(optim_target_modules, key: str): None if no match found """ if isinstance(optim_target_modules, str): - target_module_found = re.fullmatch(optim_target_modules, key) + target_module_found = bool(re.fullmatch(optim_target_modules, key)) elif key in optim_target_modules: # this module is specified directly in target_modules target_module_found = True diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d1199175019..79bb053ceac9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -85,7 +85,7 @@ slow, torch_device, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -1097,6 +1097,63 @@ def test_dataloader_without_dataset(self): trainer.train() trainer.evaluate() + def test_galore_matched_modules(self): + regex_patterns = [r".*.attn.*", r".*.mlp.*"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched = check_target_module_exists(regex_patterns, module_name) + self.assertTrue(is_module_matched == expected_value) + + exact_patterns = ["q_proj", "up_proj"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched = check_target_module_exists(exact_patterns, module_name) + self.assertTrue(is_module_matched == expected_value) + + simple_regex = r".*.attn.*" + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, False] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched = check_target_module_exists(simple_regex, module_name) + self.assertTrue(is_module_matched == expected_value) + + simple_regex = "model.transformer.h.0.attn.q_proj" + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, False] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched = check_target_module_exists(simple_regex, module_name) + self.assertTrue(is_module_matched == expected_value) + @require_galore_torch @require_torch_gpu def test_galore(self): From c11cb63ef37ede478e5d3f7fbe932e89557dacf1 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:42:41 +0000 Subject: [PATCH 25/40] GaLoRe -> GaLore --- src/transformers/testing_utils.py | 4 ++-- src/transformers/trainer.py | 4 ++-- src/transformers/training_args.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index cd2066a6fa37..984657aff748 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -327,10 +327,10 @@ def require_bs4(test_case): def require_galore_torch(test_case): """ - Decorator marking a test that requires Galore. These tests are skipped when Galore isn't installed. + Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed. https://github.com/jiaweizzhao/GaLore """ - return unittest.skipUnless(is_galore_torch_available(), "test requires Galore")(test_case) + return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) def require_cv2(test_case): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ba8bbd4ebc52..89f310be08fb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1016,7 +1016,7 @@ def create_optimizer(self): optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` - # e.g. for GaLoRe optimizer. + # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") @@ -1218,7 +1218,7 @@ def get_optimizer_cls_and_kwargs( if args.optim_target_modules is None: raise ValueError( - "You need to define a `optim_target_modules` in order to properly use GaLoRe optimizers" + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" ) if not isinstance(args.optim_target_modules, (list, str)): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 25c1b75a1179..69d628ff09f2 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -703,7 +703,7 @@ class TrainingArguments: [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. optim_target_modules (`List[str]`, *optional*): - The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLoRe algorithm + The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm https://arxiv.org/abs/2403.03507 See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules From 3678201911ca016ce0528827cfe9241069d6a1bd Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:48:14 +0000 Subject: [PATCH 26/40] move to `get_scheduler` --- src/transformers/optimization.py | 27 +++++++++++++++++++++++++++ src/transformers/trainer.py | 25 ------------------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 65a41d1b1a44..1c76ddda9f0b 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -24,6 +24,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau +from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler from .trainer_utils import SchedulerType from .utils import logging from .utils.versions import require_version @@ -362,6 +363,32 @@ def get_scheduler( """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and + # recursively call `get_scheduler` to get the proper schedulers on each parameter + if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + name, + optimizer=optimizer_dict[param], + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + def scheduler_hook(param): + # Since the optimizer hook has been already attached we only need to + # attach the scheduler hook + if param.grad is not None: + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook) + + return LayerWiseDummyScheduler() + if name == SchedulerType.CONSTANT: return schedule_func(optimizer) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 89f310be08fb..21fbef2ab11e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -84,7 +84,6 @@ IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, - LayerWiseDummyScheduler, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, @@ -1314,30 +1313,6 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim Args: num_training_steps (int): The number of training steps to do. """ - if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): - optimizer_dict = optimizer.optimizer_dict - scheduler_dict = {} - - for param in optimizer_dict.keys(): - scheduler_dict[param] = get_scheduler( - self.args.lr_scheduler_type, - optimizer=optimizer_dict[param], - num_warmup_steps=self.args.get_warmup_steps(num_training_steps) * 2, - num_training_steps=num_training_steps * 2, - ) - - def scheduler_hook(param): - # Since the optimizer hook has been already attached we only need to - # attach the scheduler hook - if param.grad is not None: - scheduler_dict[param].step() - - for param in optimizer_dict.keys(): - param.register_post_accumulate_grad_hook(scheduler_hook) - - self._created_lr_scheduler = True - self.lr_scheduler = LayerWiseDummyScheduler() - if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, From fdc4b2a2fdbeb03332e48ed5af8f6515a1e2a3f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:54:20 +0000 Subject: [PATCH 27/40] add note on docs --- docs/source/en/trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 7006b3cea9f6..32e07b043101 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -342,7 +342,7 @@ trainer = trl.SFTTrainer( trainer.train() ``` -Note layerwise optimization does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. +Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue. ## Accelerate and Trainer From e7ce9b7d3915f903908af53ce76658c74369205a Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Thu, 14 Mar 2024 09:58:30 +0000 Subject: [PATCH 28/40] add a warning --- src/transformers/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 21fbef2ab11e..53a1b0f43400 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1241,6 +1241,12 @@ def get_optimizer_cls_and_kwargs( galore_params_names = [] for module_name, module in model.named_modules(): if not isinstance(module, nn.Linear): + # Warn in case we match but it's not a linear layer + if check_target_module_exists(args.optim_target_modules, module_name): + logger.warning( + f"{module_name} has been matched but ignored as GaLore only supports linear layers. If you passed a regex `.*.attn.*` this is expected, otherwise please double check your `optim_target_modules`!" + ) + continue if not check_target_module_exists(args.optim_target_modules, module_name) and not all_linear: From 91d6436808695d0a99960aa64cb0905e1d6b5863 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Fri, 15 Mar 2024 09:57:34 +0000 Subject: [PATCH 29/40] adapt a bit the docs --- docs/source/en/trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 32e07b043101..47d1efd7eb90 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -262,7 +262,7 @@ First make sure to install GaLore official repository: pip install git+https://github.com/jiaweizzhao/GaLore ``` -Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): +Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regew or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): ```python import torch From b9e338a01d83ed4a4f040d2aef49f58dec894708 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Fri, 15 Mar 2024 09:58:38 +0000 Subject: [PATCH 30/40] update docstring --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 69d628ff09f2..49557f43c704 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -702,7 +702,7 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. - optim_target_modules (`List[str]`, *optional*): + optim_target_modules (`Union[str, List[str]]`, *optional*): The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm https://arxiv.org/abs/2403.03507 See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe From 6ff37620ecac75b4f986cb3dde451ef1ab32ef0e Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Sun, 17 Mar 2024 10:11:33 +0000 Subject: [PATCH 31/40] support original API --- src/transformers/trainer_utils.py | 2 +- tests/trainer/test_trainer.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 9d8cda6eafd8..50e08951d336 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -806,7 +806,7 @@ def check_target_module_exists(optim_target_modules, key: str): # this module is specified directly in target_modules target_module_found = True else: - target_module_found = any(key.endswith(f".{target_key}") for target_key in optim_target_modules) + target_module_found = any(target_key in key for target_key in optim_target_modules) # Check also if the user passed a list of regex if not target_module_found: target_module_found = any( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 79bb053ceac9..042d21f17509 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1154,6 +1154,20 @@ def test_galore_matched_modules(self): is_module_matched = check_target_module_exists(simple_regex, module_name) self.assertTrue(is_module_matched == expected_value) + target_modules = ["attn", "mlp"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched = check_target_module_exists(target_modules, module_name) + self.assertTrue(is_module_matched == expected_value) + @require_galore_torch @require_torch_gpu def test_galore(self): From 0d0440a1f2523aa67026de89761dadb8f3719d32 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Sun, 17 Mar 2024 17:57:21 +0100 Subject: [PATCH 32/40] Update docs/source/en/trainer.md --- docs/source/en/trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 47d1efd7eb90..7a807c58f57b 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -259,7 +259,7 @@ Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training st First make sure to install GaLore official repository: ```bash -pip install git+https://github.com/jiaweizzhao/GaLore +pip install galore-torch ``` Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regew or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): From 832f2be94ebc91203fcecb827bef01df554d1175 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 18 Mar 2024 14:51:52 +0000 Subject: [PATCH 33/40] slightly refactor --- src/transformers/trainer_utils.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 50e08951d336..cf42d655de7f 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -800,17 +800,15 @@ def check_target_module_exists(optim_target_modules, key: str): `bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or None if no match found """ + target_module_found = False if isinstance(optim_target_modules, str): target_module_found = bool(re.fullmatch(optim_target_modules, key)) - elif key in optim_target_modules: + elif key in optim_target_modules: # from here, target_module_found must be a list of str # this module is specified directly in target_modules target_module_found = True - else: - target_module_found = any(target_key in key for target_key in optim_target_modules) - # Check also if the user passed a list of regex - if not target_module_found: - target_module_found = any( - bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules - ) + elif any(target_key in key for target_key in optim_target_modules): + target_module_found = True + elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules): + target_module_found = True return target_module_found From 898a3c5a4bec8ca48451a2e509bad694596a38fa Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:34:03 +0100 Subject: [PATCH 34/40] Update docs/source/en/trainer.md Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> --- docs/source/en/trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 7a807c58f57b..61cb42016edc 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -262,7 +262,7 @@ First make sure to install GaLore official repository: pip install galore-torch ``` -Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regew or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): +Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): ```python import torch From ed3ad4ad757bcde08f9c4598f98c38df2e0953b4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 19 Mar 2024 09:21:42 +0100 Subject: [PATCH 35/40] Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 49557f43c704..e644ace94e58 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1366,7 +1366,7 @@ class TrainingArguments: }, ) - optim_target_modules: Optional[List[str]] = field( + optim_target_modules: Optional[Union[str, List[str]]] = field( default=None, metadata={ "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." From 57e7096eb9f397dc27e98d2d31470f1fc18436ce Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 19 Mar 2024 08:29:15 +0000 Subject: [PATCH 36/40] fix args parsing and add tests --- src/transformers/trainer.py | 6 +++--- tests/trainer/test_trainer.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 53a1b0f43400..53a7c72b4c13 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1263,9 +1263,9 @@ def get_optimizer_cls_and_kwargs( non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] galore_optim_kwargs = { - "rank": optim_args.pop("rank", 128), - "update_proj_gap": optim_args.pop("update_proj_gap", 200), - "scale": optim_args.pop("scale", 0.25), + "rank": int(optim_args.pop("rank", 128)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), "proj_type": optim_args.pop("proj_type", "std"), } diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 042d21f17509..c715fe3b3236 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1190,6 +1190,29 @@ def test_galore(self): # Check this works _ = trainer.train() + @require_galore_torch + @require_torch_gpu + def test_galore_extra_args(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw", + optim_args="rank=64, update_proj_gap=100, scale=0.10", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + @require_galore_torch @require_torch_gpu def test_galore_layerwise(self): From 64ccfa6b6728d7c99a8aca54abf5291abc10560f Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 19 Mar 2024 08:46:58 +0000 Subject: [PATCH 37/40] remove warning for regex --- src/transformers/trainer.py | 8 +++++--- src/transformers/trainer_utils.py | 10 ++++++++-- tests/trainer/test_trainer.py | 20 +++++++++++++++----- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 53a7c72b4c13..8ac05678501d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1240,16 +1240,18 @@ def get_optimizer_cls_and_kwargs( galore_params = [] galore_params_names = [] for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists(args.optim_target_modules, module_name) + if not isinstance(module, nn.Linear): # Warn in case we match but it's not a linear layer - if check_target_module_exists(args.optim_target_modules, module_name): + if target_module_exists and not is_regex: logger.warning( - f"{module_name} has been matched but ignored as GaLore only supports linear layers. If you passed a regex `.*.attn.*` this is expected, otherwise please double check your `optim_target_modules`!" + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" ) continue - if not check_target_module_exists(args.optim_target_modules, module_name) and not all_linear: + if not target_module_exists and not all_linear: continue galore_params.append(module.weight) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index cf42d655de7f..129ad47f9dae 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -797,12 +797,17 @@ def check_target_module_exists(optim_target_modules, key: str): A key to search any matches in optim_target_modules Returns: - `bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or + `bool` : True of match object if key matches any target modules from config, False or None if no match found + `bool` : If the matched target module is a regex to silence out the warnings in Trainer + for extra modules being found (only if `target_module_found=True` for an array of regex). """ target_module_found = False + is_regex = False + if isinstance(optim_target_modules, str): target_module_found = bool(re.fullmatch(optim_target_modules, key)) + is_regex = True if not optim_target_modules == key else False elif key in optim_target_modules: # from here, target_module_found must be a list of str # this module is specified directly in target_modules target_module_found = True @@ -810,5 +815,6 @@ def check_target_module_exists(optim_target_modules, key: str): target_module_found = True elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules): target_module_found = True + is_regex = True - return target_module_found + return target_module_found, is_regex diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c715fe3b3236..9fcd6ce8a196 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1109,8 +1109,10 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched = check_target_module_exists(regex_patterns, module_name) + is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name) self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertTrue(is_regex) exact_patterns = ["q_proj", "up_proj"] @@ -1123,8 +1125,10 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched = check_target_module_exists(exact_patterns, module_name) + is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name) self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) simple_regex = r".*.attn.*" @@ -1137,8 +1141,10 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, False] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched = check_target_module_exists(simple_regex, module_name) + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name) self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertTrue(is_regex) simple_regex = "model.transformer.h.0.attn.q_proj" @@ -1151,8 +1157,10 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, False] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched = check_target_module_exists(simple_regex, module_name) + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name) self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) target_modules = ["attn", "mlp"] @@ -1165,8 +1173,10 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched = check_target_module_exists(target_modules, module_name) + is_module_matched, is_regex = check_target_module_exists(target_modules, module_name) self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) @require_galore_torch @require_torch_gpu From 73dcabb8dec1b7532120c479b314a25810e76748 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 19 Mar 2024 09:04:03 +0000 Subject: [PATCH 38/40] fix type hint --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 7ee622432bf4..a52a77e9a766 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1366,7 +1366,7 @@ class TrainingArguments: }, ) - optim_target_modules: Optional[Union[str, List[str]]] = field( + optim_target_modules: Union[None, str, List[str]] = field( default=None, metadata={ "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." From 1987b7ae4f8f972172ce3b897a835914f20b2ac4 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 19 Mar 2024 09:40:12 +0000 Subject: [PATCH 39/40] add note about extra args --- docs/source/en/trainer.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 61cb42016edc..3d57220fe827 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -299,6 +299,44 @@ trainer = trl.SFTTrainer( trainer.train() ``` +To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example: + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw", + optim_target_modules=["attn", "mlp"], + optim_args="rank=64, update_proj_gap=100, scale=0.10", +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner. From db2bf219fd95c2ae740dea2ab2f43a03dc4e047c Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 19 Mar 2024 10:19:49 +0000 Subject: [PATCH 40/40] make `is_regex` return optional --- src/transformers/trainer.py | 4 +++- src/transformers/trainer_utils.py | 10 ++++++++-- tests/trainer/test_trainer.py | 10 +++++----- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 12a598370e04..bef4b24c517c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1241,7 +1241,9 @@ def get_optimizer_cls_and_kwargs( galore_params = [] galore_params_names = [] for module_name, module in model.named_modules(): - target_module_exists, is_regex = check_target_module_exists(args.optim_target_modules, module_name) + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) if not isinstance(module, nn.Linear): # Warn in case we match but it's not a linear layer diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 129ad47f9dae..0faf657387ba 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -787,7 +787,7 @@ def __call__(self, features: List[dict]): return self.data_collator(features) -def check_target_module_exists(optim_target_modules, key: str): +def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False): """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules. Args: @@ -795,6 +795,9 @@ def check_target_module_exists(optim_target_modules, key: str): A list of strings to try to match. Can be also a full string. key (`str`): A key to search any matches in optim_target_modules + return_is_regex (`bool`): + If set to `True`, the method will return whether the passed `optim_target_modules` + is a regex or not. Returns: `bool` : True of match object if key matches any target modules from config, False or @@ -817,4 +820,7 @@ def check_target_module_exists(optim_target_modules, key: str): target_module_found = True is_regex = True - return target_module_found, is_regex + if return_is_regex: + return target_module_found, is_regex + + return target_module_found diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9fcd6ce8a196..ebc628146b96 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1109,7 +1109,7 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name) + is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True) self.assertTrue(is_module_matched == expected_value) if is_module_matched: self.assertTrue(is_regex) @@ -1125,7 +1125,7 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name) + is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True) self.assertTrue(is_module_matched == expected_value) if is_module_matched: self.assertFalse(is_regex) @@ -1141,7 +1141,7 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, False] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name) + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True) self.assertTrue(is_module_matched == expected_value) if is_module_matched: self.assertTrue(is_regex) @@ -1157,7 +1157,7 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, False] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name) + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True) self.assertTrue(is_module_matched == expected_value) if is_module_matched: self.assertFalse(is_regex) @@ -1173,7 +1173,7 @@ def test_galore_matched_modules(self): expected_values = [False, True, False, True] for expected_value, module_name in zip(expected_values, module_names): - is_module_matched, is_regex = check_target_module_exists(target_modules, module_name) + is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True) self.assertTrue(is_module_matched == expected_value) if is_module_matched: self.assertFalse(is_regex)