Skip to content

Commit

Permalink
Merge pull request #67 from Health-Informatics-UoN/feature/LLMPipelin…
Browse files Browse the repository at this point in the history
…eTest-class

Feature/llm pipeline test class
  • Loading branch information
Karthi-DStech authored Oct 24, 2024
2 parents 15d9a35 + 22e5327 commit bcd664c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 4 deletions.
66 changes: 66 additions & 0 deletions Carrot-Assistant/evaluation/eval_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Dict
from evaluation.evaltypes import (
SingleResultPipelineTest,
SingleResultMetric,
)
from evaluation.pipelines import LLMPipeline


class LLMPipelineTest(SingleResultPipelineTest):
"""
This class provides a pipeline test for LLM pipelines that return a single result
"""

def __init__(
self,
name: str,
pipeline: LLMPipeline,
metrics: list[SingleResultMetric],
):
"""
Initialises the LLMPipelineTest class
Parameters
----------
name: str
Name given to the test
pipeline: LLMPipeline
The pipeline used to generate output
metrics: list[SingleResultMetric]
A list of metrics used to compare the pipeline output with the expected output
"""
super().__init__(name, pipeline, metrics)

def run_pipeline(self, input_data) -> str:
"""
Runs the provided pipeline on the input_data
Parameters
----------
input_data
The data used for input to the pipeline
Returns
-------
str
The reply from the pipeline
"""
return super().run_pipeline(input_data)

def evaluate(self, input_data, expected_output) -> Dict:
"""
Evaluates the attached pipeline's output against the expected output using the metrics
Parameters
----------
input_data
The data used for input to the pipeline
expected_output
The expected result of running the input data through the pipeline
Returns
-------
Dict
A dictionary of results from evaluating the pipeline.
"""
return super().evaluate(input_data, expected_output)
7 changes: 4 additions & 3 deletions Carrot-Assistant/evaluation/evaltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def run(self, *args, **kwargs) -> Any:


M = TypeVar("M", bound=Metric)
P = TypeVar("P", bound=TestPipeline)


class PipelineTest(Generic[M]):
class PipelineTest(Generic[P, M]):
"""
Base class for Pipeline tests
"""

def __init__(self, name: str, pipeline: TestPipeline, metrics: list[M]):
def __init__(self, name: str, pipeline: P, metrics: list[M]):
self.name = name
self.pipeline = pipeline
self.metrics = metrics
Expand Down Expand Up @@ -77,7 +78,7 @@ class SingleResultPipeline(TestPipeline):
"""


class SingleResultPipelineTest(PipelineTest[SingleResultMetric]):
class SingleResultPipelineTest(PipelineTest[SingleResultPipeline, SingleResultMetric]):
def __init__(
self,
name: str,
Expand Down
21 changes: 20 additions & 1 deletion Carrot-Assistant/tests/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from jinja2 import Environment, Template
from jinja2 import Environment

from evaluation.eval_tests import LLMPipelineTest
from evaluation.evaltypes import SingleResultPipeline, SingleResultPipelineTest
from evaluation.metrics import ExactMatchMetric
from evaluation.pipelines import LLMPipeline
Expand Down Expand Up @@ -90,3 +91,21 @@ def llm_pipeline(self, llm_prompt):
def test_returns_string(self, llm_pipeline):
model_output = llm_pipeline.run({"input_sentence": "Polly wants a cracker"})
assert isinstance(model_output, str)

@pytest.fixture
def llm_pipeline_test(self, llm_pipeline):
return LLMPipelineTest("Parrot Pipeline", llm_pipeline, [ExactMatchMetric()])

def test_pipeline_called_from_eval_returns_string(self, llm_pipeline_test):
model_output = llm_pipeline_test.run_pipeline(
{"input_sentence": "Polly wants a cracker"}
)
assert isinstance(model_output, str)

def test_llm_pipelinetest_evaluates(self, llm_pipeline_test):
model_eval = llm_pipeline_test.evaluate(
name="Testing the parrot pipeline",
input_data={"input_sentence": "Polly wants a cracker"},
expected_output="Polly wants a cracker",
)
assert isinstance(model_eval, dict)

0 comments on commit bcd664c

Please sign in to comment.