From e9418237d0dd06f640b0f6bbae7a5a3f9c1029ea Mon Sep 17 00:00:00 2001 From: Shahul ES Date: Mon, 1 Apr 2024 14:19:42 -0700 Subject: [PATCH] fix: use aparse in all metrics (#831) --- src/ragas/llms/output_parser.py | 5 ++++- src/ragas/metrics/_answer_relevance.py | 4 +++- src/ragas/metrics/_context_entities_recall.py | 13 ++++++----- src/ragas/metrics/_context_precision.py | 11 +++++++--- src/ragas/metrics/_context_recall.py | 9 +++++--- src/ragas/metrics/_faithfulness.py | 22 ++++++++++++------- 6 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/ragas/llms/output_parser.py b/src/ragas/llms/output_parser.py index 609c6e72d..8ff6c7cfb 100644 --- a/src/ragas/llms/output_parser.py +++ b/src/ragas/llms/output_parser.py @@ -1,4 +1,5 @@ import json +import logging import typing as t from langchain_core.exceptions import OutputParserException @@ -8,6 +9,7 @@ from ragas.llms import BaseRagasLLM from ragas.llms.prompt import Prompt, PromptValue +logger = logging.getLogger(__name__) # The get_format_instructions function is a modified version from # langchain_core.output_parser.pydantic. The original version removed the "type" json schema # property that confused some older LLMs. @@ -53,7 +55,7 @@ def get_json_format_instructions(pydantic_object: t.Type[TBaseModel]) -> str: class RagasoutputParser(PydanticOutputParser): async def aparse( # type: ignore - self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int + self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int = 1 ): try: output = super().parse(result) @@ -66,5 +68,6 @@ async def aparse( # type: ignore result = output.generations[0][0].text return await self.aparse(result, prompt, llm, max_retries - 1) else: + logger.warning("Failed to parse output. Returning None.") return None return output diff --git a/src/ragas/metrics/_answer_relevance.py b/src/ragas/metrics/_answer_relevance.py index e859bb3d6..af95aa06d 100644 --- a/src/ragas/metrics/_answer_relevance.py +++ b/src/ragas/metrics/_answer_relevance.py @@ -157,11 +157,13 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl ) answers = [ - _output_parser.parse(result.text) for result in result.generations[0] + await _output_parser.aparse(result.text, prompt, self.llm) + for result in result.generations[0] ] if any(answer is None for answer in answers): return np.nan + answers = [answer for answer in answers if answer is not None] return self._calculate_score(answers, row) def adapt(self, language: str, cache_dir: str | None = None) -> None: diff --git a/src/ragas/metrics/_context_entities_recall.py b/src/ragas/metrics/_context_entities_recall.py index 8b2b55686..ba249107f 100644 --- a/src/ragas/metrics/_context_entities_recall.py +++ b/src/ragas/metrics/_context_entities_recall.py @@ -135,6 +135,7 @@ class ContextEntityRecall(MetricWithLLM): default_factory=lambda: TEXT_ENTITY_EXTRACTION ) batch_size: int = 15 + max_retries: int = 1 def _compute_score( self, ground_truth_entities: t.Sequence[str], context_entities: t.Sequence[str] @@ -151,17 +152,19 @@ async def get_entities( is_async: bool, ) -> t.Optional[ContextEntitiesResponse]: assert self.llm is not None, "LLM is not initialized" - + p_value = self.context_entity_recall_prompt.format( + text=text, + ) result = await self.llm.generate( - prompt=self.context_entity_recall_prompt.format( - text=text, - ), + prompt=p_value, callbacks=callbacks, is_async=is_async, ) result_text = result.generations[0][0].text - answer = _output_parser.parse(result_text) + answer = await _output_parser.aparse( + result_text, p_value, self.llm, self.max_retries + ) if answer is None: return ContextEntitiesResponse(entities=[]) diff --git a/src/ragas/metrics/_context_precision.py b/src/ragas/metrics/_context_precision.py index 302c91667..e6a0ff41e 100644 --- a/src/ragas/metrics/_context_precision.py +++ b/src/ragas/metrics/_context_precision.py @@ -89,6 +89,7 @@ class ContextPrecision(MetricWithLLM): name: str = "context_precision" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore context_precision_prompt: Prompt = field(default_factory=lambda: CONTEXT_PRECISION) + max_retries: int = 1 def _get_row_attributes(self, row: t.Dict) -> t.Tuple[str, t.List[str], t.Any]: answer = "ground_truth" @@ -138,7 +139,7 @@ async def _ascore( assert self.llm is not None, "LLM is not set" human_prompts = self._context_precision_prompt(row) - responses: t.List[str] = [] + responses = [] for hp in human_prompts: result = await self.llm.generate( hp, @@ -146,12 +147,16 @@ async def _ascore( callbacks=callbacks, is_async=is_async, ) - responses.append(result.generations[0][0].text) + responses.append([result.generations[0][0].text, hp]) - items = [_output_parser.parse(item) for item in responses] + items = [ + await _output_parser.aparse(item, hp, self.llm, self.max_retries) + for item, hp in responses + ] if any(item is None for item in items): return np.nan + items = [item for item in items if item is not None] answers = ContextPrecisionVerifications(__root__=items) score = self._calculate_average_precision(answers.__root__) return score diff --git a/src/ragas/metrics/_context_recall.py b/src/ragas/metrics/_context_recall.py index e4613f3dc..f77c49ce9 100644 --- a/src/ragas/metrics/_context_recall.py +++ b/src/ragas/metrics/_context_recall.py @@ -122,6 +122,7 @@ class ContextRecall(MetricWithLLM): name: str = "context_recall" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore context_recall_prompt: Prompt = field(default_factory=lambda: CONTEXT_RECALL_RA) + max_retries: int = 1 def _create_context_recall_prompt(self, row: t.Dict) -> PromptValue: qstn, ctx, gt = row["question"], row["contexts"], row["ground_truth"] @@ -142,15 +143,17 @@ def _compute_score(self, response: t.Any) -> float: async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: assert self.llm is not None, "set LLM before use" - + p_value = self._create_context_recall_prompt(row) result = await self.llm.generate( - self._create_context_recall_prompt(row), + p_value, callbacks=callbacks, is_async=is_async, ) result_text = result.generations[0][0].text - answers = _output_parser.parse(result_text) + answers = await _output_parser.aparse( + result_text, p_value, self.llm, self.max_retries + ) if answers is None: return np.nan diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 83c589d70..ceacc4528 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -6,7 +6,6 @@ from dataclasses import dataclass, field import numpy as np -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.pydantic_v1 import BaseModel, Field from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions @@ -85,7 +84,7 @@ def dicts(self) -> t.List[t.Dict]: _faithfulness_output_instructions = get_json_format_instructions( StatementFaithfulnessAnswers ) -_faithfulness_output_parser = PydanticOutputParser( +_faithfulness_output_parser = RagasoutputParser( pydantic_object=StatementFaithfulnessAnswers ) @@ -157,6 +156,7 @@ class Faithfulness(MetricWithLLM): nli_statements_message: Prompt = field( default_factory=lambda: NLI_STATEMENTS_MESSAGE ) + max_retries: int = 1 def _create_answer_prompt(self, row: t.Dict) -> PromptValue: question, answer = row["question"], row["answer"] @@ -200,20 +200,26 @@ async def _ascore( returns the NLI score for each (q, c, a) pair """ assert self.llm is not None, "LLM is not set" - p = self._create_answer_prompt(row) + p_value = self._create_answer_prompt(row) answer_result = await self.llm.generate( - p, callbacks=callbacks, is_async=is_async + p_value, callbacks=callbacks, is_async=is_async ) answer_result_text = answer_result.generations[0][0].text - statements = _statements_output_parser.parse(answer_result_text) + statements = await _statements_output_parser.aparse( + answer_result_text, p_value, self.llm, self.max_retries + ) if statements is None: return np.nan - p = self._create_nli_prompt(row, statements.__root__) - nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async) + p_value = self._create_nli_prompt(row, statements.__root__) + nli_result = await self.llm.generate( + p_value, callbacks=callbacks, is_async=is_async + ) nli_result_text = nli_result.generations[0][0].text - faithfulness = _faithfulness_output_parser.parse(nli_result_text) + faithfulness = await _faithfulness_output_parser.aparse( + nli_result_text, p_value, self.llm, self.max_retries + ) if faithfulness is None: return np.nan