diff --git a/self_rewarding_lm_pytorch/dpo.py b/self_rewarding_lm_pytorch/dpo.py index fa8b9c1..ffed4ca 100644 --- a/self_rewarding_lm_pytorch/dpo.py +++ b/self_rewarding_lm_pytorch/dpo.py @@ -52,7 +52,7 @@ def freeze_all_layers_(module): def log_prob_from_model_and_seq(model, seq, eps = 1e-20): logits = model(seq) - log_probs = logits.softmax(dim = -1) + log_probs = logits.log_softmax(dim = -1) seq = rearrange(seq, '... -> ... 1') log_probs = log_probs.gather(-1, seq) return rearrange(log_probs, '... 1 -> ...') @@ -93,6 +93,10 @@ def adam_optimizer_with_linear_decay( wd = weight_decay ) + scheduler = None + if start_learning_rate != end_learning_rate: + scheduler = LinearLR + return OptimizerWithWarmupSchedule( optimizer = adam, accelerator = accelerator, diff --git a/self_rewarding_lm_pytorch/spin.py b/self_rewarding_lm_pytorch/spin.py index 024b1f5..16a3752 100644 --- a/self_rewarding_lm_pytorch/spin.py +++ b/self_rewarding_lm_pytorch/spin.py @@ -16,15 +16,14 @@ from einops import rearrange -from pytorch_custom_utils import ( - get_adam_optimizer, - OptimizerWithWarmupSchedule -) - from pytorch_custom_utils.utils import ( masked_mean ) +from self_rewarding_lm_pytorch.dpo import ( + adam_optimizer_with_linear_decay +) + from self_rewarding_lm_pytorch.sampling_utils import ( sample, top_p, @@ -156,8 +155,11 @@ def __init__( accelerator_kwargs: dict = dict(), batch_size = 16, epochs = 2, - learning_rate = 3e-4, + start_learning_rate = 1e-6, + end_learning_rate = 1e-7, + learning_rate_num_decay_steps = 1000, weight_decay = 0., + adam_kwargs: dict = dict(), temperature = 0.7, nucleus_p = 0.9, pad_id: int = -1, @@ -175,13 +177,14 @@ def __init__( self.epochs = epochs self.train_dataloader = DataLoader(sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True) - self.optimizer = OptimizerWithWarmupSchedule( - optimizer = get_adam_optimizer( - model.parameters(), - lr = learning_rate, - wd = weight_decay - ), - accelerator = self.accelerator + self.optimizer = adam_optimizer_with_linear_decay( + model, + start_learning_rate, + end_learning_rate, + num_decay_steps = learning_rate_num_decay_steps, + accelerator = self.accelerator, + weight_decay = weight_decay, + adam_kwargs = adam_kwargs ) ( @@ -237,7 +240,7 @@ def save(self, path: str, overwrite: bool = False): torch.save(pkg, str(path)) - def forward(self): + def forward(self, overwrite_checkpoints: bool = True): """ Algorithm 1 - https://arxiv.org/abs/2401.01335v1 """ @@ -289,7 +292,7 @@ def forward(self): if self.should_checkpoint and not (self.checkpoint_every % self.steps): checkpoint_num = self.steps // self.checkpoint_every - self.save(f'spin.ckpt.{checkpoint_num}.pt') + self.save(f'spin.ckpt.{checkpoint_num}.pt', overwrite = overwrite_checkpoints) self.wait() diff --git a/setup.py b/setup.py index f028158..0551fd7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',