Skip to content

Commit

Permalink
when carrying out iterative spin, need to update reference model with…
Browse files Browse the repository at this point in the history
… policy after each iteration
  • Loading branch information
lucidrains committed Jan 31, 2024
1 parent d403689 commit f5058f2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
20 changes: 18 additions & 2 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
adam_optimizer_with_linear_decay
)

from self_rewarding_lm_pytorch.spin import SPINTrainer
from self_rewarding_lm_pytorch.spin import (
SPIN,
SPINTrainer
)

from einops import rearrange, repeat

Expand Down Expand Up @@ -635,6 +638,7 @@ def __init__(
accelerate_kwargs: dict = dict(),
sft_trainer_kwargs: dict = dict(),
spin_trainer_kwargs: dict = dict(),
spin_kwargs: dict = dict(),
dpo_trainer_kwargs: dict = dict(),
):
super().__init__()
Expand Down Expand Up @@ -677,10 +681,20 @@ def __init__(

assert len(self.spin_trainers) == 0 or exists(train_sft_dataset)

self.spin = None

if num_spin_cycles > 0:
self.spin = SPIN(
model,
pad_id = pad_id,
λ = spin_λ,
**spin_kwargs
)

for _ in range(num_spin_cycles):

spin_trainer = SPINTrainer(
model,
self.spin,
accelerator = self.accelerator,
train_sft_dataset = train_sft_dataset,
valid_sft_dataset = valid_sft_dataset,
Expand Down Expand Up @@ -790,6 +804,8 @@ def forward(

spin_trainer()

self.spin.update_reference_model_with_policy()

self.save(f'spin.{spin_cycle}.ckpt.pt', overwrite = overwrite_checkpoints)


Expand Down
18 changes: 10 additions & 8 deletions self_rewarding_lm_pytorch/spin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from beartype import beartype
from beartype.typing import Optional, Callable
from beartype.typing import Optional, Callable, Union
from torchtyping import TensorType

import torch
Expand Down Expand Up @@ -140,7 +140,7 @@ def forward(
class SPINTrainer(Module):
def __init__(
self,
model: Module,
model: Union[Module, SPIN],
*,
train_sft_dataset: Dataset,
max_seq_len: int,
Expand Down Expand Up @@ -169,13 +169,15 @@ def __init__(
if not exists(self.accelerator):
self.accelerator = Accelerator(**accelerate_kwargs)

self.model = SPIN(
model,
λ = spin_λ,
pad_id = pad_id,
ref_model_ema_decay = ref_model_ema_decay
)
if not isinstance(model, SPIN):
model = SPIN(
model,
λ = spin_λ,
pad_id = pad_id,
ref_model_ema_decay = ref_model_ema_decay
)

self.model = model
self.epochs = epochs
self.train_dataloader = DataLoader(train_sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.36',
version = '0.0.37',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f5058f2

Please sign in to comment.