Skip to content

Commit

Permalink
Merge pull request #29 from Control-derek/derek
Browse files Browse the repository at this point in the history
Fixed deep copy, shallow copy error and label mask error.
  • Loading branch information
lucidrains authored Apr 11, 2024
2 parents 2db4fed + 2c3ec61 commit d4755a2
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ def get_cross_entropy_loss(
else:
prompt_mask = prompt_len_or_mask

seq, labels = seq[:, :-1], seq[:, 1:]
seq, labels = seq[:, :-1].clone(), seq[:, 1:].clone()

labels.masked_fill_(prompt_mask[:, 1:], self.ignore_index)
labels.masked_fill_(~prompt_mask[:, 1:], self.ignore_index)

logits = self.model(seq)

Expand Down Expand Up @@ -822,6 +822,7 @@ def __init__(
eval_filter_fn = config.eval_filter_fn,
eval_filter_kwargs = config.eval_filter_kwargs,
accelerator = self.accelerator,
pad_id = pad_id,
**config.reward_generator_kwargs
)

Expand Down Expand Up @@ -862,6 +863,7 @@ def __init__(
eval_temperature = config.eval_temperature,
eval_filter_fn = config.eval_filter_fn,
eval_filter_kwargs = config.eval_filter_kwargs,
pad_id = pad_id,
**config.reward_generator_kwargs
)

Expand Down

0 comments on commit d4755a2

Please sign in to comment.