Skip to content

Commit

Permalink
paged optim on bnb
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Aug 3, 2024
1 parent 40d974b commit 1b60a8a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion mttl/models/get_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from mttl.logging import logger


def instantiate_bnb_optimizer(model_parameters, **kwargs):
import bitsandbytes as bnb

optimizer = bnb.optim.PagedAdamW(model_parameters, **kwargs)
return optimizer


def get_optimizer(model, args, no_decay=None):
"""
Construct optimizer based on args
Expand Down Expand Up @@ -92,7 +99,10 @@ def get_optimizer(model, args, no_decay=None):
# from transformers import AdamW # tloen uses adamw_torch
from torch.optim import AdamW

optimizer = AdamW(param_groups, eps=args.adam_epsilon)
if args.load_in_4bit or args.load_in_8bit:
optimizer = instantiate_bnb_optimizer(param_groups, eps=args.adam_epsilon)
else:
optimizer = AdamW(param_groups, eps=args.adam_epsilon)
elif optim_name.lower() == "adafactor":
optimizer = Adafactor(
param_groups,
Expand Down

0 comments on commit 1b60a8a

Please sign in to comment.