Skip to content

Commit

Permalink
Refactor Engine based on PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 13, 2024
1 parent caaeb44 commit 0449e00
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 109 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## Fixed

- Exception when calling `Structure.to_json()` after it has run.

### Added

Expand All @@ -20,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `BranchTask` in `StructureVisualizer`.
- `EvalEngine` for evaluating the performance of an LLM's output against a given input.

## Fixed

- Exception when calling `Structure.to_json()` after it has run.

## [1.0.0] - 2024-12-09

### Added
Expand Down
10 changes: 0 additions & 10 deletions docs/griptape-framework/engines/eval-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@ Eval Engines require either [criteria](../../reference/griptape/engines/eval/eva
If `criteria` is set, Griptape will generate `evaluation_steps` for you. This is useful for getting started, but you may want to customize the evaluation steps for more complex evaluations.
`evaluation_steps` are a list of instructions provided to the LLM to evaluate the input.

Eval Engines also require a list of [evaluation_params](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.evaluation_params).
These parameters must be passed as arguments when running [EvalEngine.evaluate](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.evaluate). The available parameters are:

- [EvalEngine.Param.INPUT](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.Param.INPUT): The original input to the LLM.
- [EvalEngine.Param.ACTUAL_OUTPUT](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.Param.ACTUAL_OUTPUT): The actual output from the LLM.
- [EvalEngine.Param.EXPECTED_OUTPUT](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.Param.EXPECTED_OUTPUT): The expected output from the LLM.
- [EvalEngine.Param.CONTEXT](../../reference/griptape/engines/eval/eval_engine.md#griptape.engines.eval.eval_engine.EvalEngine.Param.CONTEXT): Additional context to be included as part of the evaluation.

`INPUT` and `ACTUAL_OUTPUT` are required; all other parameters are optional.

```python
--8<-- "docs/griptape-framework/engines/src/eval_engines_1.py"
```
Expand Down
5 changes: 0 additions & 5 deletions docs/griptape-framework/engines/src/eval_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@

engine = EvalEngine(
criteria="Determine whether the actual output is factually correct based on the expected output.",
evaluation_params=[
EvalEngine.Param.INPUT,
EvalEngine.Param.EXPECTED_OUTPUT,
EvalEngine.Param.ACTUAL_OUTPUT,
],
)

score, reason = engine.evaluate(
Expand Down
75 changes: 34 additions & 41 deletions griptape/engines/eval/eval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,37 @@

import json
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Optional

import schema
from attrs import Attribute, Factory, define, field, validators

from griptape.common import Message, PromptStack
from griptape.configs import Defaults
from griptape.engines import BaseEvalEngine
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.rules import JsonSchemaRule
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.drivers import BasePromptDriver

STEPS_SCHEMA = schema.Schema(
{
"steps": [str],
}
)

RESULTS_SCHEMA = schema.Schema(
{
"score": float,
"reason": str,
}
)


@define(kw_only=True)
class EvalEngine(BaseEvalEngine, SerializableMixin):
class Param(Enum):
INPUT = "Input"
ACTUAL_OUTPUT = "Actual Output"
EXPECTED_OUTPUT = "Expected Output"
CONTEXT = "Context"

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
name: str = field(
default=Factory(lambda self: self.id, takes_self=True),
Expand All @@ -33,9 +41,6 @@ class Param(Enum):
criteria: Optional[str] = field(default=None, metadata={"serializable": True})
evaluation_steps: Optional[list[str]] = field(default=None, metadata={"serializable": True})
prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
evaluation_params: list[EvalEngine.Param] = field(
default=Factory(lambda: [EvalEngine.Param.INPUT, EvalEngine.Param.ACTUAL_OUTPUT])
)
generate_steps_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/system.j2")))
generate_steps_user_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/user.j2")))
generate_results_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/results/system.j2")))
Expand Down Expand Up @@ -67,39 +72,23 @@ def validate_evaluation_steps(self, _: Attribute, value: Optional[list[str]]) ->
if not value:
raise ValueError("evaluation_steps must not be empty")

@evaluation_params.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_evaluation_params(self, attribute: Attribute, value: list[EvalEngine.Param]) -> None:
if not value:
raise ValueError("evaluation_params must not be empty")

if EvalEngine.Param.INPUT not in value:
raise ValueError("Input is required in evaluation_params")
def evaluate(self, input: str, actual_output: str, **kwargs) -> tuple[float, str]: # noqa: A002
evaluation_params = {
key.replace("_", " ").title(): value
for key, value in {"input": input, "actual_output": actual_output, **kwargs}.items()
}

if EvalEngine.Param.ACTUAL_OUTPUT not in value:
raise ValueError("Actual Output is required in evaluation_params")

def __attrs_post_init__(self) -> None:
if self.evaluation_steps is None:
with validators.disabled():
self.evaluation_steps = self._generate_steps()

def evaluate(self, **kwargs) -> tuple[float, str]:
if not all(key.upper() in EvalEngine.Param.__members__ for key in kwargs):
raise ValueError("All keys in kwargs must be a member of EvalEngine.Param")

if EvalEngine.Param.INPUT.name.lower() not in kwargs:
raise ValueError("Input is required for evaluation")
if EvalEngine.Param.ACTUAL_OUTPUT.name.lower() not in kwargs:
raise ValueError("Actual Output is required for evaluation")
self.evaluation_steps = self._generate_steps(evaluation_params)

text = "\n\n".join(f"{EvalEngine.Param[key.upper()]}: {value}" for key, value in kwargs.items())
return self._generate_results(evaluation_params)

return self._generate_results(text)

def _generate_steps(self) -> list[str]:
def _generate_steps(self, evaluation_params: dict[str, str]) -> list[str]:
system_prompt = self.generate_steps_system_template.render(
parameters=self.__generate_evaluation_params_str(self.evaluation_params),
evaluation_params=self.__generate_evaluation_params_text(evaluation_params),
criteria=self.criteria,
json_schema_rule=JsonSchemaRule(STEPS_SCHEMA.json_schema("Output Format")),
)
user_prompt = self.generate_steps_user_template.render()

Expand All @@ -116,11 +105,12 @@ def _generate_steps(self) -> list[str]:

return parsed_result["steps"]

def _generate_results(self, text: str) -> tuple[float, str]:
def _generate_results(self, evaluation_params: dict[str, str]) -> tuple[float, str]:
system_prompt = self.generate_results_system_template.render(
parameters=self.__generate_evaluation_params_str(self.evaluation_params),
evaluation_params=self.__generate_evaluation_params_text(evaluation_params),
evaluation_steps=self.evaluation_steps,
text=text,
evaluation_text=self.__generate_evaluation_text(evaluation_params),
json_schema_rule=JsonSchemaRule(RESULTS_SCHEMA.json_schema("Output Format")),
)
user_prompt = self.generate_results_user_template.render()

Expand All @@ -140,5 +130,8 @@ def _generate_results(self, text: str) -> tuple[float, str]:

return score, reason

def __generate_evaluation_params_str(self, evaluation_params: list[EvalEngine.Param]) -> str:
return ", ".join(param.value for param in evaluation_params)
def __generate_evaluation_params_text(self, evaluation_params: dict[str, str]) -> str:
return ", ".join(param for param in evaluation_params)

def __generate_evaluation_text(self, evaluation_params: dict[str, str]) -> str:
return "\n\n".join(f"{key}: {value}" for key, value in evaluation_params.items())
12 changes: 2 additions & 10 deletions griptape/templates/engines/eval/results/system.j2
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@ Given the evaluation steps, return a JSON with two keys:
Evaluation Steps:
{{ evaluation_steps }}

{{ text }}
{{ evaluation_text }}

**
IMPORTANT: Please make sure to only return in JSON format, with the "score" and "reason" key. No words or explanation is needed.

Example JSON:
{
"score": 0,
"reason": "The text does not follow the evaluation steps provided."
}
**
{{ json_schema_rule }}
10 changes: 3 additions & 7 deletions griptape/templates/engines/eval/steps/system.j2
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
Given an evaluation criteria which outlines how you should judge the {{ parameters }}, generate 3-4 concise evaluation steps based on the criteria below.
You MUST make it clear how to evaluate {{ parameters }} in relation to one another.
Given an evaluation criteria which outlines how you should judge the {{ evaluation_params }}, generate 3-4 concise evaluation steps based on the criteria below.
You MUST make it clear how to evaluate {{ evaluation_params }} in relation to one another.

Evaluation Criteria:
{{ criteria }}

IMPORTANT: Please make sure to only return in JSON format, with the "steps" key as a list of strings. No words or explanation is needed.
Example JSON:
{
"steps": <list_of_strings>
}
{{ json_schema_rule }}
35 changes: 2 additions & 33 deletions tests/unit/engines/eval/test_eval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,8 @@ def engine(self):
),
)

def test_generate_evaluation_steps(self):
engine = EvalEngine(
criteria="foo",
prompt_driver=MockPromptDriver(
mock_output=json.dumps(
{
"steps": ["mock output"],
}
),
),
)
def test_generate_evaluation_steps(self, engine):
engine.evaluate(input="foo", actual_output="bar")
assert engine.evaluation_steps == ["mock output"]

def test_validate_criteria(self):
Expand Down Expand Up @@ -79,16 +70,6 @@ def test_validate_evaluation_steps(self):
with pytest.raises(ValueError, match="evaluation_steps must not be empty"):
assert EvalEngine(evaluation_steps=[])

def test_validate_evaluation_params(self):
with pytest.raises(ValueError, match="evaluation_params must not be empty"):
assert EvalEngine(evaluation_steps=["foo"], evaluation_params=[])

with pytest.raises(ValueError, match="Input is required in evaluation_params"):
assert EvalEngine(evaluation_steps=["foo"], evaluation_params=[EvalEngine.Param.EXPECTED_OUTPUT])

with pytest.raises(ValueError, match="Actual Output is required in evaluation_params"):
assert EvalEngine(evaluation_steps=["foo"], evaluation_params=[EvalEngine.Param.INPUT])

def test_evaluate(self):
engine = EvalEngine(
evaluation_steps=["foo"],
Expand All @@ -109,15 +90,3 @@ def test_evaluate(self):

assert score == 0.0
assert reason == "mock output"

def test_evaluate_invalid_param(self, engine):
with pytest.raises(ValueError, match="All keys in kwargs must be a member of EvalEngine.Param"):
engine.evaluate(foo="bar")

def test_evaluate_missing_input(self, engine):
with pytest.raises(ValueError, match="Input is required for evaluation"):
engine.evaluate()

def test_evaluate_missing_actual_output(self, engine):
with pytest.raises(ValueError, match="Actual Output is required for evaluation"):
engine.evaluate(input="foo")

0 comments on commit 0449e00

Please sign in to comment.