diff --git a/src/agent.py b/src/agent.py index 52195c5..881d160 100644 --- a/src/agent.py +++ b/src/agent.py @@ -2,7 +2,7 @@ import os from collections.abc import Iterable -from anthropic import Anthropic +from anthropic import Anthropic, MessageStreamManager from dotenv import load_dotenv from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -31,9 +31,7 @@ def get_set(self) -> set[str]: class Dialectic: - def __init__( - self, agent_input: str, user_representation: str, chat_history: list[str] - ): + def __init__(self, agent_input: str, user_representation: str, chat_history: str): self.agent_input = agent_input self.user_representation = user_representation self.chat_history = chat_history @@ -68,8 +66,7 @@ def stream(self): {self.chat_history} Provide a brief, matter-of-fact, and appropriate response to the query based on the context provided. If the context provided doesn't aid in addressing the query, return only the word "None". """ - - yield from self.client.messages.create( + return self.client.messages.stream( model="claude-3-5-sonnet-20240620", messages=[ { @@ -78,26 +75,25 @@ def stream(self): } ], max_tokens=300, - stream=True, ) -async def chat_history(app_id: str, user_id: str, session_id: str) -> list[str]: +async def chat_history(app_id: str, user_id: str, session_id: str) -> str: async with SessionLocal() as db: stmt = await crud.get_messages(db, app_id, user_id, session_id) results = await db.execute(stmt) messages = results.scalars() - history = [] + history = "" for message in messages: if message.is_user: - history.append(f"user:{message.content}") + history += f"user:{message.content}\n" else: - history.append(f"assistant:{message.content}") + history += f"assistant:{message.content}\n" return history async def get_latest_user_representation( - db: AsyncSession, app_id: str, user_id: str, session_id: str + db: AsyncSession, app_id: str, user_id: str ) -> str: stmt = ( select(models.Metamessage) @@ -126,13 +122,13 @@ async def chat( session_id: str, query: schemas.AgentQuery, stream: bool = False, -): +) -> schemas.AgentChat | MessageStreamManager: questions = [query.queries] if isinstance(query.queries, str) else query.queries final_query = "\n".join(questions) if len(questions) > 1 else questions[0] async with SessionLocal() as db: # Run user representation retrieval and chat history retrieval concurrently - user_rep_task = get_latest_user_representation(db, app_id, user_id, session_id) + user_rep_task = get_latest_user_representation(db, app_id, user_id) history_task = chat_history(app_id, user_id, session_id) # Wait for both tasks to complete diff --git a/src/routers/sessions.py b/src/routers/sessions.py index 432b4a6..e2ff8a4 100644 --- a/src/routers/sessions.py +++ b/src/routers/sessions.py @@ -1,5 +1,6 @@ from typing import Optional +from anthropic import MessageStreamManager from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from fastapi_pagination import Page @@ -150,8 +151,10 @@ async def parse_stream(): query=query, stream=True, ) - for chunk in stream: - yield chunk.content + if type(stream) is MessageStreamManager: + with stream as stream_manager: + for text in stream_manager.text_stream: + yield text return StreamingResponse( content=parse_stream(), media_type="text/event-stream", status_code=200