Skip to content

Commit

Permalink
allow for learning rate annealing for spin as well
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 27, 2024
1 parent 0e1bf95 commit 2e899a0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
6 changes: 5 additions & 1 deletion self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def freeze_all_layers_(module):

def log_prob_from_model_and_seq(model, seq, eps = 1e-20):
logits = model(seq)
log_probs = logits.softmax(dim = -1)
log_probs = logits.log_softmax(dim = -1)
seq = rearrange(seq, '... -> ... 1')
log_probs = log_probs.gather(-1, seq)
return rearrange(log_probs, '... 1 -> ...')
Expand Down Expand Up @@ -93,6 +93,10 @@ def adam_optimizer_with_linear_decay(
wd = weight_decay
)

scheduler = None
if start_learning_rate != end_learning_rate:
scheduler = LinearLR

return OptimizerWithWarmupSchedule(
optimizer = adam,
accelerator = accelerator,
Expand Down
33 changes: 18 additions & 15 deletions self_rewarding_lm_pytorch/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@

from einops import rearrange

from pytorch_custom_utils import (
get_adam_optimizer,
OptimizerWithWarmupSchedule
)

from pytorch_custom_utils.utils import (
masked_mean
)

from self_rewarding_lm_pytorch.dpo import (
adam_optimizer_with_linear_decay
)

from self_rewarding_lm_pytorch.sampling_utils import (
sample,
top_p,
Expand Down Expand Up @@ -156,8 +155,11 @@ def __init__(
accelerator_kwargs: dict = dict(),
batch_size = 16,
epochs = 2,
learning_rate = 3e-4,
start_learning_rate = 1e-6,
end_learning_rate = 1e-7,
learning_rate_num_decay_steps = 1000,
weight_decay = 0.,
adam_kwargs: dict = dict(),
temperature = 0.7,
nucleus_p = 0.9,
pad_id: int = -1,
Expand All @@ -175,13 +177,14 @@ def __init__(
self.epochs = epochs
self.train_dataloader = DataLoader(sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True)

self.optimizer = OptimizerWithWarmupSchedule(
optimizer = get_adam_optimizer(
model.parameters(),
lr = learning_rate,
wd = weight_decay
),
accelerator = self.accelerator
self.optimizer = adam_optimizer_with_linear_decay(
model,
start_learning_rate,
end_learning_rate,
num_decay_steps = learning_rate_num_decay_steps,
accelerator = self.accelerator,
weight_decay = weight_decay,
adam_kwargs = adam_kwargs
)

(
Expand Down Expand Up @@ -237,7 +240,7 @@ def save(self, path: str, overwrite: bool = False):

torch.save(pkg, str(path))

def forward(self):
def forward(self, overwrite_checkpoints: bool = True):
"""
Algorithm 1 - https://arxiv.org/abs/2401.01335v1
"""
Expand Down Expand Up @@ -289,7 +292,7 @@ def forward(self):

if self.should_checkpoint and not (self.checkpoint_every % self.steps):
checkpoint_num = self.steps // self.checkpoint_every
self.save(f'spin.ckpt.{checkpoint_num}.pt')
self.save(f'spin.ckpt.{checkpoint_num}.pt', overwrite = overwrite_checkpoints)

self.wait()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2e899a0

Please sign in to comment.