diff --git a/openagent/agents/__init__.py b/openagent/agents/__init__.py deleted file mode 100644 index 264c4259..00000000 --- a/openagent/agents/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -from os import environ - -from phi.agent import Agent -from phi.model.base import Model -from phi.model.ollama import Ollama -from phi.model.openai import OpenAIChat -from phi.model.anthropic import Claude -from phi.model.google import Gemini - -from .finance import finance_agent -from .feed import feed_agent -from .twitter import twitter_agent - - -class UnsupportedModel(Exception): - def __init__(self, model: str): - self.model = model - - def __str__(self): - return f"Unsupported model {self.model}" - - -def build_model(model: str) -> Model: - (provider, model_id) = model.split("/") - - match provider: - case "openai": - return OpenAIChat(id=model_id, base_url=environ.get("OPENAI_BASE_URL")) - case "anthropic": - return Claude( - id=model_id, - client_params={"base_url": environ.get("ANTHROPIC_BASE_URL")}, - ) - case "google": - return Gemini( - id=model_id, client_params={"base_url": environ.get("GOOGLE_BASE_URL")} - ) - case "ollama": - return Ollama(id=model_id, host=environ.get("OLLAMA_BASE_URL")) - case _: - raise UnsupportedModel(model) - - -def build_agent_team(model: str) -> Agent: - return Agent( - team=[finance_agent, feed_agent, twitter_agent], - model=build_model(model), - ) - - -__all__ = [build_agent_team] diff --git a/openagent/agents/feed.py b/openagent/agents/feed.py deleted file mode 100644 index eeb8f7fb..00000000 --- a/openagent/agents/feed.py +++ /dev/null @@ -1,8 +0,0 @@ -from phi.agent import Agent - -from openagent.tools import DSLTools - -feed_agent = Agent( - name="Feed Agent", - tools=[DSLTools()], -) diff --git a/openagent/agents/finance.py b/openagent/agents/finance.py deleted file mode 100644 index 5fb4131f..00000000 --- a/openagent/agents/finance.py +++ /dev/null @@ -1,8 +0,0 @@ -from phi.agent import Agent - -from openagent.tools import CoinGeckoTools - -finance_agent = Agent( - name="Finance Agent", - tools=[CoinGeckoTools()], -) diff --git a/openagent/agents/twitter.py b/openagent/agents/twitter.py deleted file mode 100644 index 26c92473..00000000 --- a/openagent/agents/twitter.py +++ /dev/null @@ -1,10 +0,0 @@ -from phi.agent import Agent -from ..tools.twitter.tweet_generator import TweetGeneratorTools - -twitter_agent = Agent( - name="twitter_agent", - description="An agent that generates and posts tweets in different personalities", - tools=[TweetGeneratorTools()], -) - -__all__ = ["twitter_agent"] diff --git a/openagent/database/models/agent.py b/openagent/database/models/agent.py index 85b1a31c..77eaea8b 100644 --- a/openagent/database/models/agent.py +++ b/openagent/database/models/agent.py @@ -1,15 +1,20 @@ import enum +from typing import List from sqlalchemy import Column, Integer, String, DateTime, JSON, Enum from datetime import datetime, UTC from openagent.database.models.base import Base +from openagent.tools import ToolConfig -class AgentStatus(enum.Enum): +class AgentStatus(str, enum.Enum): INACTIVE = "inactive" ACTIVE = "active" + def __str__(self): + return self.value + class Agent(Base): __tablename__ = "agents" @@ -28,7 +33,10 @@ class Agent(Base): telegram = Column(String) website = Column(String) tool_configs = Column(JSON) - status = Column(Enum(AgentStatus), nullable=False) + status = Column( + Enum(AgentStatus, values_callable=lambda x: [e.value for e in x]), + nullable=False, + ) created_at = Column(DateTime, default=lambda: datetime.now(UTC), nullable=False) updated_at = Column( DateTime, @@ -36,3 +44,18 @@ class Agent(Base): onupdate=lambda: datetime.now(UTC), nullable=False, ) + + def __init__(self, *args, **kwargs): + if "status" in kwargs and isinstance(kwargs["status"], AgentStatus): + kwargs["status"] = kwargs["status"].value + super().__init__(*args, **kwargs) + + @property + def tool_configs_list(self) -> List[ToolConfig]: + if not self.tool_configs: + return [] + return [ToolConfig.model_validate(config) for config in self.tool_configs] + + @tool_configs_list.setter + def tool_configs_list(self, configs: List[ToolConfig]): + self.tool_configs = [config.model_dump() for config in configs] diff --git a/openagent/database/models/tool.py b/openagent/database/models/tool.py index 53f1c459..0fe2db7a 100644 --- a/openagent/database/models/tool.py +++ b/openagent/database/models/tool.py @@ -8,6 +8,9 @@ class ToolType(enum.Enum): TEXT_GENERATION = "text_generation" SOCIAL_INTEGRATION = "social_integration" + def __str__(self): + return self.value + class Tool(Base): __tablename__ = "tools" @@ -15,7 +18,10 @@ class Tool(Base): id = Column(Integer, primary_key=True) name = Column(String, nullable=False) description = Column(Text) - type = Column(Enum(ToolType), nullable=False) + type = Column( + Enum(ToolType, values_callable=lambda x: [e.value for e in x]), + nullable=False, + ) created_at = Column(DateTime, default=lambda: datetime.now(UTC), nullable=False) updated_at = Column( DateTime, @@ -23,3 +29,8 @@ class Tool(Base): onupdate=lambda: datetime.now(UTC), nullable=False, ) + + def __init__(self, *args, **kwargs): + if "type" in kwargs and isinstance(kwargs["type"], ToolType): + kwargs["type"] = kwargs["type"].value + super().__init__(*args, **kwargs) diff --git a/openagent/router/routes/__init__.py b/openagent/router/routes/__init__.py index 7a9813fc..f350aa01 100644 --- a/openagent/router/routes/__init__.py +++ b/openagent/router/routes/__init__.py @@ -1,6 +1,5 @@ -from .chat import router as chat_router from .agent import router as agent_router from .model import router as model_router from .tool import router as tool_router -__all__ = ["chat_router", "agent_router", "model_router", "tool_router"] +__all__ = ["agent_router", "model_router", "tool_router"] diff --git a/openagent/router/routes/agent.py b/openagent/router/routes/agent.py index b2b6f9b2..b79b4b03 100644 --- a/openagent/router/routes/agent.py +++ b/openagent/router/routes/agent.py @@ -1,10 +1,9 @@ +import os import binascii - from eth_utils import remove_0x_prefix, to_checksum_address -from fastapi import APIRouter, Depends, HTTPException, status, Header +from fastapi import APIRouter, Depends, HTTPException, status, Header, Path from sqlalchemy.orm import Session -from typing import List, Optional, Union - +from typing import List, Optional, Tuple, Union, Dict, Any from openagent.database.models.agent import Agent, AgentStatus from openagent.database.models.model import Model from openagent.database.models.tool import Tool @@ -15,18 +14,36 @@ AgentListResponse, ) from openagent.router.error import APIExceptionResponse -from openagent.tools import ToolConfig +from openagent.tools import BaseTool, ToolConfig, get_tool_executor from openagent.database import get_db from eth_account.messages import encode_defunct from web3 import Web3 from typing import Annotated +from phi.model.base import Model as AI_Model +from phi.model.ollama import Ollama +from phi.model.openai import OpenAIChat +from phi.model.anthropic import Claude +from phi.model.google import Gemini +from dotenv import load_dotenv router = APIRouter(prefix="/agents", tags=["agents"]) +load_dotenv() + def check_tool_configs( tool_configs: List[ToolConfig], db: Session ) -> Optional[APIExceptionResponse]: + # check if the tool_names are unique + tool_names = [] + for tool_config in tool_configs: + if tool_config.name in tool_names: + return APIExceptionResponse( + status_code=status.HTTP_400_BAD_REQUEST, + error=f"Duplicate tool name: {tool_config.name}", + ) + tool_names.append(tool_config.name) + tool_ids = {tool_config.tool_id for tool_config in tool_configs} model_ids = {tool_config.model_id for tool_config in tool_configs} @@ -140,7 +157,7 @@ def create_agent( twitter=request.twitter, telegram=request.telegram, website=request.website, - tool_configs=request.tool_configs, + tool_configs=request.get_tool_configs_data(), status=AgentStatus.INACTIVE, # default status ) @@ -295,7 +312,6 @@ def update_agent( ) def delete_agent( agent_id: int, - request: CreateAgentRequest, verified_address: str = Depends(verify_wallet_auth), db: Session = Depends(get_db), ) -> ResponseModel: @@ -323,3 +339,209 @@ def delete_agent( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) + + +@router.post( + "/{agent_id}/run", + response_model=ResponseModel[AgentResponse], + summary="Run an agent", + description="Start an agent by setting its status to active", + responses={ + 200: {"description": "Successfully started agent"}, + 403: {"description": "Not authorized to run this agent"}, + 404: {"description": "Agent not found"}, + 500: {"description": "Internal server error"}, + }, +) +def run_agent( + agent_id: int, + verified_address: str = Depends(verify_wallet_auth), + db: Session = Depends(get_db), +) -> Union[ResponseModel[AgentResponse], APIExceptionResponse]: + try: + # get agent + agent = db.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + return APIExceptionResponse( + status_code=status.HTTP_404_NOT_FOUND, + error=f"Agent with ID {agent_id} not found", + ) + + # check if the user is authorized to run this agent + if agent.wallet_address.lower() != verified_address.lower(): + return APIExceptionResponse( + status_code=status.HTTP_403_FORBIDDEN, + error="Not authorized to run this agent", + ) + + # update the agent status to active + agent.status = AgentStatus.ACTIVE + db.commit() + db.refresh(agent) + + # TODO: start the agent + + return ResponseModel( + code=status.HTTP_200_OK, + data=AgentResponse.model_validate(agent), + message="Agent started successfully", + ) + except Exception as error: + db.rollback() + return APIExceptionResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + error=error, + ) + + +@router.post( + "/{agent_id}/execute/{tool_name}", + response_model=ResponseModel[Dict[str, Any]], + summary="Execute a specific tool", + description="Execute a specific tool of an agent with optional parameters", + responses={ + 200: {"description": "Successfully executed tool"}, + 400: {"description": "Tool not found in agent's configuration"}, + 403: {"description": "Not authorized to execute this agent's tool"}, + 404: {"description": "Agent not found"}, + 500: {"description": "Internal server error"}, + }, +) +@router.post( + "/{agent_id}/execute/{tool_name}", + response_model=ResponseModel[Dict[str, Any]], + summary="Execute a specific tool", + description="Execute a specific tool of an agent with optional parameters", + responses={ + 200: {"description": "Successfully executed tool"}, + 400: {"description": "Tool not found in agent's configuration"}, + 403: {"description": "Not authorized to execute this agent's tool"}, + 404: {"description": "Agent not found"}, + 500: {"description": "Internal server error"}, + }, + operation_id="execute_agent_tool", +) +def execute_tool( + agent_id: int, + tool_name: str = Path(..., description="Name of the tool to execute"), + verified_address: str = Depends(verify_wallet_auth), + db: Session = Depends(get_db), +) -> Union[ResponseModel[Dict[str, Any]], APIExceptionResponse]: + try: + # get agent + agent = db.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + return APIExceptionResponse( + status_code=status.HTTP_404_NOT_FOUND, + error=f"Agent with ID {agent_id} not found", + ) + + # check if the user is authorized to execute this agent's tool + if agent.wallet_address.lower() != verified_address.lower(): + return APIExceptionResponse( + status_code=status.HTTP_403_FORBIDDEN, + error="Not authorized to execute this agent's tool", + ) + + # check if the agent is active + if agent.status != AgentStatus.ACTIVE: + return APIExceptionResponse( + status_code=status.HTTP_400_BAD_REQUEST, + error="Agent is not active", + ) + + # find the specified tool config + tool_config = None + for config in agent.tool_configs_list: + if config.name == tool_name: + tool_config = config + break + + if not tool_config: + return APIExceptionResponse( + status_code=status.HTTP_400_BAD_REQUEST, + error=f"Tool '{tool_name}' not found in agent's configuration", + ) + + # get the tool and model + tool, model = get_tool_and_model(tool_config, db) + + # initialize the tool + tool_executor = initialize_tool_executor(tool, model) + + # execute the tool + success, result = execute_tool_action(tool_executor, agent, tool_config) + + if not success: + return APIExceptionResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + error=result, + ) + + return ResponseModel( + code=status.HTTP_200_OK, + data=None, + message=f"Tool {tool_name} executed successfully", + ) + + except Exception as error: + return APIExceptionResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + error=error, + ) + + +def get_tool_and_model(tool_config: ToolConfig, db: Session) -> Tuple[Tool, Model]: + tool = db.query(Tool).filter(Tool.id == tool_config.tool_id).first() + model = db.query(Model).filter(Model.id == tool_config.model_id).first() + + if not tool or not model: + raise ValueError("Tool or model not found in database") + + return tool, model + + +def build_model(model: Model) -> AI_Model: + (provider, model_id) = model.name.split("/") + + match provider: + case "openai": + return OpenAIChat(id=model_id, base_url=os.getenv("OPENAI_BASE_URL")) + case "anthropic": + return Claude( + id=model_id, + client_params={"base_url": os.getenv("ANTHROPIC_BASE_URL")}, + ) + case "google": + return Gemini( + id=model_id, client_params={"base_url": os.getenv("GOOGLE_BASE_URL")} + ) + case "ollama": + return Ollama(id=model_id, host=os.getenv("OLLAMA_BASE_URL")) + case _: + raise ValueError(f"Unsupported model: {model}") + + +def initialize_tool_executor(tool: Tool, model: Model) -> BaseTool: + model_instance = build_model(model) + + return get_tool_executor(tool, model_instance) + + +def execute_tool_action( + tool_executor: BaseTool, agent: Agent, tool_config: ToolConfig +) -> Tuple[bool, str]: + try: + match tool_executor.name: + case "tweet_generator": + return tool_executor.run( + personality=agent.personality, + description=tool_config.description + if tool_config.parameters + else None, + ) + case _: + raise ValueError(f"Unsupported tool: {tool_executor.name}") + + except Exception as e: + return False, f"Tool execution failed: {str(e)}" diff --git a/openagent/router/routes/chat.py b/openagent/router/routes/chat.py deleted file mode 100644 index 1e6d8ff8..00000000 --- a/openagent/router/routes/chat.py +++ /dev/null @@ -1,45 +0,0 @@ -import time -import uuid -from typing import Union -from fastapi import APIRouter, status -from openai.types.chat import ( - ChatCompletion, - ChatCompletionMessage, -) -from openai.types.chat.chat_completion import Choice -from openagent.agents import build_agent_team -from openagent.router.error import APIExceptionResponse -from openagent.router.routes.models.request import CreateChatCompletionRequest - -router = APIRouter(tags=["chat"]) - - -@router.post( - "/v1/chat/completions", response_model=None, response_model_exclude_none=True -) -async def create_chat_completion( - request: CreateChatCompletionRequest, -) -> Union[ChatCompletion, APIExceptionResponse]: - try: - agent_team = build_agent_team(request.model) - result = agent_team.run(messages=request.messages) - - return ChatCompletion( - id=str(uuid.uuid4()), - created=int(time.time()), - model=request.model, - object="chat.completion", - choices=[ - Choice( - message=ChatCompletionMessage( - role="assistant", content=result.content - ), - finish_reason="stop", - index=0, - ) - ], - ) - except Exception as error: - return APIExceptionResponse( - status_code=status.HTTP_400_BAD_REQUEST, error=error - ) diff --git a/openagent/router/routes/models/request.py b/openagent/router/routes/models/request.py index 2964f2c9..f398f192 100644 --- a/openagent/router/routes/models/request.py +++ b/openagent/router/routes/models/request.py @@ -1,15 +1,8 @@ from typing import Optional, List from pydantic import BaseModel -from openai.types.chat import ChatCompletionMessageParam from openagent.tools import ToolConfig -# Create Chat Completion -class CreateChatCompletionRequest(BaseModel): - model: str - messages: List[ChatCompletionMessageParam] - - # Create Agent class CreateAgentRequest(BaseModel): name: str @@ -25,6 +18,9 @@ class CreateAgentRequest(BaseModel): website: Optional[str] = None tool_configs: Optional[List[ToolConfig]] = None + def get_tool_configs_data(self) -> List[dict]: + return [config.model_dump() for config in self.tool_configs] + # Run Agent class RunAgentRequest(BaseModel): diff --git a/openagent/router/routes/models/response.py b/openagent/router/routes/models/response.py index c1e308a6..5884d1ae 100644 --- a/openagent/router/routes/models/response.py +++ b/openagent/router/routes/models/response.py @@ -1,5 +1,5 @@ from typing import Optional, Generic, TypeVar, List -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from openagent.database.models.agent import AgentStatus from openagent.database.models.tool import ToolType from openagent.tools import ToolConfig @@ -14,6 +14,8 @@ class ResponseModel(BaseModel, Generic[T]): class AgentResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: int name: str description: Optional[str] = None @@ -32,6 +34,8 @@ class AgentResponse(BaseModel): class AgentListResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + agents: List[AgentResponse] total: int diff --git a/openagent/router/server.py b/openagent/router/server.py index f102aba6..577bcb5d 100644 --- a/openagent/router/server.py +++ b/openagent/router/server.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware -from .routes import chat_router, agent_router, model_router, tool_router +from .routes import agent_router, model_router, tool_router app = FastAPI( title="OpenAgent API", @@ -19,7 +19,6 @@ allow_headers=["*"], ) -app.include_router(chat_router) app.include_router(agent_router) app.include_router(model_router) app.include_router(tool_router) diff --git a/openagent/tools/__init__.py b/openagent/tools/__init__.py index c06b4826..e70f2530 100644 --- a/openagent/tools/__init__.py +++ b/openagent/tools/__init__.py @@ -1,12 +1,99 @@ -from .coingecko import CoinGeckoTools -from typing import Optional -from .dsl import DSLTools -from .twitter import TweetGeneratorTools -from pydantic import BaseModel +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple, Optional +from enum import Enum +from phi.tools import Toolkit +from phi.model.base import Model +from pydantic import BaseModel, ConfigDict + +from openagent.database.models.tool import Tool + + +class TriggerType(Enum): + SCHEDULED = "scheduled" + AUTO = "auto" + Manual = "manual" + + +class ToolParameters(BaseModel): + model_config = ConfigDict(from_attributes=True) + trigger_type: TriggerType + schedule: Optional[str] = None # cron, such as "0 */2 * * *" + auth: Optional[Dict[str, Any]] = None + + def validate_schedule(self): + if self.trigger_type == TriggerType.SCHEDULED and not self.schedule: + raise ValueError("Schedule must be set when trigger_type is SCHEDULED") -__all__ = [CoinGeckoTools, DSLTools, TweetGeneratorTools] class ToolConfig(BaseModel): + model_config = ConfigDict(from_attributes=True) + + name: str + description: Optional[str] = None tool_id: int model_id: int - parameters: Optional[dict] = None \ No newline at end of file + parameters: Optional[ToolParameters] = None + + def validate_parameters(self): + if self.parameters: + self.parameters.validate_schedule() + + def model_dump(self, *args, **kwargs) -> dict: + data = super().model_dump(*args, **kwargs) + if data.get("parameters") and "trigger_type" in data["parameters"]: + data["parameters"]["trigger_type"] = data["parameters"][ + "trigger_type" + ].value + return data + + +class BaseTool(Toolkit, ABC): + def __init__(self, name: str, model: Optional[Model] = None): + super().__init__(name=name) + self.model = model + + @abstractmethod + def run(self, **kwargs) -> Tuple[bool, Any]: + """ + execute the tool + + Args: + **kwargs: the input parameters + + Returns: + Tuple[bool, Any]: (success, result) + """ + pass + + @abstractmethod + def validate_params(self, params: Dict[str, Any]) -> Tuple[bool, str]: + """ + validate the input parameters + + Args: + params: the input parameters + + Returns: + Tuple[bool, str]: (success, error message) + """ + pass + + +def get_tool_executor(tool: Tool, model: Model) -> BaseTool: + match tool.name: + case "twitter.post": + from .twitter.tweet_generator import TweetGeneratorTools + + return TweetGeneratorTools(model=model) + # TODO: add more tools + case _: + raise ValueError(f"Unsupported tool: {tool.name}") + + +__all__ = [ + "BaseTool", + "ToolConfig", + "TriggerType", + "ToolParameters", + "get_tool_executor", +] diff --git a/openagent/tools/tests/test_twitter_tools.py b/openagent/tools/tests/test_twitter_tools.py index 166fc0dc..2b1e27a4 100644 --- a/openagent/tools/tests/test_twitter_tools.py +++ b/openagent/tools/tests/test_twitter_tools.py @@ -37,14 +37,16 @@ def test_tweet_generation_and_posting(self): # Test case for tweet generation and posting personality = "tech enthusiast" - topic = "AI and machine learning innovations" + description = "AI and machine learning innovations" expected_terms = ["AI", "tech", "machine learning", "innovation"] try: # Generate and post tweet - print(f"\nGenerating and posting tweet as {personality} about {topic}...") + print( + f"\nGenerating and posting tweet as {personality} about {description}..." + ) success, tweet_content = self.tweet_tools.generate_tweet( - personality=personality, topic=topic + personality=personality, description=description ) # Validate the result diff --git a/openagent/tools/twitter/tweet_generator.py b/openagent/tools/twitter/tweet_generator.py index f62c9c64..a0c63e2d 100644 --- a/openagent/tools/twitter/tweet_generator.py +++ b/openagent/tools/twitter/tweet_generator.py @@ -1,11 +1,12 @@ import logging import os -from typing import Tuple, Optional -from phi.tools import Toolkit +from typing import Any, Dict, Tuple, Optional from phi.model.openai import OpenAIChat from phi.model.base import Model from phi.model.message import Message from dotenv import load_dotenv + +from openagent.tools import BaseTool from .twitter_handler import TwitterHandler # Configure logging @@ -15,7 +16,7 @@ load_dotenv() SYSTEM_PROMPT = """You are a creative tweet writer who can adapt to different personalities and styles. -Your task is to generate engaging tweets that match the given personality and topic. +Your task is to generate engaging tweets that match the given personality and description. ALWAYS include relevant hashtags in your tweets to increase visibility and engagement. Format your response exactly like a real human tweet - no quotes, no additional text.""" @@ -31,15 +32,13 @@ """ -class TweetGeneratorTools(Toolkit): +class TweetGeneratorTools(BaseTool): def __init__(self, model: Optional[Model] = None): - super().__init__(name="tweet_generator_tools") + super().__init__(name="tweet_generator", model=model) self.twitter_handler = TwitterHandler() # Use provided model (from agent) or create a new one - if model: - self.model = model - else: + if not model: # Initialize OpenAI model for standalone use openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: @@ -47,7 +46,7 @@ def __init__(self, model: Optional[Model] = None): # Initialize model params model_params = { - "id": "gpt-4o", + "id": "gpt-4", "name": "TweetGenerator", "temperature": 0.7, "max_tokens": 280, @@ -62,24 +61,43 @@ def __init__(self, model: Optional[Model] = None): self.model = OpenAIChat(**model_params) - # Register only the tweet generation function - self.register(self.generate_tweet) + # Register the run method + self.register(self.run) + + def validate_params(self, params: Dict[str, Any]) -> Tuple[bool, str]: + if "personality" not in params: + return False, "Missing required parameter: personality" + + if not isinstance(params["personality"], str): + return False, "Parameter 'personality' must be a string" + + if "description" in params and not isinstance(params["description"], str): + return False, "Parameter 'description' must be a string" - def generate_tweet(self, personality: str, topic: str = None) -> Tuple[bool, str]: + return True, "" + + def run(self, personality: str, description: str = None) -> Tuple[bool, str]: + return self.generate_tweet(personality, description) + + def generate_tweet( + self, personality: str, description: str = None + ) -> Tuple[bool, str]: """ - Generate a tweet using the model based on personality and topic, and post it. + Generate a tweet using the model based on personality and description, and post it. Args: personality (str): The personality/role to use for tweet generation - topic (str, optional): Specific topic to tweet about + description (str, optional): Specific description to tweet about Returns: tuple: (success: bool, message: str) - Success status and response message """ try: # Generate prompt messages user_prompt = f"Generate a tweet as {personality}." - if topic: - user_prompt += f" The tweet should be about: {topic}." + + if description: + user_prompt += f" Its content is centered around: {description}." + user_prompt += TWEET_REQUIREMENTS messages = [ @@ -106,7 +124,7 @@ def generate_tweet(self, personality: str, topic: str = None) -> Tuple[bool, str logger.warning( "Generated tweet does not contain hashtags, regenerating..." ) - return self.generate_tweet(personality, topic) + return self.generate_tweet(personality, description) # Post the generated tweet logger.info(f"Posting generated tweet: {tweet_content}")