Skip to content

Commit

Permalink
make the sampling performant, SPIN should be finished
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 26, 2024
1 parent 09996ac commit 5d6e843
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 42 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ trainer(overwrite_checkpoints = True)

## Todo

- [x] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that
- [x] handle eos

- [ ] remove early stopper in favor of just simple few line logic - have a function that accepts List[float] and decide what to do
- [ ] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that
- [ ] figure out how best to handle different impl of kv cache, for now just do without
- [ ] allow for different strategies for sampling the pairs
- [ ] consider KTO
- [ ] handle eos
- [ ] any order of sft, spin, self-rewarding dpo, dpo with external reward model
- [ ] show an example for using your own reward prompt instead of default llm-as-judge

Expand Down
18 changes: 14 additions & 4 deletions self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20):

def prompt_mask_from_len(lengths, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1')
return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1')

def maybe_and_mask(*masks):
masks = [*filter(exists, masks)]
Expand Down Expand Up @@ -243,6 +243,8 @@ def forward(
preferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
unpreferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None
):
self.policy_model.train()

"""
b - batch
n - sequence length
Expand Down Expand Up @@ -344,6 +346,12 @@ def __init__(
def is_main(self):
return self.accelerator.is_main_process

def wait(self):
return self.accelerator.wait_for_everyone()

def log(self, **data):
self.accelerator.log(data, step = self.steps)

def forward(
self,
train_self_reward_dataset: Dataset
Expand All @@ -363,27 +371,29 @@ def forward(
dpo_loss = self.model(*batch)
self.accelerator.backward(dpo_loss)

self.accelerator.log(dict(loss = dpo_loss.item()), step = self.steps)
self.log(loss = dpo_loss.item())

self.optimizer.step()
self.optimizer.zero_grad()

self.steps += 1
pbar.update(1)
self.accelerator.wait_for_everyone()
self.wait()

if not (self.steps % self.check_early_stop_every) and exists(self.early_stopper):

early_stop_return = self.early_stopper()

self.log(dpo_valid_score = early_stop_return.score)

if self.is_main and early_stop_return.should_stop:
self.break_signal.copy_(1.)
dist.all_reduce(self.break_signal)

if self.break_signal.item() == 1:
break

self.accelerator.wait_for_everyone()
self.wait()

pbar.close()
print('dpo training finished')
66 changes: 54 additions & 12 deletions self_rewarding_lm_pytorch/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch.nn.utils.rnn import pad_sequence

from beartype import beartype
from beartype.typing import Optional, Callable
from beartype.typing import Optional, Callable, List, Tuple

from tqdm import tqdm
from einops import rearrange

def exists(v):
return v is not None
Expand Down Expand Up @@ -57,22 +58,63 @@ def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None):
@beartype
def sample(
net: Module,
prompt: Tensor,
prompts,
seq_len: int,
temperature = 1.,
filter_fn: Callable = top_p,
filter_kwargs: dict = dict()
filter_kwargs: dict = dict(),
pad_id: int = -1,
eos_id: Optional[int] = None,
):
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
sample_num_times = max(0, seq_len - prompt_seq_len)
device = next(net.parameters()).device
net.eval()

for _ in tqdm(range(sample_num_times)):
logits = net(out)
logits = logits[:, -1]
if isinstance(prompts, (tuple, list)):
prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id)

batch, prompts_tensor_len = prompts.shape

batch_arange = torch.arange(batch, device = device)[..., None]

prompt_lens = (prompts != pad_id).sum(dim = -1)
curr_seq_indices = prompt_lens[..., None]

out = prompts.clone()

while (curr_seq_indices < seq_len).any():
out = F.pad(out, (0, 1), value = pad_id)

net_input = out.masked_fill(out == pad_id, 0)

logits = net(net_input)

logits = logits[batch_arange, curr_seq_indices]
logits = rearrange(logits, 'b 1 d -> b d')

logits = filter_fn(logits, **filter_kwargs)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1)

out[batch_arange, curr_seq_indices] = sampled_tokens

curr_seq_indices += 1
curr_seq_indices.clamp_(max = seq_len)

if not exists(eos_id):
continue

is_eos_mask = out == eos_id
all_eos = is_eos_mask.any(dim = -1).all()

if all_eos:
break

if exists(eos_id):
after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False)
out = out.masked_fill_(after_eos_mask, pad_id)

prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None]

out = torch.cat((out, sample), dim = -1)
generated_seq_mask = out != pad_id & ~prompt_mask
seq_lens = generated_seq_mask.sum(dim = -1).tolist()

return out[..., prompt_seq_len:]
return out[generated_seq_mask].split(seq_lens)
11 changes: 6 additions & 5 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def generate_reward(

reward_responses = sample(
self.model,
prompt = reward_prompt,
prompts = reward_prompt.long(),
seq_len = self.generate_reward_max_seq_len,
temperature = self.eval_temperature,
filter_fn = top_p,
Expand Down Expand Up @@ -424,9 +424,9 @@ def forward(self) -> DPODataset:
prompt_len = prompt_tensor.shape[-1]
repeated_prompt_tensor = repeat(prompt_tensor, 'n -> r n', r = self.num_candidate_responses)

candidate_responses_tensor = sample(
candidate_tensor_responses = sample(
self.model,
prompt = repeated_prompt_tensor,
prompts = repeated_prompt_tensor.long(),
seq_len = self.preference_max_seq_len,
temperature = self.gen_temperature,
filter_fn = top_p,
Expand All @@ -435,15 +435,16 @@ def forward(self) -> DPODataset:
)
)

candidate_responses: List[str] = [*map(self.tokenizer_decode, candidate_responses_tensor.long().tolist())]
candidate_int_responses: List[List[int]] = [response.tolist() for response in candidate_tensor_responses]
candidate_responses: List[str] = [*map(self.tokenizer_decode, candidate_int_responses)]

# get rewards

rewards: List[Optional[float]] = [self.generate_reward(prompt, response) for response in candidate_responses]

# zip together the responses and rewards and filter out if reward is not generated correctly

paired_reward_response = [(reward, candidate_response) for reward, candidate_response in zip(rewards, candidate_responses_tensor)]
paired_reward_response = [(reward, candidate_response) for reward, candidate_response in zip(rewards, candidate_tensor_responses)]

paired_reward_response = [*filter(lambda pair: exists(first(pair)), paired_reward_response)]
paired_reward_response.sort(key = first)
Expand Down
34 changes: 16 additions & 18 deletions self_rewarding_lm_pytorch/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20):

def prompt_mask_from_len(lengths, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1')
return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1')

def maybe_and_mask(*masks):
masks = [*filter(exists, masks)]
Expand Down Expand Up @@ -97,15 +97,16 @@ def forward(
self,
generated_seq: TensorType['b', 'n', int],
real_seq: TensorType['b', 'n', int],
prompt_len: Optional[TensorType['b', int]],
prompt_len: TensorType['b', int],
generated_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
real_seq_mask: Optional[TensorType['b', 'n', bool]] = None
):
self.policy_model.train()

"""
b - batch
n - sequence length
"""

assert generated_seq.ndim == real_seq.ndim == 2

real_prompt_mask = torch.arange(real_seq.shape[-1], device = self.device) < prompt_len[:, None]
Expand Down Expand Up @@ -194,6 +195,9 @@ def __init__(

self.spin_λ = spin_λ

def wait(self):
return self.accelerator.wait_for_everyone()

def forward(self):
"""
Algorithm 1 - https://arxiv.org/abs/2401.01335v1
Expand All @@ -214,22 +218,16 @@ def forward(self):

prompts = [one_real_seq[one_prompt_mask] for one_real_seq, one_prompt_mask in zip(real_seq, prompt_mask)]

generated_seqs = []

for prompt in prompts:
one_generated_seq = sample(
self.model,
prompt = rearrange(prompt, '... -> 1 ...'),
seq_len = self.max_seq_len,
temperature = self.temperature,
filter_fn = top_p,
filter_kwargs = dict(
thres = self.nucleus_p
)
generated_seqs = sample(
self.model,
prompts = prompts,
seq_len = self.max_seq_len,
temperature = self.temperature,
filter_fn = top_p,
filter_kwargs = dict(
thres = self.nucleus_p
)

one_generated_seq = rearrange(one_generated_seq, '1 ... -> ...')
generated_seqs.append(torch.cat((prompt, one_generated_seq), dim = -1))
)

generated_seqs = pad_sequence(generated_seqs, padding_value = self.pad_id, batch_first = True)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 5d6e843

Please sign in to comment.