Skip to content

Commit

Permalink
feat(router): implement agent stop api
Browse files Browse the repository at this point in the history
  • Loading branch information
polebug committed Jan 27, 2025
1 parent bc49cf0 commit f095b25
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 75 deletions.
77 changes: 67 additions & 10 deletions openagent/router/routes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
PublicAgentResponse,
ResponseModel,
)
from openagent.tools import BaseTool, ToolConfig, ToolParameters
from openagent.tools import BaseTool, ToolConfig

auth_handler = Auth()

Expand Down Expand Up @@ -175,7 +175,7 @@ def list_agents(
)
def get_agent(
agent_id: int,
wallet_address: str = Depends(auth_handler.auth_wrapper),
wallet_address: str | None = Depends(auth_handler.optional_auth_wrapper),
db: Session = Depends(get_db),
) -> ResponseModel[AgentResponse] | APIExceptionResponse:
try:
Expand All @@ -186,14 +186,16 @@ def get_agent(
error=f"Agent with ID {agent_id} not found",
)

if agent.wallet_address.lower() != wallet_address.lower():
return APIExceptionResponse(
status_code=status.HTTP_403_FORBIDDEN,
error="Not authorized to query this agent",
)
# return full agent info if authenticated and wallet addresses match
if wallet_address and agent.wallet_address.lower() == wallet_address.lower():
response_data = AgentResponse.model_validate(agent)
else:
# return public info for unauthenticated users or non-owners
response_data = PublicAgentResponse.model_validate(agent)

return ResponseModel(
code=status.HTTP_200_OK,
data=AgentResponse.model_validate(agent),
data=AgentResponse.model_validate(response_data),
message="Agent retrieved successfully",
)
except Exception as error:
Expand Down Expand Up @@ -355,6 +357,59 @@ def run_agent(
)


@router.post(
"/{agent_id}/stop",
response_model=ResponseModel[AgentResponse],
summary="Stop an agent",
description="Stop an agent by setting its status to inactive",
responses={
200: {"description": "Successfully stopped agent"},
403: {"description": "Not authorized to stop this agent"},
404: {"description": "Agent not found"},
500: {"description": "Internal server error"},
},
)
def stop_agent(
agent_id: int,
wallet_address: str = Depends(auth_handler.auth_wrapper),
db: Session = Depends(get_db),
) -> ResponseModel[AgentResponse] | APIExceptionResponse:
try:
# get agent
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
return APIExceptionResponse(
status_code=status.HTTP_404_NOT_FOUND,
error=f"Agent with ID {agent_id} not found",
)

# check if the user is authorized to stop this agent
if agent.wallet_address.lower() != wallet_address.lower():
return APIExceptionResponse(
status_code=status.HTTP_403_FORBIDDEN,
error="Not authorized to stop this agent",
)

# update the agent status to inactive
agent.status = AgentStatus.INACTIVE
db.commit()
db.refresh(agent)

# TODO: stop any running agent processes

return ResponseModel(
code=status.HTTP_200_OK,
data=AgentResponse.model_validate(agent),
message="Agent stopped successfully",
)
except Exception as error:
db.rollback()
return APIExceptionResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error=error,
)


@router.post(
"/{agent_id}/execute/{tool_name}",
response_model=ResponseModel[dict[str, Any]],
Expand Down Expand Up @@ -473,9 +528,11 @@ def build_model(model: Model) -> AI_Model:
raise ValueError(f"Unsupported model: {model}")


def initialize_tool_executor(agent: Agent, tool: Tool, model: Model, tool_config: ToolConfig) -> BaseTool:
def initialize_tool_executor(
agent: Agent, tool: Tool, model: Model, tool_config: ToolConfig
) -> BaseTool:
model_instance = build_model(model)

return get_tool_executor(agent, tool, model_instance, tool_config)


Expand Down
18 changes: 14 additions & 4 deletions openagent/router/routes/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, status, Query
from sqlalchemy.orm import Session

from openagent.database import get_db
Expand All @@ -24,11 +24,21 @@
},
)
def list_models(
page: int = 0, limit: int = 10, db: Session = Depends(get_db)
page: int = 0,
limit: int = 10,
ids: list[int] | None = Query(default=None),
db: Session = Depends(get_db),
) -> ResponseModel[ModelListResponse] | APIExceptionResponse:
try:
total = db.query(Model).count()
models = db.query(Model).offset(page * limit).limit(limit).all()
query = db.query(Model)

# Add filter for model ids if provided
if ids:
query = query.filter(Model.id.in_(ids))

total = query.count()
models = query.offset(page * limit).limit(limit).all()

return ResponseModel(
code=status.HTTP_200_OK,
data=ModelListResponse(
Expand Down
22 changes: 20 additions & 2 deletions openagent/router/routes/models/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from web3 import Web3

security = HTTPBearer()
security = HTTPBearer(auto_error=False)


class Auth:
Expand Down Expand Up @@ -71,7 +71,25 @@ def verify_wallet_signature(
)

def auth_wrapper(
self, auth: HTTPAuthorizationCredentials = Security(security)
self, auth: HTTPAuthorizationCredentials | None = Security(security)
) -> str:
if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No authorization token provided",
headers={"WWW-Authenticate": "Bearer"},
)

payload = self.decode_token(auth.credentials)
return payload["wallet_address"]

def optional_auth_wrapper(
self, auth: HTTPAuthorizationCredentials | None = Security(security)
) -> str | None:
if not auth:
return None
try:
payload = self.decode_token(auth.credentials)
return payload["wallet_address"]
except HTTPException:
return None
19 changes: 11 additions & 8 deletions openagent/router/routes/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class AgentResponse(BaseModel):
created_at: datetime
updated_at: datetime


class PublicToolConfigResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)

name: str
description: str | None = None
tool_id: int
model_id: int



class PublicAgentResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)

Expand All @@ -70,18 +72,19 @@ class PublicAgentResponse(BaseModel):

@classmethod
def from_orm(cls, obj):
if hasattr(obj, 'tool_configs') and obj.tool_configs:
if hasattr(obj, "tool_configs") and obj.tool_configs:
obj.tool_configs = [
PublicToolConfigResponse(
name=tc['name'],
description=tc.get('description'),
tool_id=tc['tool_id'],
model_id=tc['model_id']
name=tc["name"],
description=tc.get("description"),
tool_id=tc["tool_id"],
model_id=tc["model_id"],
)
for tc in obj.tool_configs
]
return super().from_orm(obj)


class AgentListResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)

Expand All @@ -91,7 +94,7 @@ class AgentListResponse(BaseModel):

class ModelResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int
name: str
description: str | None = None
Expand All @@ -106,7 +109,7 @@ class ModelListResponse(BaseModel):

class ToolResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int
name: str
description: str | None = None
Expand Down
18 changes: 14 additions & 4 deletions openagent/router/routes/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, status, Query
from sqlalchemy.orm import Session

from openagent.database import get_db
Expand All @@ -24,11 +24,21 @@
},
)
def list_tools(
page: int = 0, limit: int = 10, db: Session = Depends(get_db)
page: int = 0,
limit: int = 10,
ids: list[int] | None = Query(default=None),
db: Session = Depends(get_db),
) -> ResponseModel[ToolListResponse] | APIExceptionResponse:
try:
total = db.query(Tool).count()
tools = db.query(Tool).offset(page * limit).limit(limit).all()
query = db.query(Tool)

# Add filter for tool ids if provided
if ids:
query = query.filter(Tool.id.in_(ids))

total = query.count()
tools = query.offset(page * limit).limit(limit).all()

return ResponseModel(
code=status.HTTP_200_OK,
data=ToolListResponse(
Expand Down
43 changes: 22 additions & 21 deletions openagent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,40 @@
from phi.tools import Toolkit
from pydantic import BaseModel, ConfigDict, model_validator


class TriggerType(Enum):
SCHEDULED = "scheduled"
AUTO = "auto"
Manual = "manual"

def __str__(self) -> str:
return self.value


class ToolParameters(BaseModel):
model_config = ConfigDict(
from_attributes=True,
arbitrary_types_allowed=True,
json_encoders={TriggerType: lambda v: v.value}
json_encoders={TriggerType: lambda v: v.value},
)

trigger_type: TriggerType
schedule: str | None = None # cron, such as "0 */2 * * *"
config: dict | None = None

def validate_schedule(self):
if self.trigger_type == TriggerType.SCHEDULED and not self.schedule:
raise ValueError("Schedule must be set when trigger_type is SCHEDULED")

def model_dump(self, *args, **kwargs) -> dict:
# Add custom serialization for TriggerType
data = super().model_dump(*args, **kwargs)
data['trigger_type'] = self.trigger_type.value
data["trigger_type"] = self.trigger_type.value
return data


class ToolConfig(BaseModel):
model_config = ConfigDict(
from_attributes=True,
arbitrary_types_allowed=True
)
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

name: str
description: str | None = None
Expand All @@ -51,33 +50,34 @@ def validate_parameters(self):
if self.parameters:
self.parameters.validate_schedule()


def model_dump(self, *args, **kwargs) -> dict:
data = {
"name": self.name,
"description": self.description,
"tool_id": self.tool_id,
"model_id": self.model_id,
}

if self.parameters:
data["parameters"] = self.parameters.model_dump()

return data


class TwitterToolParameters(ToolParameters):
model_config = ConfigDict(
from_attributes=True,
arbitrary_types_allowed=True
)
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_twitter_config(cls, data: Dict) -> Dict:
if isinstance(data, dict) and "config" in data and isinstance(data["config"], dict):
if (
isinstance(data, dict)
and "config" in data
and isinstance(data["config"], dict)
):
data["config"] = {
"access_token": data["config"].get("access_token"),
"access_token_secret": data["config"].get("access_token_secret")
"access_token_secret": data["config"].get("access_token_secret"),
}
return data

Expand Down Expand Up @@ -113,10 +113,11 @@ def validate_params(self, params: dict[str, Any]) -> tuple[bool, str]:
"""
pass


__all__ = [
"BaseTool",
"ToolConfig",
"ToolParameters",
"TwitterToolParameters",
"TriggerType",
]
]
Loading

0 comments on commit f095b25

Please sign in to comment.