-
Notifications
You must be signed in to change notification settings - Fork 0
/
quickstart_server.py
96 lines (81 loc) · 3.63 KB
/
quickstart_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import asyncio
import uuid
import traceback
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
import redis.asyncio as redis
from dotenv import load_dotenv
from bolna.helpers.utils import store_file
from bolna.prompts import *
from bolna.helpers.logger_config import configure_logger
from bolna.models import *
from bolna.llms import LiteLLM
from bolna.agent_manager.assistant_manager import AssistantManager
load_dotenv()
logger = configure_logger(__name__)
redis_pool = redis.ConnectionPool.from_url(os.getenv('REDIS_URL'), decode_responses=True)
redis_client = redis.Redis.from_pool(redis_pool)
active_websockets: List[WebSocket] = []
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
class CreateAgentPayload(BaseModel):
agent_config: AgentModel
agent_prompts: Optional[Dict[str, Dict[str, str]]]
@app.post("/agent")
async def create_agent(agent_data: CreateAgentPayload):
agent_uuid = str(uuid.uuid4())
data_for_db = agent_data.agent_config.model_dump()
data_for_db["assistant_status"] = "seeding"
agent_prompts = agent_data.agent_prompts
logger.info(f'Data for DB {data_for_db}')
if len(data_for_db['tasks']) > 0:
logger.info("Setting up follow up tasks")
for index, task in enumerate(data_for_db['tasks']):
if task['task_type'] == "extraction":
extraction_prompt_llm = os.getenv("EXTRACTION_PROMPT_GENERATION_MODEL")
extraction_prompt_generation_llm = LiteLLM(model=extraction_prompt_llm, max_tokens=2000)
extraction_prompt = await extraction_prompt_generation_llm.generate(
messages=[
{'role': 'system', 'content': EXTRACTION_PROMPT_GENERATION_PROMPT},
{'role': 'user', 'content': data_for_db["tasks"][index]['tools_config']["llm_agent"]['extraction_details']}
])
data_for_db["tasks"][index]["tools_config"]["llm_agent"]['extraction_json'] = extraction_prompt
stored_prompt_file_path = f"{agent_uuid}/conversation_details.json"
await asyncio.gather(
redis_client.set(agent_uuid, json.dumps(data_for_db)),
store_file(file_key=stored_prompt_file_path, file_data=agent_prompts, local=True)
)
return {"agent_id": agent_uuid, "state": "created"}
#############################################################################################
# Websocket
#############################################################################################
@app.websocket("/chat/v1/{agent_id}")
async def websocket_endpoint(agent_id: str, websocket: WebSocket, user_agent: str = Query(None)):
logger.info("Connected to ws")
await websocket.accept()
active_websockets.append(websocket)
agent_config, context_data = None, None
try:
retrieved_agent_config = await redis_client.get(agent_id)
logger.info(
f"Retrieved agent config: {retrieved_agent_config}")
agent_config = json.loads(retrieved_agent_config)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=404, detail="Agent not found")
assistant_manager = AssistantManager(agent_config, websocket, agent_id)
try:
async for index, task_output in assistant_manager.run(local=True):
logger.info(task_output)
except WebSocketDisconnect:
active_websockets.remove(websocket)
except Exception as e:
traceback.print_exc()
logger.error(f"error in executing {e}")