Skip to content

Commit

Permalink
adding gold answers to judges
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Oct 7, 2024
1 parent 3da4a06 commit 765768b
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class BaseConstraintJudge(BaseJudge):
"""

@abstractmethod
def judge(self, prompts: List[str], completions: List[str], shuffle_order: bool = True) -> List[int]:
def judge(
self, prompts: List[str], completions: List[str], gold_answers: List[str] = None, shuffle_order: bool = True
) -> List[int]:
"""
Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.
Expand All @@ -150,6 +152,7 @@ def judge(self, prompts: List[str], completions: List[str], shuffle_order: bool
Args:
prompts (`List[str]`): List of prompts.
completions (`List[str]`): List of completions.
gold_answers (`List[str]`): List of gold answers if it exists.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
Returns:
Expand All @@ -170,7 +173,7 @@ class RandomConstraintJudge(BaseConstraintJudge):
Random constraint judge, for testing purposes.
"""

def judge(self, prompts, completions, shuffle_order=True):
def judge(self, prompts, completions, gold_answers=None, shuffle_order=True):
return [random.choice([0, 1]) for _ in range(len(prompts))]


Expand Down Expand Up @@ -360,8 +363,12 @@ class MixtureOfConstraintJudges(BaseConstraintJudge):
def __init__(self, judges: List[BaseConstraintJudge]):
self.judges = judges

def judge(self, prompts: List[str], completions: List[str], shuffle_order: bool = True) -> List[bool]:
all_constraint_judgments = [judge.judge(prompts, completions, shuffle_order) for judge in self.judges]
def judge(
self, prompts: List[str], completions: List[str], gold_answers: List[str] = None, shuffle_order: bool = True
) -> List[bool]:
all_constraint_judgments = [
judge.judge(prompts, completions, gold_answers, shuffle_order) for judge in self.judges
]

return [
1 if all(constraint_judgment == 1 for constraint_judgment in constraint_judgments) else 0
Expand Down

0 comments on commit 765768b

Please sign in to comment.