Skip to content

Commit

Permalink
update the __call__ for slicpairpm (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiXiongUST authored May 16, 2024
1 parent ad38d67 commit 60faba7
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions rewardbench/models/slicpairpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, task, model, tokenizer):
self.token_id_B = token_id_B[0]
self.temperature = 1.0

def __call__(self, prompts: List[str], candidates_A: List[str], candidates_B: List[str]):
def __call__(self, candidates_A: List[str], candidates_B: List[str], **kwargs):
"""
Input:
prompts: [prompt1, prompt2, ..., promptn]
Expand All @@ -41,14 +41,14 @@ def __call__(self, prompts: List[str], candidates_A: List[str], candidates_B: Li
Output:
probs_choose_A: [P(responseA1 > responseB1 | prompt1), ...., P(responseAn > responseBn | promptn)]
"""
assert len(prompts) == len(candidates_A)

assert len(candidates_A) == len(candidates_B)
probs_choose_A = []
for i in range(len(prompts)):
instruction = [{"role": "user", "content": prompts[i]}]
context = self.tokenizer_data_format.apply_chat_template(instruction, tokenize=False)
responses = [candidates_A[i], candidates_B[i]]

for i in range(len(candidates_A)):
chosen = candidates_A[i]
rejected = candidates_B[i]
context = self.tokenizer_data_format.apply_chat_template(chosen[:-1], tokenize=False)
responses = [chosen[-1]["content"], rejected[-1]["content"]]
probs_chosen = []

for chosen_position in [0, 1]:
Expand Down Expand Up @@ -77,4 +77,4 @@ def __call__(self, prompts: List[str], candidates_A: List[str], candidates_B: Li
probs_chosen.append(prob_chosen)
probs_choose_A.append(np.mean(probs_chosen))
# probs_chose_B = 1 - probs_choose_A
return probs_choose_A
return torch.tensor([x > 0.5 for x in probs_choose_A])

0 comments on commit 60faba7

Please sign in to comment.