From fd338cb6b14c309c7e5b234fe69e53a21649ca69 Mon Sep 17 00:00:00 2001 From: Control-derek <1067245742@qq.com> Date: Sun, 7 Apr 2024 15:15:09 +0800 Subject: [PATCH 1/2] Fixed deep copy, shallow copy problem --- self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 146b1e0..91793ef 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -315,7 +315,7 @@ def get_cross_entropy_loss( else: prompt_mask = prompt_len_or_mask - seq, labels = seq[:, :-1], seq[:, 1:] + seq, labels = seq[:, :-1].clone(), seq[:, 1:].clone() labels.masked_fill_(prompt_mask[:, 1:], self.ignore_index) @@ -822,6 +822,7 @@ def __init__( eval_filter_fn = config.eval_filter_fn, eval_filter_kwargs = config.eval_filter_kwargs, accelerator = self.accelerator, + pad_id = pad_id, **config.reward_generator_kwargs ) @@ -862,6 +863,7 @@ def __init__( eval_temperature = config.eval_temperature, eval_filter_fn = config.eval_filter_fn, eval_filter_kwargs = config.eval_filter_kwargs, + pad_id = pad_id, **config.reward_generator_kwargs ) From 2c3ec617a369776b880a92e763b3a6803b49af5c Mon Sep 17 00:00:00 2001 From: Control-derek <1067245742@qq.com> Date: Sun, 7 Apr 2024 15:16:37 +0800 Subject: [PATCH 2/2] Fixed label mask error --- self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 91793ef..16143d6 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -317,7 +317,7 @@ def get_cross_entropy_loss( seq, labels = seq[:, :-1].clone(), seq[:, 1:].clone() - labels.masked_fill_(prompt_mask[:, 1:], self.ignore_index) + labels.masked_fill_(~prompt_mask[:, 1:], self.ignore_index) logits = self.model(seq)