Skip to content

Commit

Permalink
fix: use aparse in all metrics (#831)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Apr 1, 2024
1 parent b976369 commit e941823
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 21 deletions.
5 changes: 4 additions & 1 deletion src/ragas/llms/output_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import typing as t

from langchain_core.exceptions import OutputParserException
Expand All @@ -8,6 +9,7 @@
from ragas.llms import BaseRagasLLM
from ragas.llms.prompt import Prompt, PromptValue

logger = logging.getLogger(__name__)
# The get_format_instructions function is a modified version from
# langchain_core.output_parser.pydantic. The original version removed the "type" json schema
# property that confused some older LLMs.
Expand Down Expand Up @@ -53,7 +55,7 @@ def get_json_format_instructions(pydantic_object: t.Type[TBaseModel]) -> str:

class RagasoutputParser(PydanticOutputParser):
async def aparse( # type: ignore
self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int
self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int = 1
):
try:
output = super().parse(result)
Expand All @@ -66,5 +68,6 @@ async def aparse( # type: ignore
result = output.generations[0][0].text
return await self.aparse(result, prompt, llm, max_retries - 1)
else:
logger.warning("Failed to parse output. Returning None.")
return None
return output
4 changes: 3 additions & 1 deletion src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
)

answers = [
_output_parser.parse(result.text) for result in result.generations[0]
await _output_parser.aparse(result.text, prompt, self.llm)
for result in result.generations[0]
]
if any(answer is None for answer in answers):
return np.nan

answers = [answer for answer in answers if answer is not None]
return self._calculate_score(answers, row)

def adapt(self, language: str, cache_dir: str | None = None) -> None:
Expand Down
13 changes: 8 additions & 5 deletions src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class ContextEntityRecall(MetricWithLLM):
default_factory=lambda: TEXT_ENTITY_EXTRACTION
)
batch_size: int = 15
max_retries: int = 1

def _compute_score(
self, ground_truth_entities: t.Sequence[str], context_entities: t.Sequence[str]
Expand All @@ -151,17 +152,19 @@ async def get_entities(
is_async: bool,
) -> t.Optional[ContextEntitiesResponse]:
assert self.llm is not None, "LLM is not initialized"

p_value = self.context_entity_recall_prompt.format(
text=text,
)
result = await self.llm.generate(
prompt=self.context_entity_recall_prompt.format(
text=text,
),
prompt=p_value,
callbacks=callbacks,
is_async=is_async,
)

result_text = result.generations[0][0].text
answer = _output_parser.parse(result_text)
answer = await _output_parser.aparse(
result_text, p_value, self.llm, self.max_retries
)
if answer is None:
return ContextEntitiesResponse(entities=[])

Expand Down
11 changes: 8 additions & 3 deletions src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ContextPrecision(MetricWithLLM):
name: str = "context_precision" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore
context_precision_prompt: Prompt = field(default_factory=lambda: CONTEXT_PRECISION)
max_retries: int = 1

def _get_row_attributes(self, row: t.Dict) -> t.Tuple[str, t.List[str], t.Any]:
answer = "ground_truth"
Expand Down Expand Up @@ -138,20 +139,24 @@ async def _ascore(
assert self.llm is not None, "LLM is not set"

human_prompts = self._context_precision_prompt(row)
responses: t.List[str] = []
responses = []
for hp in human_prompts:
result = await self.llm.generate(
hp,
n=1,
callbacks=callbacks,
is_async=is_async,
)
responses.append(result.generations[0][0].text)
responses.append([result.generations[0][0].text, hp])

items = [_output_parser.parse(item) for item in responses]
items = [
await _output_parser.aparse(item, hp, self.llm, self.max_retries)
for item, hp in responses
]
if any(item is None for item in items):
return np.nan

items = [item for item in items if item is not None]
answers = ContextPrecisionVerifications(__root__=items)
score = self._calculate_average_precision(answers.__root__)
return score
Expand Down
9 changes: 6 additions & 3 deletions src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ContextRecall(MetricWithLLM):
name: str = "context_recall" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore
context_recall_prompt: Prompt = field(default_factory=lambda: CONTEXT_RECALL_RA)
max_retries: int = 1

def _create_context_recall_prompt(self, row: t.Dict) -> PromptValue:
qstn, ctx, gt = row["question"], row["contexts"], row["ground_truth"]
Expand All @@ -142,15 +143,17 @@ def _compute_score(self, response: t.Any) -> float:

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "set LLM before use"

p_value = self._create_context_recall_prompt(row)
result = await self.llm.generate(
self._create_context_recall_prompt(row),
p_value,
callbacks=callbacks,
is_async=is_async,
)
result_text = result.generations[0][0].text

answers = _output_parser.parse(result_text)
answers = await _output_parser.aparse(
result_text, p_value, self.llm, self.max_retries
)
if answers is None:
return np.nan

Expand Down
22 changes: 14 additions & 8 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dataclasses import dataclass, field

import numpy as np
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field

from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand Down Expand Up @@ -85,7 +84,7 @@ def dicts(self) -> t.List[t.Dict]:
_faithfulness_output_instructions = get_json_format_instructions(
StatementFaithfulnessAnswers
)
_faithfulness_output_parser = PydanticOutputParser(
_faithfulness_output_parser = RagasoutputParser(
pydantic_object=StatementFaithfulnessAnswers
)

Expand Down Expand Up @@ -157,6 +156,7 @@ class Faithfulness(MetricWithLLM):
nli_statements_message: Prompt = field(
default_factory=lambda: NLI_STATEMENTS_MESSAGE
)
max_retries: int = 1

def _create_answer_prompt(self, row: t.Dict) -> PromptValue:
question, answer = row["question"], row["answer"]
Expand Down Expand Up @@ -200,20 +200,26 @@ async def _ascore(
returns the NLI score for each (q, c, a) pair
"""
assert self.llm is not None, "LLM is not set"
p = self._create_answer_prompt(row)
p_value = self._create_answer_prompt(row)
answer_result = await self.llm.generate(
p, callbacks=callbacks, is_async=is_async
p_value, callbacks=callbacks, is_async=is_async
)
answer_result_text = answer_result.generations[0][0].text
statements = _statements_output_parser.parse(answer_result_text)
statements = await _statements_output_parser.aparse(
answer_result_text, p_value, self.llm, self.max_retries
)
if statements is None:
return np.nan

p = self._create_nli_prompt(row, statements.__root__)
nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async)
p_value = self._create_nli_prompt(row, statements.__root__)
nli_result = await self.llm.generate(
p_value, callbacks=callbacks, is_async=is_async
)
nli_result_text = nli_result.generations[0][0].text

faithfulness = _faithfulness_output_parser.parse(nli_result_text)
faithfulness = await _faithfulness_output_parser.aparse(
nli_result_text, p_value, self.llm, self.max_retries
)
if faithfulness is None:
return np.nan

Expand Down

0 comments on commit e941823

Please sign in to comment.