From f5058f262495b3746abca3596fc39ff6e57ef82f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 31 Jan 2024 08:13:23 -0800 Subject: [PATCH] when carrying out iterative spin, need to update reference model with policy after each iteration --- .../self_rewarding_lm_pytorch.py | 20 +++++++++++++++++-- self_rewarding_lm_pytorch/spin.py | 18 +++++++++-------- setup.py | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index b9a928d..fc35b6a 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -30,7 +30,10 @@ adam_optimizer_with_linear_decay ) -from self_rewarding_lm_pytorch.spin import SPINTrainer +from self_rewarding_lm_pytorch.spin import ( + SPIN, + SPINTrainer +) from einops import rearrange, repeat @@ -635,6 +638,7 @@ def __init__( accelerate_kwargs: dict = dict(), sft_trainer_kwargs: dict = dict(), spin_trainer_kwargs: dict = dict(), + spin_kwargs: dict = dict(), dpo_trainer_kwargs: dict = dict(), ): super().__init__() @@ -677,10 +681,20 @@ def __init__( assert len(self.spin_trainers) == 0 or exists(train_sft_dataset) + self.spin = None + + if num_spin_cycles > 0: + self.spin = SPIN( + model, + pad_id = pad_id, + λ = spin_λ, + **spin_kwargs + ) + for _ in range(num_spin_cycles): spin_trainer = SPINTrainer( - model, + self.spin, accelerator = self.accelerator, train_sft_dataset = train_sft_dataset, valid_sft_dataset = valid_sft_dataset, @@ -790,6 +804,8 @@ def forward( spin_trainer() + self.spin.update_reference_model_with_policy() + self.save(f'spin.{spin_cycle}.ckpt.pt', overwrite = overwrite_checkpoints) diff --git a/self_rewarding_lm_pytorch/spin.py b/self_rewarding_lm_pytorch/spin.py index 9e87f83..1e60b7e 100644 --- a/self_rewarding_lm_pytorch/spin.py +++ b/self_rewarding_lm_pytorch/spin.py @@ -1,7 +1,7 @@ from pathlib import Path from beartype import beartype -from beartype.typing import Optional, Callable +from beartype.typing import Optional, Callable, Union from torchtyping import TensorType import torch @@ -140,7 +140,7 @@ def forward( class SPINTrainer(Module): def __init__( self, - model: Module, + model: Union[Module, SPIN], *, train_sft_dataset: Dataset, max_seq_len: int, @@ -169,13 +169,15 @@ def __init__( if not exists(self.accelerator): self.accelerator = Accelerator(**accelerate_kwargs) - self.model = SPIN( - model, - λ = spin_λ, - pad_id = pad_id, - ref_model_ema_decay = ref_model_ema_decay - ) + if not isinstance(model, SPIN): + model = SPIN( + model, + λ = spin_λ, + pad_id = pad_id, + ref_model_ema_decay = ref_model_ema_decay + ) + self.model = model self.epochs = epochs self.train_dataloader = DataLoader(train_sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True) diff --git a/setup.py b/setup.py index 2192266..60bdebf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.36', + version = '0.0.37', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',