diff --git a/src/inspect_ai/solver/_chain.py b/src/inspect_ai/solver/_chain.py index ebb0fa16d..f57be985b 100644 --- a/src/inspect_ai/solver/_chain.py +++ b/src/inspect_ai/solver/_chain.py @@ -71,8 +71,12 @@ async def __call__( state: TaskState, generate: Generate, ) -> TaskState: + from ._transcript import solver_transcript + for solver in self._solvers: - state = await solver(state, generate) + with solver_transcript(solver, state) as st: + state = await solver(state, generate) + st.complete(state) if state.completed: break diff --git a/src/inspect_ai/solver/_plan.py b/src/inspect_ai/solver/_plan.py index 99651dc4d..e9d02f5e2 100644 --- a/src/inspect_ai/solver/_plan.py +++ b/src/inspect_ai/solver/_plan.py @@ -95,11 +95,15 @@ async def __call__( state: TaskState, generate: Generate, ) -> TaskState: + from ._transcript import solver_transcript + try: # execute steps for index, solver in enumerate(self.steps): # run solver - state = await solver(state, generate) + with solver_transcript(solver, state) as st: + state = await solver(state, generate) + st.complete(state) # tick progress self.progress() @@ -114,7 +118,9 @@ async def __call__( # execute finish if self.finish: - state = await self.finish(state, generate) + with solver_transcript(self.finish, state) as st: + state = await self.finish(state, generate) + st.complete(state) self.progress() # mark completed diff --git a/src/inspect_ai/solver/_solver.py b/src/inspect_ai/solver/_solver.py index fe09d1c92..2429cf882 100644 --- a/src/inspect_ai/solver/_solver.py +++ b/src/inspect_ai/solver/_solver.py @@ -192,33 +192,15 @@ def solver_wrapper(*args: Any, **kwargs: dict[str, Any]) -> Solver: if not is_callable_coroutine(solver): raise TypeError(f"'{solver}' is not declared as an async callable.") - async def solver_with_transcript( - state: TaskState, generate: Generate - ) -> TaskState: - from ._transcript import solver_transcript - - with solver_transcript(solver, state, solver_name) as st: # type: ignore - state = await solver(state, generate) - st.complete(state) - return state - - # don't wrap transcript around compound solver types - from ._chain import Chain - from ._plan import Plan - - target_solver = ( - solver if isinstance(solver, Chain | Plan) else solver_with_transcript - ) - registry_tag( solver_type, - target_solver, + solver, RegistryInfo(type="solver", name=solver_name), *args, **kwargs, ) - return target_solver + return solver return solver_register(cast(SolverType, solver_wrapper), solver_name)