Skip to content

Commit

Permalink
Moved fake agent to LLM propose complete tool (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Nov 16, 2024
1 parent b45d437 commit 6fda922
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
1 change: 1 addition & 0 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def litellm_get_search_query(
" Ignoring template and using default search prompt."
)
if not search_prompt:
# TODO: move to use tools instead of DIY schema in prompt
search_prompt = (
"We want to answer the following question: {question}\nProvide"
" {count} unique keyword searches (one search per line) and year ranges"
Expand Down
50 changes: 36 additions & 14 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aviary.core import (
MalformedMessageError,
Message,
Tool,
ToolCall,
ToolRequestMessage,
ToolSelector,
Expand Down Expand Up @@ -38,7 +37,13 @@ class Callback: # type: ignore[no-redef]
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex, get_directory_index
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch
from .tools import (
Complete,
EnvironmentState,
GatherEvidence,
GenerateAnswer,
PaperSearch,
)

if TYPE_CHECKING:
from aviary.core import Environment
Expand Down Expand Up @@ -186,7 +191,7 @@ async def run_fake_agent(
" applicable with the fake agent, ignoring it."
)
env = env_class(query, docs, **env_kwargs)
_, tools = await env.reset()
obs, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)

Expand All @@ -198,25 +203,42 @@ async def run_fake_agent(
generate_answer_tool = next(
filter(lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME, tools)
)

async def step(tool: Tool, **call_kwargs) -> None:
action = ToolRequestMessage(
tool_calls=[ToolCall.from_tool(tool, **call_kwargs)]
complete_tool = next(filter(lambda x: x.info.name == Complete.TOOL_FN_NAME, tools))
agent_messages = obs.copy() # Copy just to be safe

async def step(action: list[ToolCall] | ToolRequestMessage) -> None:
action = (
action
if isinstance(action, ToolRequestMessage)
else ToolRequestMessage(tool_calls=action)
)
agent_messages.append(action)
if on_agent_action_callback:
await on_agent_action_callback(action, env.state)
obs, reward, done, truncated = await env.step(action)
agent_messages.extend(obs)
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)

async def rollout() -> AgentStatus:
# Seed docs with a few keyword searches
for search in await litellm_get_search_query(
question, llm=query.settings.get_llm(), count=3
):
await step(search_tool, query=search, min_year=None, max_year=None)
await step(gather_evidence_tool, question=question)
await step(generate_answer_tool)
llm_model = query.settings.get_llm()

# Seed docs with a few LLM-proposed search calls
# TODO: make properly support year ranges
for search in await litellm_get_search_query(question, llm=llm_model, count=3):
search_tcs = [
ToolCall.from_tool(
search_tool, query=search, min_year=None, max_year=None
)
]
await step(search_tcs)
await step([ToolCall.from_tool(gather_evidence_tool, question=question)])
await step([ToolCall.from_tool(generate_answer_tool)])
# Complete with an LLM-proposed complete call
complete_action = await llm_model.select_tool(
messages=agent_messages, tools=tools, tool_choice=complete_tool
)
await step(complete_action)
return AgentStatus.SUCCESS

return await _run_with_timeout_failure(rollout, query, env)
Expand Down
12 changes: 11 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import litellm
import numpy as np
import tiktoken
from aviary.core import ToolRequestMessage, ToolSelector
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -573,7 +574,7 @@ def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]:
return {"num_retries": 3, "timeout": timeout}


class PassThroughRouter(litellm.Router):
class PassThroughRouter(litellm.Router): # TODO: add rate_limited
"""Router that is just a wrapper on LiteLLM's normal free functions."""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -753,6 +754,15 @@ def infer_llm_type(self) -> str:
def count_tokens(self, text: str) -> int:
return litellm.token_counter(model=self.name, text=text)

async def select_tool(
self, *selection_args, **selection_kwargs
) -> ToolRequestMessage:
"""Shim to aviary.core.ToolSelector that supports tool schemae."""
tool_selector = ToolSelector(
model_name=self.name, acompletion=self.router.acompletion
)
return await tool_selector(*selection_args, **selection_kwargs)


def cosine_similarity(a, b):
norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
Expand Down

0 comments on commit 6fda922

Please sign in to comment.