diff --git a/README.md b/README.md index 4b32736..349d97f 100644 --- a/README.md +++ b/README.md @@ -73,12 +73,13 @@ trainer(overwrite_checkpoints = True) ## Todo +- [x] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that +- [x] handle eos + - [ ] remove early stopper in favor of just simple few line logic - have a function that accepts List[float] and decide what to do -- [ ] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that - [ ] figure out how best to handle different impl of kv cache, for now just do without - [ ] allow for different strategies for sampling the pairs - [ ] consider KTO -- [ ] handle eos - [ ] any order of sft, spin, self-rewarding dpo, dpo with external reward model - [ ] show an example for using your own reward prompt instead of default llm-as-judge diff --git a/self_rewarding_lm_pytorch/dpo.py b/self_rewarding_lm_pytorch/dpo.py index 242f593..c9d0bb6 100644 --- a/self_rewarding_lm_pytorch/dpo.py +++ b/self_rewarding_lm_pytorch/dpo.py @@ -59,7 +59,7 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20): def prompt_mask_from_len(lengths, seq): seq_len, device = seq.shape[-1], seq.device - return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1') + return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1') def maybe_and_mask(*masks): masks = [*filter(exists, masks)] @@ -243,6 +243,8 @@ def forward( preferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None, unpreferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None ): + self.policy_model.train() + """ b - batch n - sequence length @@ -344,6 +346,12 @@ def __init__( def is_main(self): return self.accelerator.is_main_process + def wait(self): + return self.accelerator.wait_for_everyone() + + def log(self, **data): + self.accelerator.log(data, step = self.steps) + def forward( self, train_self_reward_dataset: Dataset @@ -363,19 +371,21 @@ def forward( dpo_loss = self.model(*batch) self.accelerator.backward(dpo_loss) - self.accelerator.log(dict(loss = dpo_loss.item()), step = self.steps) + self.log(loss = dpo_loss.item()) self.optimizer.step() self.optimizer.zero_grad() self.steps += 1 pbar.update(1) - self.accelerator.wait_for_everyone() + self.wait() if not (self.steps % self.check_early_stop_every) and exists(self.early_stopper): early_stop_return = self.early_stopper() + self.log(dpo_valid_score = early_stop_return.score) + if self.is_main and early_stop_return.should_stop: self.break_signal.copy_(1.) dist.all_reduce(self.break_signal) @@ -383,7 +393,7 @@ def forward( if self.break_signal.item() == 1: break - self.accelerator.wait_for_everyone() + self.wait() pbar.close() print('dpo training finished') diff --git a/self_rewarding_lm_pytorch/sampling_utils.py b/self_rewarding_lm_pytorch/sampling_utils.py index c1e3a36..8f98af1 100644 --- a/self_rewarding_lm_pytorch/sampling_utils.py +++ b/self_rewarding_lm_pytorch/sampling_utils.py @@ -2,11 +2,12 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from torch.nn.utils.rnn import pad_sequence from beartype import beartype -from beartype.typing import Optional, Callable +from beartype.typing import Optional, Callable, List, Tuple -from tqdm import tqdm +from einops import rearrange def exists(v): return v is not None @@ -57,22 +58,63 @@ def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None): @beartype def sample( net: Module, - prompt: Tensor, + prompts, seq_len: int, temperature = 1., filter_fn: Callable = top_p, - filter_kwargs: dict = dict() + filter_kwargs: dict = dict(), + pad_id: int = -1, + eos_id: Optional[int] = None, ): - prompt_seq_len, out = prompt.shape[-1], prompt.clone() - sample_num_times = max(0, seq_len - prompt_seq_len) + device = next(net.parameters()).device + net.eval() - for _ in tqdm(range(sample_num_times)): - logits = net(out) - logits = logits[:, -1] + if isinstance(prompts, (tuple, list)): + prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id) + + batch, prompts_tensor_len = prompts.shape + + batch_arange = torch.arange(batch, device = device)[..., None] + + prompt_lens = (prompts != pad_id).sum(dim = -1) + curr_seq_indices = prompt_lens[..., None] + + out = prompts.clone() + + while (curr_seq_indices < seq_len).any(): + out = F.pad(out, (0, 1), value = pad_id) + + net_input = out.masked_fill(out == pad_id, 0) + + logits = net(net_input) + + logits = logits[batch_arange, curr_seq_indices] + logits = rearrange(logits, 'b 1 d -> b d') logits = filter_fn(logits, **filter_kwargs) - sample = gumbel_sample(logits, temperature = temperature, dim = -1) + sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1) + + out[batch_arange, curr_seq_indices] = sampled_tokens + + curr_seq_indices += 1 + curr_seq_indices.clamp_(max = seq_len) + + if not exists(eos_id): + continue + + is_eos_mask = out == eos_id + all_eos = is_eos_mask.any(dim = -1).all() + + if all_eos: + break + + if exists(eos_id): + after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False) + out = out.masked_fill_(after_eos_mask, pad_id) + + prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None] - out = torch.cat((out, sample), dim = -1) + generated_seq_mask = out != pad_id & ~prompt_mask + seq_lens = generated_seq_mask.sum(dim = -1).tolist() - return out[..., prompt_seq_len:] + return out[generated_seq_mask].split(seq_lens) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 4eec610..b39292d 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -379,7 +379,7 @@ def generate_reward( reward_responses = sample( self.model, - prompt = reward_prompt, + prompts = reward_prompt.long(), seq_len = self.generate_reward_max_seq_len, temperature = self.eval_temperature, filter_fn = top_p, @@ -424,9 +424,9 @@ def forward(self) -> DPODataset: prompt_len = prompt_tensor.shape[-1] repeated_prompt_tensor = repeat(prompt_tensor, 'n -> r n', r = self.num_candidate_responses) - candidate_responses_tensor = sample( + candidate_tensor_responses = sample( self.model, - prompt = repeated_prompt_tensor, + prompts = repeated_prompt_tensor.long(), seq_len = self.preference_max_seq_len, temperature = self.gen_temperature, filter_fn = top_p, @@ -435,7 +435,8 @@ def forward(self) -> DPODataset: ) ) - candidate_responses: List[str] = [*map(self.tokenizer_decode, candidate_responses_tensor.long().tolist())] + candidate_int_responses: List[List[int]] = [response.tolist() for response in candidate_tensor_responses] + candidate_responses: List[str] = [*map(self.tokenizer_decode, candidate_int_responses)] # get rewards @@ -443,7 +444,7 @@ def forward(self) -> DPODataset: # zip together the responses and rewards and filter out if reward is not generated correctly - paired_reward_response = [(reward, candidate_response) for reward, candidate_response in zip(rewards, candidate_responses_tensor)] + paired_reward_response = [(reward, candidate_response) for reward, candidate_response in zip(rewards, candidate_tensor_responses)] paired_reward_response = [*filter(lambda pair: exists(first(pair)), paired_reward_response)] paired_reward_response.sort(key = first) diff --git a/self_rewarding_lm_pytorch/spin.py b/self_rewarding_lm_pytorch/spin.py index 86cfb23..ef1270e 100644 --- a/self_rewarding_lm_pytorch/spin.py +++ b/self_rewarding_lm_pytorch/spin.py @@ -50,7 +50,7 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20): def prompt_mask_from_len(lengths, seq): seq_len, device = seq.shape[-1], seq.device - return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1') + return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1') def maybe_and_mask(*masks): masks = [*filter(exists, masks)] @@ -97,15 +97,16 @@ def forward( self, generated_seq: TensorType['b', 'n', int], real_seq: TensorType['b', 'n', int], - prompt_len: Optional[TensorType['b', int]], + prompt_len: TensorType['b', int], generated_seq_mask: Optional[TensorType['b', 'n', bool]] = None, real_seq_mask: Optional[TensorType['b', 'n', bool]] = None ): + self.policy_model.train() + """ b - batch n - sequence length """ - assert generated_seq.ndim == real_seq.ndim == 2 real_prompt_mask = torch.arange(real_seq.shape[-1], device = self.device) < prompt_len[:, None] @@ -194,6 +195,9 @@ def __init__( self.spin_λ = spin_λ + def wait(self): + return self.accelerator.wait_for_everyone() + def forward(self): """ Algorithm 1 - https://arxiv.org/abs/2401.01335v1 @@ -214,22 +218,16 @@ def forward(self): prompts = [one_real_seq[one_prompt_mask] for one_real_seq, one_prompt_mask in zip(real_seq, prompt_mask)] - generated_seqs = [] - - for prompt in prompts: - one_generated_seq = sample( - self.model, - prompt = rearrange(prompt, '... -> 1 ...'), - seq_len = self.max_seq_len, - temperature = self.temperature, - filter_fn = top_p, - filter_kwargs = dict( - thres = self.nucleus_p - ) + generated_seqs = sample( + self.model, + prompts = prompts, + seq_len = self.max_seq_len, + temperature = self.temperature, + filter_fn = top_p, + filter_kwargs = dict( + thres = self.nucleus_p ) - - one_generated_seq = rearrange(one_generated_seq, '1 ... -> ...') - generated_seqs.append(torch.cat((prompt, one_generated_seq), dim = -1)) + ) generated_seqs = pad_sequence(generated_seqs, padding_value = self.pad_id, batch_first = True) diff --git a/setup.py b/setup.py index f0c1b93..71e054d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',