Skip to content

Commit

Permalink
Adapted getting predictions to new way for all metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 23, 2024
1 parent 4418e82 commit 075ebd2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions community_tasks/swiss_legal_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def compute(
# There should be only one language each in the batch
assert len(set(source_langs)) == len(set(target_langs)) == 1
sources = [formatted_doc.specific["source"] for formatted_doc in formatted_docs]
predictions = [response[0].result[0] for response in responses]
predictions = [response[0].result for response in responses]

answers, errors = get_gemba_scores(
sources, predictions, source_langs[0], target_langs[0], method=self.method, model=self.model
Expand Down Expand Up @@ -461,7 +461,7 @@ def compute(
) -> dict[str, float]:
logger.info(f"Scoring {len(formatted_docs)} samples with {self.metric_name}...")
golds = [formatted_doc.get_golds()[0] for formatted_doc in formatted_docs]
predictions = [response[0].result[0] for response in responses]
predictions = [response[0].result for response in responses]

all_scores = []
for i in range(0, len(golds), self.batch_size):
Expand Down Expand Up @@ -531,7 +531,7 @@ def compute(
) -> dict[str, float]:
logger.info(f"Scoring {len(formatted_docs)} samples with {self.metric_name}...")
golds = [formatted_doc.get_golds()[0] for formatted_doc in formatted_docs]
predictions = [response[0].result[0] for response in responses]
predictions = [response[0].result for response in responses]
sources = [kwargs["formatted_doc"].specific["source"] for kwargs["formatted_doc"] in formatted_docs]

data = [{"src": src, "mt": pred, "ref": gold} for src, pred, gold in zip(sources, predictions, golds)]
Expand Down

0 comments on commit 075ebd2

Please sign in to comment.