Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Optim: Add GaLore optimizer #29588

Merged
merged 44 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b31ce79
add galore v1
younesbelkada Mar 11, 2024
58169f1
add import
younesbelkada Mar 11, 2024
9032635
add tests and doc
younesbelkada Mar 11, 2024
136f104
fix doctest
younesbelkada Mar 11, 2024
a5483b3
forward contrib credits from discussions
Mar 11, 2024
887d3ad
forward contrib credits from discussions
Mar 11, 2024
d6f119f
Apply suggestions from code review
younesbelkada Mar 11, 2024
3fae229
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Mar 11, 2024
c8c50f8
fix failing tests'
younesbelkada Mar 11, 2024
2bdda68
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
630bd13
switch to `optim_target_modules` and clarify docs
younesbelkada Mar 13, 2024
a871b75
more clarification
younesbelkada Mar 13, 2024
cb6cd7e
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
51b7b29
enhance lookup logic
younesbelkada Mar 13, 2024
3da3b90
update a test to add peak memory
younesbelkada Mar 13, 2024
9115c94
add regex, all-linear and single string support
younesbelkada Mar 13, 2024
0b4ba83
add layer-wise optimization through DummyOptimizers and LRSchedulers
younesbelkada Mar 13, 2024
3e5930e
forward contrib credits from discussions and original idea
hiyouga Mar 13, 2024
a16d3a8
add a section about DDP not supported in layerwise
younesbelkada Mar 13, 2024
29e7e94
Update src/transformers/trainer.py
younesbelkada Mar 13, 2024
18ea144
fix self
younesbelkada Mar 13, 2024
7800bf1
check only if layer_wise
younesbelkada Mar 13, 2024
e022bdd
Update src/transformers/training_args.py
younesbelkada Mar 14, 2024
830c68d
oops
younesbelkada Mar 14, 2024
b640e98
make use of intervals
younesbelkada Mar 14, 2024
14a89b2
clarify comment
younesbelkada Mar 14, 2024
6f7102d
add matching tests
younesbelkada Mar 14, 2024
c11cb63
GaLoRe -> GaLore
younesbelkada Mar 14, 2024
3678201
move to `get_scheduler`
younesbelkada Mar 14, 2024
fdc4b2a
add note on docs
younesbelkada Mar 14, 2024
e7ce9b7
add a warning
younesbelkada Mar 14, 2024
91d6436
adapt a bit the docs
younesbelkada Mar 15, 2024
b9e338a
update docstring
younesbelkada Mar 15, 2024
6ff3762
support original API
younesbelkada Mar 17, 2024
0d0440a
Update docs/source/en/trainer.md
younesbelkada Mar 17, 2024
832f2be
slightly refactor
younesbelkada Mar 18, 2024
898a3c5
Update docs/source/en/trainer.md
younesbelkada Mar 18, 2024
ed3ad4a
Update src/transformers/training_args.py
younesbelkada Mar 19, 2024
57e7096
fix args parsing and add tests
younesbelkada Mar 19, 2024
64ccfa6
remove warning for regex
younesbelkada Mar 19, 2024
4413f07
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 19, 2024
73dcabb
fix type hint
younesbelkada Mar 19, 2024
1987b7a
add note about extra args
younesbelkada Mar 19, 2024
db2bf21
make `is_regex` return optional
younesbelkada Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,57 @@ trainer = Trainer(..., args=training_args)

NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.

## GaLore

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.

First make sure to install GaLore official repository:

```bash
pip install git+https://github.com/jiaweizzhao/GaLore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GaLore has released an official package: pip install galore-torch

https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#install-galore-optimizer

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
```

Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `galore_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
galore_target_modules=["attn", "mlp"]
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).

Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -324,6 +325,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


def require_galore_torch(test_case):
"""
Decorator marking a test that requires Galore. These tests are skipped when Galore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires Galore")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
78 changes: 76 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
Expand Down Expand Up @@ -1009,7 +1010,12 @@ def create_optimizer(self):
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLoRe optimizer.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
Expand All @@ -1032,7 +1038,9 @@ def create_optimizer(self):
return self.optimizer

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.

Expand Down Expand Up @@ -1170,6 +1178,72 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
]:
if not is_galore_torch_available():
raise ValueError(
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"You need to insall `galore_torch` in order to use GaLore optimizers"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
)

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an import check here, no? 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed, many things can be optimized, for now it's a really rough draft, will focus on polishing everything next!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, ping me when you're ready for a full review :)


optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
}

optimizer_cls = optimizer_mapping[args.optim]

if args.galore_target_modules is None:
raise ValueError(
"You need to define a `galore_target_modules` in order to properly use GaLoRe optimizers"
)

if not isinstance(args.galore_target_modules, list):
raise ValueError(
f"`galore_target_modules` has to be a list of strings, you passed {args.galore_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")

galore_params = []
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
Copy link
Member

@BenjaminBossan BenjaminBossan Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should raise an error if the target module name matches but the layer type is not Linear. Let's assume that a user matches N linear layers and accidentally 1 other type like Embedding, currently the embedding layer would be ignored but the user doesn't get any error message or warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a good idea!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I think we should pass a warning to be able to pass simple regex such as .*.attn.*. - lmk wdyt !

continue

if not any(target_key in module_name for target_key in args.galore_target_modules):
continue

galore_params.append(module.weight)

if len(galore_params) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada But it is not allowed for pure layered_adamw optimizers, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm in that case users should just pass adamw instead of galore_adamw_layerwise I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I think i got your point, we could indeed extend the optimizers and enable layer-wise optimizations ! This can be done in a scope of another follow up PR !

Copy link
Contributor Author

@younesbelkada younesbelkada Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @muellerzr @amyeroberts @BenjaminBossan @pacman100
The idea here would be to leverage optim_target_modules and think of a more general API to enable layer-wise optimization within Trainer, using the approach presented here with DummyOptimizer / DummyLRScheduler and post gradient hooks. I propose to do that in a separate PR but I can also do that here if you think it makes more sense to introduce both GaLoRe + per-layer optimization for all optimizers in the same PR
I think it's wiser to do it in a separate PR as layer-wise optimization might not be supported OTB for many scenarios such as DS / DDP etc;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Yeah, that sounds good for me. I guess I can comment out this check locally and wait for your pr.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it makes more sense to a) have generic API for layerwise optimization using the approach presented here b) have GaLoRE c) use generic API to instantiate layerwise GaLoRE; otherwise, after you add a generic API (which it seems like you already have a lot of code to do that, I don't see anything that's GaLoRE-specific), you would have to go back and do another PR just to refactor the layerwise GaLoRE?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiddyboots216 I think it's fine to add Galore first. We're just merging into the dev branch atm, so there aren't guarantees about API stability until we have a release. Adding it in a more general sense is more involved and will require more tests / hitting possible blockers. In this order, we can release the feature without being blocked by the development of the more general API.

raise ValueError("Target modules not found ! Please make sure to pass a valid target_modules.")
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

id_galore_params = [id(p) for p in galore_params]
non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params]
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{
"params": galore_params,
"rank": optim_args.pop("rank", 128),
"update_proj_gap": optim_args.pop("update_proj_gap", 200),
"scale": optim_args.pop("scale", 0.25),
"proj_type": optim_args.pop("proj_type", "std"),
},
]

optimizer_kwargs.update({"params": param_groups})

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class OptimizerNames(ExplicitEnum):
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"


# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
Expand Down Expand Up @@ -694,6 +697,11 @@ class TrainingArguments:
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft.
galore_target_modules (`list[str]`, *optional*):
The GaLoRe target modules, i.e. the module names that you would like to train, using GaLoRe algorithm
https://arxiv.org/abs/2403.03507
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""

framework = "pt"
Expand Down Expand Up @@ -1352,6 +1360,13 @@ class TrainingArguments:
},
)

galore_target_modules: Optional[list] = field(
default=None,
metadata={
"help": "Target modules for GaLoRE optimizer. See https://github.com/jiaweizzhao/GaLore for more details."
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
},
)

def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_jieba_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -282,6 +283,10 @@ def is_torchvision_available():
return _torchvision_available


def is_galore_torch_available():
return _galore_torch_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
66 changes: 66 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
Expand Down Expand Up @@ -114,6 +115,8 @@
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
Trainer,
TrainerState,
Expand Down Expand Up @@ -1069,6 +1072,69 @@ def test_dataloader_without_dataset(self):
trainer.train()
trainer.evaluate()

@require_galore_torch
def test_galore(self):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_galore_torch
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_8bit",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_galore_torch
def test_galore_adafactor(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()
Expand Down
Loading