-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Optim: Add GaLore optimizer #29588
Changes from 6 commits
b31ce79
58169f1
9032635
136f104
a5483b3
887d3ad
d6f119f
3fae229
c8c50f8
2bdda68
630bd13
a871b75
cb6cd7e
51b7b29
3da3b90
9115c94
0b4ba83
3e5930e
a16d3a8
29e7e94
18ea144
7800bf1
e022bdd
830c68d
b640e98
14a89b2
6f7102d
c11cb63
3678201
fdc4b2a
e7ce9b7
91d6436
b9e338a
6ff3762
0d0440a
832f2be
898a3c5
ed3ad4a
57e7096
64ccfa6
4413f07
73dcabb
1987b7a
db2bf21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,6 +140,7 @@ | |
is_apex_available, | ||
is_bitsandbytes_available, | ||
is_datasets_available, | ||
is_galore_torch_available, | ||
is_in_notebook, | ||
is_ipex_available, | ||
is_peft_available, | ||
|
@@ -1009,7 +1010,12 @@ def create_optimizer(self): | |
}, | ||
] | ||
|
||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | ||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) | ||
|
||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` | ||
# e.g. for GaLoRe optimizer. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if "params" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("params") | ||
|
||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | ||
if optimizer_cls.__name__ == "Adam8bit": | ||
|
@@ -1032,7 +1038,9 @@ def create_optimizer(self): | |
return self.optimizer | ||
|
||
@staticmethod | ||
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | ||
def get_optimizer_cls_and_kwargs( | ||
args: TrainingArguments, model: Optional[PreTrainedModel] = None | ||
) -> Tuple[Any, Any]: | ||
""" | ||
Returns the optimizer class and optimizer parameters based on the training arguments. | ||
|
||
|
@@ -1170,6 +1178,72 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | |
optimizer_cls = torch.optim.Adagrad | ||
elif args.optim == OptimizerNames.RMSPROP: | ||
optimizer_cls = torch.optim.RMSprop | ||
elif args.optim in [ | ||
OptimizerNames.GALORE_ADAMW, | ||
OptimizerNames.GALORE_ADAMW_8BIT, | ||
OptimizerNames.GALORE_ADAFACTOR, | ||
]: | ||
if not is_galore_torch_available(): | ||
raise ValueError( | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"You need to insall `galore_torch` in order to use GaLore optimizers" | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" | ||
) | ||
|
||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need an import check here, no? 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah indeed, many things can be optimized, for now it's a really rough draft, will focus on polishing everything next! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, ping me when you're ready for a full review :) |
||
|
||
optimizer_mapping = { | ||
OptimizerNames.GALORE_ADAMW: GaLoreAdamW, | ||
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, | ||
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, | ||
} | ||
|
||
optimizer_cls = optimizer_mapping[args.optim] | ||
|
||
if args.galore_target_modules is None: | ||
raise ValueError( | ||
"You need to define a `galore_target_modules` in order to properly use GaLoRe optimizers" | ||
) | ||
|
||
if not isinstance(args.galore_target_modules, list): | ||
raise ValueError( | ||
f"`galore_target_modules` has to be a list of strings, you passed {args.galore_target_modules}" | ||
) | ||
|
||
if model is None: | ||
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") | ||
|
||
galore_params = [] | ||
for module_name, module in model.named_modules(): | ||
if not isinstance(module, nn.Linear): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should raise an error if the target module name matches but the layer type is not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like a good idea! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! I think we should pass a warning to be able to pass simple regex such as |
||
continue | ||
|
||
if not any(target_key in module_name for target_key in args.galore_target_modules): | ||
continue | ||
|
||
galore_params.append(module.weight) | ||
|
||
if len(galore_params) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @younesbelkada But it is not allowed for pure layered_adamw optimizers, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm in that case users should just pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh I think i got your point, we could indeed extend the optimizers and enable layer-wise optimizations ! This can be done in a scope of another follow up PR ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @muellerzr @amyeroberts @BenjaminBossan @pacman100 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @younesbelkada Yeah, that sounds good for me. I guess I can comment out this check locally and wait for your pr. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it makes more sense to a) have generic API for layerwise optimization using the approach presented here b) have GaLoRE c) use generic API to instantiate layerwise GaLoRE; otherwise, after you add a generic API (which it seems like you already have a lot of code to do that, I don't see anything that's GaLoRE-specific), you would have to go back and do another PR just to refactor the layerwise GaLoRE? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kiddyboots216 I think it's fine to add Galore first. We're just merging into the dev branch atm, so there aren't guarantees about API stability until we have a release. Adding it in a more general sense is more involved and will require more tests / hitting possible blockers. In this order, we can release the feature without being blocked by the development of the more general API. |
||
raise ValueError("Target modules not found ! Please make sure to pass a valid target_modules.") | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
id_galore_params = [id(p) for p in galore_params] | ||
non_galore_params = [p for p in model.parameters() if id(p) not in id_galore_params] | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore | ||
param_groups = [ | ||
{"params": non_galore_params}, | ||
{ | ||
"params": galore_params, | ||
"rank": optim_args.pop("rank", 128), | ||
"update_proj_gap": optim_args.pop("update_proj_gap", 200), | ||
"scale": optim_args.pop("scale", 0.25), | ||
"proj_type": optim_args.pop("proj_type", "std"), | ||
}, | ||
] | ||
|
||
optimizer_kwargs.update({"params": param_groups}) | ||
|
||
if args.optim == OptimizerNames.GALORE_ADAFACTOR: | ||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) | ||
else: | ||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") | ||
return optimizer_cls, optimizer_kwargs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GaLore has released an official package:
pip install galore-torch
https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#install-galore-optimizer