Skip to content

Commit

Permalink
Merge branch 'main' into nerdai/message-queue
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Jun 6, 2024
2 parents d920dd8 + 68e19ee commit 9182713
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 24 deletions.
1 change: 0 additions & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ on:

env:
POETRY_VERSION: "1.6.1"
LLAMA_CLOUD_API_KEY: ${{ secrets.LLAMA_CLOUD_API_KEY }}

jobs:
test:
Expand Down
15 changes: 0 additions & 15 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,16 +1 @@
.pants.d/
dist/
migration_scripts/
venv/
.idea
.venv/
.ipynb_checkpoints
.__pycache__
__pycache__
dev_notebooks/
llamaindex_registry.txt
packages_to_bump_deduped.txt
.env
credentials.json
token.json
.python-version
21 changes: 21 additions & 0 deletions agentfile/agent_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from agentfile.agent_server.base import BaseAgentServer
from agentfile.agent_server.fastapi_agent import FastAPIAgentServer
from agentfile.agent_server.types import (
_Task,
_TaskSate,
_TaskStep,
_TaskStepOutput,
_ChatMessage,
AgentRole,
)

__all__ = [
"BaseAgentServer",
"FastAPIAgentServer",
"_Task",
"_TaskSate",
"_TaskStep",
"_TaskStepOutput",
"_ChatMessage",
"AgentRole",
]
51 changes: 51 additions & 0 deletions agentfile/agent_server/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Dict, List

from agentfile.agent_server.types import _Task, _TaskSate, _TaskStep, _TaskStepOutput


class BaseAgentServer(ABC):
@abstractmethod
def launch(self) -> None:
"""Launch the agent server."""
...

@abstractmethod
async def home(self) -> Dict[str, str]:
"""Get the home page of the server, usually containing status info."""
...

@abstractmethod
async def create_task(self, input: str) -> _Task:
"""Create a new task."""
...

@abstractmethod
async def get_tasks(self) -> List[_Task]:
"""Get a list of all tasks."""
...

@abstractmethod
async def get_task_state(self, task_id: str) -> _TaskSate:
"""Get a specific state of a task."""
...

@abstractmethod
async def get_completed_tasks(self) -> List[_Task]:
"""Get a list of all completed tasks."""
...

@abstractmethod
async def get_task_output(self, task_id: str) -> _TaskStepOutput:
"""Get the output of a task."""
...

@abstractmethod
async def get_task_steps(self, task_id: str) -> List[_TaskStep]:
"""Get the steps of a task."""
...

@abstractmethod
async def get_completed_steps(self, task_id: str) -> List[_TaskStepOutput]:
"""Get the completed steps of a task."""
...
17 changes: 12 additions & 5 deletions agentfile/agent.py → agentfile/agent_server/fastapi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from fastapi import FastAPI, HTTPException
from typing import AsyncGenerator, Dict, List, Literal

from agentfile.schema import _Task, _TaskSate, _TaskStep, _TaskStepOutput, _ChatMessage
from agentfile.agent_server.base import BaseAgentServer
from agentfile.agent_server.types import (
_Task,
_TaskSate,
_TaskStep,
_TaskStepOutput,
_ChatMessage,
)
from llama_index.core.agent import AgentRunner

import logging
Expand All @@ -14,7 +21,7 @@
logging.basicConfig(level=logging.INFO)


class AgentServer:
class FastAPIAgentServer(BaseAgentServer):
def __init__(
self,
agent: AgentRunner,
Expand All @@ -38,7 +45,7 @@ def __init__(
"/tasks", self.get_tasks, methods=["GET"], tags=["Tasks"]
)
self.app.add_api_route(
"/tasks/{task_id}", self.get_task, methods=["GET"], tags=["Tasks"]
"/tasks/{task_id}", self.get_task_state, methods=["GET"], tags=["Tasks"]
)
self.app.add_api_route(
"/completed_tasks",
Expand Down Expand Up @@ -146,7 +153,7 @@ async def get_tasks(self) -> List[_Task]:

return _tasks

async def get_task(self, task_id: str) -> _TaskSate:
async def get_task_state(self, task_id: str) -> _TaskSate:
task_state = self.agent.state.task_dict.get(task_id)
if task_state is None:
raise HTTPException(status_code=404, detail="Task not found")
Expand Down Expand Up @@ -214,5 +221,5 @@ async def reset_agent(self) -> Dict[str, str]:
index = VectorStoreIndex.from_documents([Document.example()])
agent = index.as_chat_engine()

server = AgentServer(agent)
server = FastAPIAgentServer(agent)
server.launch()
20 changes: 19 additions & 1 deletion agentfile/schema.py → agentfile/agent_server/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import uuid
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
from pydantic import BaseModel

from llama_index.core.agent.types import TaskStep, TaskStepOutput, Task
from llama_index.core.agent.runner.base import AgentState, TaskState
from llama_index.core.llms import ChatMessage

# ------ FastAPI types ------


class _Task(BaseModel):
task_id: str
Expand Down Expand Up @@ -113,3 +116,18 @@ def from_chat_message(cls, chat_message: ChatMessage) -> "_ChatMessage":
role=str(chat_message.role),
additional_kwargs=chat_message.additional_kwargs,
)


# ------ General types ------


class AgentRole(BaseModel):
agent_name: str = Field(description="The name of the agent.")
description: str = Field(description="A description of the agent and it's purpose.")
prompt: List[ChatMessage] = Field(
default_factory=list, description="Specific instructions for the agent."
)
agent_id: str = Field(
default_factory=str(uuid.uuid4()),
description="A unique identifier for the agent.",
)
Empty file.
122 changes: 122 additions & 0 deletions agentfile/control_plane/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
What does the processing loop for the control plane look like?
- check message queue
- handle incoming new tasks
- handle incoming general chats
- handle agents returning a completed task
"""

from abc import ABC, abstractmethod


class BaseControlPlane(ABC):
@abstractmethod
def register_agent(self, agent_id: str, agent_role: str) -> None:
"""
Register an agent with the control plane.
:param agent_id: Unique identifier of the agent.
:param agent_role: Role of the agent.
"""
...

@abstractmethod
def deregister_agent(self, agent_id: str) -> None:
"""
Deregister an agent from the control plane.
:param agent_id: Unique identifier of the agent.
"""
...

@abstractmethod
def register_flow(self, flow_id: str, flow_definition: dict) -> None:
"""
Register a flow with the control plane.
:param flow_id: Unique identifier of the flow.
:param flow_definition: Definition of the flow.
"""
...

@abstractmethod
def deregister_flow(self, flow_id: str) -> None:
"""
Deregister a flow from the control plane.
:param flow_id: Unique identifier of the flow.
"""
...

@abstractmethod
def handle_new_task(self, task_id: str, task_definition: dict) -> None:
"""
Submit a task to the control plane.
:param task_id: Unique identifier of the task.
:param task_definition: Definition of the task.
"""
...

@abstractmethod
def send_task_to_agent(self, task_id: str, agent_id: str) -> None:
"""
Send a task to an agent.
:param task_id: Unique identifier of the task.
:param agent_id: Unique identifier of the agent.
"""
...

@abstractmethod
def handle_agent_completion(
self, task_id: str, agent_id: str, result: dict
) -> None:
"""
Handle the completion of a task by an agent.
:param task_id: Unique identifier of the task.
:param agent_id: Unique identifier of the agent.
:param result: Result of the task.
"""
...

@abstractmethod
def get_next_agent(self, task_id: str) -> str:
"""
Get the next agent for a task.
:param task_id: Unique identifier of the task.
:return: Unique identifier of the next agent.
"""
...

@abstractmethod
def get_task_state(self, task_id: str) -> dict:
"""
Get the current state of a task.
:param task_id: Unique identifier of the task.
:return: Current state of the task.
"""
...

@abstractmethod
def request_user_input(self, task_id: str, message: str) -> None:
"""
Request input from the user for a task.
:param task_id: Unique identifier of the task.
:param message: Message to send to the user.
"""
...

@abstractmethod
def handle_user_input(self, task_id: str, user_input: str) -> None:
"""
Handle the user input for a task.
:param task_id: Unique identifier of the task.
:param user_input: Input provided by the user.
"""
...
45 changes: 45 additions & 0 deletions agentfile/control_plane/fastapi_control_plane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, Dict

from agentfile.control_plane.base import BaseControlPlane


class FastAPIControlPlane(BaseControlPlane):
def __init__(self) -> None:
self.agents: Dict[str, Any] = {}
self.flows: Dict[str, Any] = {}
self.tasks: Dict[str, Any] = {}

def register_agent(self, agent_id: str, agent_role: str) -> None:
self.agents[agent_id] = agent_role

def deregister_agent(self, agent_id: str) -> None:
del self.agents[agent_id]

def register_flow(self, flow_id: str, flow_definition: dict) -> None:
self.flows[flow_id] = flow_definition

def deregister_flow(self, flow_id: str) -> None:
del self.flows[flow_id]

def handle_new_task(self, task_id: str, task_definition: dict) -> None:
self.tasks[task_id] = task_definition

def send_task_to_agent(self, task_id: str, agent_id: str) -> None:
pass

def handle_agent_completion(
self, task_id: str, agent_id: str, result: dict
) -> None:
pass

def get_next_agent(self, task_id: str) -> str:
return ""

def get_task_state(self, task_id: str) -> dict:
return self.tasks.get(task_id, None)

def request_user_input(self, task_id: str, message: str) -> None:
pass

def handle_user_input(self, task_id: str, user_input: str) -> None:
pass
4 changes: 2 additions & 2 deletions tests/test_agent_server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from llama_index.core.llms import MockLLM
from llama_index.core.agent import ReActAgent

from agentfile.agent import AgentServer
from agentfile.agent_server import FastAPIAgentServer


def test_init() -> None:
agent = ReActAgent.from_tools([], llm=MockLLM())
server = AgentServer(
server = FastAPIAgentServer(
agent, running=False, description="Test Agent Server", step_interval=0.5
)

Expand Down

0 comments on commit 9182713

Please sign in to comment.