Skip to content

Commit

Permalink
another round of cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 26, 2024
1 parent 4b137eb commit 09996ac
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
18 changes: 11 additions & 7 deletions self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
def exists(v):
return v is not None

def cycle(dl):
while True:
for batch in dl:
yield batch

def freeze_all_layers_(module):
for param in module.parameters():
param.requires_grad = False
Expand All @@ -52,6 +57,10 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20):
logprobs = probs.gather(-1, seq).clamp(min = eps).log()
return rearrange(logprobs, '... 1 -> ...')

def prompt_mask_from_len(lengths, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1')

def maybe_and_mask(*masks):
masks = [*filter(exists, masks)]
if len(masks) == 0:
Expand Down Expand Up @@ -241,8 +250,8 @@ def forward(

assert preferred_seq.ndim == unpreferred_seq.ndim == 2

preferred_prompt_mask = torch.arange(preferred_seq.shape[-1], device = self.device) < prompt_len[:, None]
unpreferred_prompt_mask = torch.arange(unpreferred_seq.shape[-1], device = self.device) < prompt_len[:, None]
preferred_prompt_mask = prompt_mask_from_len(prompt_len, preferred_seq)
unpreferred_prompt_mask = prompt_mask_from_len(prompt_len, unpreferred_seq)

"""
Following Appendix B in https://arxiv.org/abs/2305.18290
Expand Down Expand Up @@ -281,11 +290,6 @@ def forward(

# trainer class

def cycle(dl):
while True:
for batch in dl:
yield batch

class DPOTrainer(Module):
@beartype
def __init__(
Expand Down
8 changes: 6 additions & 2 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def cycle(dl):
for batch in dl:
yield batch

def prompt_mask_from_len(length, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(length, '... -> ... 1')

# constants
# llm-as-judge prompt
# https://openreview.net/forum?id=uccHPGDlao
Expand Down Expand Up @@ -115,7 +119,7 @@ def default_parse_reward_fn(llm_response: str) -> float:
class RewardConfig:
prompt_template: str
parse_reward: Callable[[str], Optional[float]]
render: Optional[Callable[..., str]] = None
template_fn: Optional[Callable[..., str]] = None

def init(self):
prompt_template = self.prompt_template
Expand Down Expand Up @@ -206,7 +210,7 @@ def get_cross_entropy_loss(
]
):
if prompt_len_or_mask.dtype == torch.long:
prompt_mask = torch.arange(seq.shape[-1], device = seq.device) < prompt_len_or_mask[..., None]
prompt_mask = prompt_mask_from_len(prompt_len_or_mask, seq)
else:
prompt_mask = prompt_len_or_mask

Expand Down
6 changes: 5 additions & 1 deletion self_rewarding_lm_pytorch/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def log_prob_from_model_and_seq(model, seq, eps = 1e-20):
logprobs = probs.gather(-1, seq).clamp(min = eps).log()
return rearrange(logprobs, '... 1 -> ...')

def prompt_mask_from_len(lengths, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(prompt_len, '... -> ... 1')

def maybe_and_mask(*masks):
masks = [*filter(exists, masks)]
if len(masks) == 0:
Expand Down Expand Up @@ -206,7 +210,7 @@ def forward(self):
for epoch in tqdm(range(self.epochs), desc = 'spin epoch'):
for real_seq, prompt_len in tqdm(self.train_dataloader, desc = 'spin finetuning'):

prompt_mask = torch.arange(real_seq.shape[-1], device = real_seq.device) < prompt_len[..., None]
prompt_mask = prompt_mask_from_len(prompt_len, real_seq)

prompts = [one_real_seq[one_prompt_mask] for one_real_seq, one_prompt_mask in zip(real_seq, prompt_mask)]

Expand Down

0 comments on commit 09996ac

Please sign in to comment.