Skip to content

Commit

Permalink
feat(router): implement model and tool handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
polebug committed Jan 17, 2025
1 parent deefeec commit f6aa449
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 8 deletions.
4 changes: 3 additions & 1 deletion openagent/router/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .chat import router as chat_router
from .agent import router as agent_router
from .model import router as model_router
from .tool import router as tool_router

__all__ = ["chat_router", "agent_router"]
__all__ = ["chat_router", "agent_router", "model_router", "tool_router"]
24 changes: 18 additions & 6 deletions openagent/router/routes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from openagent.db.models.model import Model
from openagent.db.models.tool import Tool
from openagent.router.routes.models.request import CreateAgentRequest
from openagent.router.routes.models.response import AgentResponse, ResponseModel
from openagent.router.routes.models.response import (
AgentResponse,
ResponseModel,
AgentListResponse,
)
from openagent.router.error import APIExceptionResponse
from openagent.tools import ToolConfig
from openagent.db import get_db
Expand Down Expand Up @@ -95,7 +99,7 @@ def create_agent(

@router.get(
"",
response_model=ResponseModel[List[AgentResponse]],
response_model=ResponseModel[AgentListResponse],
summary="List all agents",
description="Get a paginated list of all agents",
responses={
Expand All @@ -105,13 +109,17 @@ def create_agent(
)
def list_agents(
page: int = 0, limit: int = 10, db: Session = Depends(get_db)
) -> Union[ResponseModel[List[AgentResponse]], APIExceptionResponse]:
) -> Union[ResponseModel[dict], APIExceptionResponse]:
try:
agents = db.query(Agent).offset(page).limit(limit).all()
total = db.query(Agent).count()
agents = db.query(Agent).offset(page * limit).limit(limit).all()
return ResponseModel(
code=status.HTTP_200_OK,
data=[AgentResponse.model_validate(agent) for agent in agents],
message=f"Retrieved {len(agents)} agents",
data=AgentListResponse(
agents=[AgentResponse.model_validate(agent) for agent in agents],
total=total,
),
message=f"Retrieved {len(agents)} agents out of {total}",
)
except Exception as error:
return APIExceptionResponse(
Expand Down Expand Up @@ -173,6 +181,10 @@ def update_agent(
error=f"Agent with ID {agent_id} not found",
)

# check if the tool_configs are valid
if error := check_tool_configs(request.tool_configs, db):
return error

for key, value in request.model_dump(exclude_unset=True).items():
setattr(agent, key, value)

Expand Down
44 changes: 44 additions & 0 deletions openagent/router/routes/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from typing import Union

from openagent.db.models.model import Model
from openagent.router.routes.models.response import (
ModelResponse,
ModelListResponse,
ResponseModel,
)
from openagent.router.error import APIExceptionResponse
from openagent.db import get_db

router = APIRouter(prefix="/models", tags=["models"])


@router.get(
"",
response_model=ResponseModel[ModelListResponse],
summary="List all models",
description="Get a paginated list of all models",
responses={
200: {"description": "Successfully retrieved models"},
500: {"description": "Internal server error"},
},
)
def list_models(
page: int = 0, limit: int = 10, db: Session = Depends(get_db)
) -> Union[ResponseModel[ModelListResponse], APIExceptionResponse]:
try:
total = db.query(Model).count()
models = db.query(Model).offset(page * limit).limit(limit).all()
return ResponseModel(
code=status.HTTP_200_OK,
data=ModelListResponse(
models=[ModelResponse.model_validate(model) for model in models],
total=total,
),
message=f"Retrieved {len(models)} models out of {total}",
)
except Exception as error:
return APIExceptionResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, error=error
)
31 changes: 31 additions & 0 deletions openagent/router/routes/models/response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Generic, TypeVar, List
from pydantic import BaseModel
from openagent.db.models.agent import AgentStatus
from openagent.db.models.tool import ToolType
from openagent.tools import ToolConfig

T = TypeVar("T")
Expand Down Expand Up @@ -28,3 +29,33 @@ class AgentResponse(BaseModel):
website: Optional[str] = None
tool_configs: Optional[List[ToolConfig]] = None
status: AgentStatus


class AgentListResponse(BaseModel):
agents: List[AgentResponse]
total: int


class ModelResponse(BaseModel):
id: int
name: str
description: Optional[str] = None
capability_score: float
capabilities: Optional[str] = None


class ModelListResponse(BaseModel):
models: List[ModelResponse]
total: int


class ToolResponse(BaseModel):
id: int
name: str
description: Optional[str] = None
type: ToolType


class ToolListResponse(BaseModel):
tools: List[ToolResponse]
total: int
43 changes: 43 additions & 0 deletions openagent/router/routes/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from typing import Union

from openagent.db.models.tool import Tool
from openagent.router.routes.models.response import (
ToolResponse,
ToolListResponse,
ResponseModel,
)
from openagent.router.error import APIExceptionResponse
from openagent.db import get_db

router = APIRouter(prefix="/tools", tags=["tools"])


@router.get(
"",
response_model=ResponseModel[ToolListResponse],
summary="List all tools",
description="Get a paginated list of all tools",
responses={
200: {"description": "Successfully retrieved tools"},
500: {"description": "Internal server error"},
},
)
def list_tools(
page: int = 0, limit: int = 10, db: Session = Depends(get_db)
) -> Union[ResponseModel[ToolListResponse], APIExceptionResponse]:
try:
total = db.query(Tool).count()
tools = db.query(Tool).offset(page * limit).limit(limit).all()
return ResponseModel(
code=status.HTTP_200_OK,
data=ToolListResponse(
tools=[ToolResponse.model_validate(tool) for tool in tools], total=total
),
message=f"Retrieved {len(tools)} tools out of {total}",
)
except Exception as error:
return APIExceptionResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, error=error
)
4 changes: 3 additions & 1 deletion openagent/router/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from .routes import chat_router, agent_router
from .routes import chat_router, agent_router, model_router, tool_router

app = FastAPI(
title="OpenAgent API",
Expand All @@ -21,3 +21,5 @@

app.include_router(chat_router)
app.include_router(agent_router)
app.include_router(model_router)
app.include_router(tool_router)

0 comments on commit f6aa449

Please sign in to comment.