Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce components experimental #1339

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/ragas/experimental/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel, Field

from ragas.experimental.llms.prompt import PydanticPrompt
from ragas.experimental.metrics.component import BaseNLIComponent, LLMNLIComponent
from ragas.metrics.base import (
MetricType,
MetricWithLLM,
Expand Down Expand Up @@ -166,6 +167,7 @@ class FaithfulnessExperimental(MetricWithLLM, SingleTurnMetric):
MetricType.SINGLE_TURN: {"user_input", "response", "retrieved_contexts"}
}
)
nli_component: t.Optional[BaseNLIComponent] = None
sentence_segmenter: t.Optional[HasSegmentMethod] = None
max_retries: int = 1
_reproducibility: int = 1
Expand All @@ -188,21 +190,22 @@ def reproducibility(self, value):

def __post_init__(self):
self.long_form_answer_prompt = LongFormAnswerPrompt()
self.nli_statement_prompt = NLIStatementPrompt()
self.nli_component = LLMNLIComponent(llm=self.llm) # type: ignore
if self.sentence_segmenter is None:
# TODO: make this dynamic, taking language from prompt
language = "english"
self.sentence_segmenter = get_segmenter(language=language, clean=False)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "LLM is not set"
assert self.nli_component is not None, "NLI component is not set"

answer, question, contexts = (
row["response"],
row["user_input"],
row["retrieved_contexts"],
)

contexts = "\n".join(contexts)
# get the sentences from the answer
if self.sentence_segmenter is None:
raise ValueError("Sentence segmenter is not set")
Expand All @@ -226,19 +229,12 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
for component in sentence_components.sentences
for statement in component.simpler_statements
]
verdicts = await self.nli_statement_prompt.generate(
data=NLIStatementInput(
context="\n".join(contexts),
statements=statements,
),
llm=self.llm,
callbacks=callbacks,
verdicts = await self.nli_component.apply(
hypothesis=statements, premise=contexts, callbacks=callbacks
)

# compute the score
num_faithful_statements = sum(
verdict.verdict for verdict in verdicts.statements
)
num_faithful_statements = sum(verdicts)
if len(statements):
score = num_faithful_statements / len(statements)
else:
Expand Down
177 changes: 177 additions & 0 deletions src/ragas/experimental/metrics/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import logging
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from pydantic import BaseModel, Field
from transformers import AutoTokenizer, Pipeline, pipeline

from ragas.experimental.llms.prompt import PydanticPrompt
from ragas.llms.base import BaseRagasLLM

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks

logger = logging.getLogger(__name__)


class StatementFaithfulnessAnswer(BaseModel):
statement: str = Field(..., description="the original statement, word-by-word")
reason: str = Field(..., description="the reason of the verdict")
verdict: int = Field(..., description="the verdict(0/1) of the faithfulness.")


class NLIStatementOutput(BaseModel):
statements: t.List[StatementFaithfulnessAnswer]


class NLIStatementInput(BaseModel):
context: str = Field(..., description="The context of the question")
statements: t.List[str] = Field(..., description="The statements to judge")


class NLIStatementPrompt(PydanticPrompt[NLIStatementInput, NLIStatementOutput]):
instruction = "Your task is to judge the faithfulness of a series of statements based on a given context. For each statement you must return verdict as 1 if the statement can be directly inferred based on the context or 0 if the statement can not be directly inferred based on the context."
input_model = NLIStatementInput
output_model = NLIStatementOutput
examples = [
(
NLIStatementInput(
context="""John is a student at XYZ University. He is pursuing a degree in Computer Science. He is enrolled in several courses this semester, including Data Structures, Algorithms, and Database Management. John is a diligent student and spends a significant amount of time studying and completing assignments. He often stays late in the library to work on his projects.""",
statements=[
"John is majoring in Biology.",
"John is taking a course on Artificial Intelligence.",
"John is a dedicated student.",
"John has a part-time job.",
],
),
NLIStatementOutput(
statements=[
StatementFaithfulnessAnswer(
statement="John is majoring in Biology.",
reason="John's major is explicitly mentioned as Computer Science. There is no information suggesting he is majoring in Biology.",
verdict=0,
),
StatementFaithfulnessAnswer(
statement="John is taking a course on Artificial Intelligence.",
reason="The context mentions the courses John is currently enrolled in, and Artificial Intelligence is not mentioned. Therefore, it cannot be deduced that John is taking a course on AI.",
verdict=0,
),
StatementFaithfulnessAnswer(
statement="John is a dedicated student.",
reason="The context states that he spends a significant amount of time studying and completing assignments. Additionally, it mentions that he often stays late in the library to work on his projects, which implies dedication.",
verdict=1,
),
StatementFaithfulnessAnswer(
statement="John has a part-time job.",
reason="There is no information given in the context about John having a part-time job.",
verdict=0,
),
]
),
),
(
NLIStatementInput(
context="Photosynthesis is a process used by plants, algae, and certain bacteria to convert light energy into chemical energy.",
statements=[
"Albert Einstein was a genius.",
],
),
NLIStatementOutput(
statements=[
StatementFaithfulnessAnswer(
statement="Albert Einstein was a genius.",
reason="The context and statement are unrelated",
verdict=0,
)
]
),
),
]


class BaseNLIComponent(ABC):
@abstractmethod
async def apply(
self, hypothesis: t.List[str], premise: str, callbacks: Callbacks
) -> t.List[bool]:
"""
Apply the NLI component to a list of premises and a hypothesis.
"""
raise NotImplementedError("apply method must be implemented by subclasses")


@dataclass
class LLMNLIComponent(BaseNLIComponent):
llm: BaseRagasLLM
nli_prompt: PydanticPrompt = NLIStatementPrompt()

async def apply(
self, hypothesis: t.List[str], premise: str, callbacks: Callbacks
) -> t.List[bool]:
assert self.llm is not None, "LLM must be set"
prompt_input = NLIStatementInput(context=premise, statements=hypothesis)
response = await self.nli_prompt.generate(
data=prompt_input, llm=self.llm, callbacks=callbacks
)
return [bool(result.verdict) for result in response.statements]


@dataclass
class TextClassificationNLIComponent(BaseNLIComponent):
hf_pipeline: Pipeline
prompt: str
label: str
batch_size: int = 32
model_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)

def __post_init__(self):
if "{premise}" not in self.prompt or "{hypothesis}" not in self.prompt:
raise ValueError("Prompt should not contain 'premise' or 'hypothesis'")

self.model_kwargs["top_k"] = 1

@classmethod
def from_model_id(
cls,
model_id: str,
prompt: str,
label: str,
model_kwargs: t.Dict[str, t.Any] = {},
pipeline_kwargs: t.Dict[str, t.Any] = {},
) -> TextClassificationNLIComponent:
tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_pipeline = pipeline(
"text-classification",
model=model_id,
tokenizer=tokenizer,
**pipeline_kwargs,
)

return cls(
hf_pipeline=hf_pipeline,
prompt=prompt,
label=label,
model_kwargs=model_kwargs,
)

async def apply(
self, hypothesis: t.List[str], premise: str, callbacks: Callbacks = None
) -> t.List[bool]:
scores = []
prompt_input_list = [
self.prompt.format(hypothesis=text, premise=premise) for text in hypothesis
]
for i in range(0, len(prompt_input_list), self.batch_size):
prompt_input_list_batch = prompt_input_list[i : i + self.batch_size]
response = self.hf_pipeline(prompt_input_list_batch, **self.model_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be a problem with async calls

assert isinstance(response, list), "Response should be a list"
assert all(
isinstance(item, dict) for item in response
), "Items in response should be dictionaries"
response = [item[0].get("label") == self.label for item in response] # type: ignore
scores.extend(response)

return scores
Loading