Skip to content

Commit

Permalink
naming
Browse files Browse the repository at this point in the history
  • Loading branch information
art-dsit committed Sep 20, 2024
1 parent 8ad98f2 commit 85305e8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
5 changes: 2 additions & 3 deletions tests/test_eval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import deepcopy
from pathlib import Path

from test_helpers.utils import failing_solver, failing_task, failing_task_det
from test_helpers.utils import failing_solver, failing_task, failing_task_deterministic

from inspect_ai import Task, task
from inspect_ai._eval.evalset import (
Expand Down Expand Up @@ -224,11 +224,10 @@ def test_eval_set_s3(mock_s3) -> None:
def test_eval_zero_retries() -> None:
with tempfile.TemporaryDirectory() as log_dir:
success, logs = eval_set(
tasks=failing_task_det([False, True]),
tasks=failing_task_deterministic([True, False]),
log_dir=log_dir,
retry_attempts=0,
retry_wait=0.1,
model="mockllm/model",
)
assert not success
assert logs[0].status == "error"
17 changes: 11 additions & 6 deletions tests/test_helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,27 +166,32 @@ def failing_task(rate=0.5, samples=1) -> Task:
scorer=match(),
)


@solver
def failing_solver_det(yay_or_nay: Sequence[bool]):
it = iter(yay_or_nay)
def failing_solver_deterministic(should_fail: Sequence[bool]):
it = iter(should_fail)

async def solve(state: TaskState, generate: Generate):
if it.__next__():
should_fail_this_time = it.__next__()
if should_fail_this_time:
raise ValueError("Eval failed!")
return state

return solve


@task
def failing_task_det(yay_or_nay: Sequence[bool]) -> Task:
def failing_task_deterministic(should_fail: Sequence[bool]) -> Task:
dataset: list[Sample] = []
for _ in range(0, len(yay_or_nay)):
for _ in range(0, len(should_fail)):
dataset.append(Sample(input="Say hello", target="hello"))
return Task(
dataset=dataset,
plan=[failing_solver_det(yay_or_nay), generate()],
plan=[failing_solver_deterministic(should_fail), generate()],
scorer=match(),
)


def ensure_test_package_installed():
try:
import inspect_package # type: ignore # noqa: F401
Expand Down

0 comments on commit 85305e8

Please sign in to comment.