Skip to content

Commit

Permalink
track sample task state in solver decorator rather than solver transc…
Browse files Browse the repository at this point in the history
…ript (#532)

Co-authored-by: aisi-inspect <[email protected]>
  • Loading branch information
jjallaire-aisi and aisi-inspect authored Sep 26, 2024
1 parent 7a17f1b commit 26acdb7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
- Capture solver input params for subtasks created by `fork()`.
- Allow Docker sandboxes configured with `x-default` to be referred to by their declared service name.
- Require a `max_messages` for use of `basic_agent()` (as without it, the agent could end up in an infinite loop).
- Track sample task state in solver decorator rather than solver transcript.
- Display solver input parameters for forked subtasks.
- Improvements to docker compose down cleanup: timeout, survive missing compose files.


## v0.3.32 (25 September 2024)

- Fix issue w/ subtasks not getting a fresh store() (regression from introduction of `fork()` in v0.3.30)
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from inspect_ai.solver._chain import Chain, unroll
from inspect_ai.solver._fork import set_task_generate
from inspect_ai.solver._solver import Solver
from inspect_ai.solver._task_state import state_jsonable
from inspect_ai.solver._task_state import set_sample_state, state_jsonable
from inspect_ai.util._subtask import init_subtask

from ..context import init_task_context
Expand Down Expand Up @@ -385,6 +385,7 @@ async def task_run_sample(
)

# initialise subtask and scoring context
set_sample_state(state)
init_subtask(SAMPLE_SUBTASK, state.store)
if scorers:
init_scoring_context(scorers, Target(sample.target))
Expand Down
37 changes: 34 additions & 3 deletions src/inspect_ai/solver/_solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
from dataclasses import dataclass, field
from functools import wraps
from typing import (
Any,
Callable,
Expand All @@ -22,7 +24,7 @@
)
from inspect_ai.model import CachePolicy, GenerateConfigArgs

from ._task_state import TaskState
from ._task_state import TaskState, set_sample_state


@runtime_checkable
Expand Down Expand Up @@ -192,15 +194,44 @@ def solver_wrapper(*args: Any, **kwargs: dict[str, Any]) -> Solver:
if not is_callable_coroutine(solver):
raise TypeError(f"'{solver}' is not declared as an async callable.")

# if the solver is a class then we inject state tracking
# by patching the __call__ method (this is because we
# want to preserve the type, especially for code that e.g.
# checks for Chain or Plan)
if inspect.isclass(type(solver)):
original_call = solver.__call__

async def call_with_state(
state: TaskState, generate: Generate
) -> TaskState:
state = await original_call(state, generate)
set_sample_state(state)
return state

registered_solver = solver
setattr(registered_solver, "__call__", call_with_state)

# if its a function then use ordinary @wraps to preserve
# the wrapped solver
else:

@wraps(solver)
async def registered_solver(
state: TaskState, generate: Generate
) -> TaskState:
state = await solver(state, generate)
set_sample_state(state)
return state

registry_tag(
solver_type,
solver,
registered_solver,
RegistryInfo(type="solver", name=solver_name),
*args,
**kwargs,
)

return solver
return registered_solver

return solver_register(cast(SolverType, solver_wrapper), solver_name)

Expand Down
3 changes: 1 addition & 2 deletions src/inspect_ai/solver/_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect_ai._util.registry import registry_log_name

from ._solver import Solver
from ._task_state import TaskState, set_sample_state, state_jsonable
from ._task_state import TaskState, state_jsonable


class SolverTranscript:
Expand All @@ -28,7 +28,6 @@ def solver_transcript(
) -> Iterator[SolverTranscript]:
from inspect_ai.log._transcript import transcript

set_sample_state(state)
name = registry_log_name(name or solver)
with transcript().step(name=name, type="solver"):
yield SolverTranscript(name, state)

0 comments on commit 26acdb7

Please sign in to comment.