Skip to content

Commit

Permalink
fix: dialectic endpoint stream method
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Dec 4, 2024
1 parent 7608572 commit 3ad9dcc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
24 changes: 10 additions & 14 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections.abc import Iterable

from anthropic import Anthropic
from anthropic import Anthropic, MessageStreamManager
from dotenv import load_dotenv
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -31,9 +31,7 @@ def get_set(self) -> set[str]:


class Dialectic:
def __init__(
self, agent_input: str, user_representation: str, chat_history: list[str]
):
def __init__(self, agent_input: str, user_representation: str, chat_history: str):
self.agent_input = agent_input
self.user_representation = user_representation
self.chat_history = chat_history
Expand Down Expand Up @@ -68,8 +66,7 @@ def stream(self):
<conversation_history>{self.chat_history}</conversation_history>
Provide a brief, matter-of-fact, and appropriate response to the query based on the context provided. If the context provided doesn't aid in addressing the query, return only the word "None".
"""

yield from self.client.messages.create(
return self.client.messages.stream(
model="claude-3-5-sonnet-20240620",
messages=[
{
Expand All @@ -78,26 +75,25 @@ def stream(self):
}
],
max_tokens=300,
stream=True,
)


async def chat_history(app_id: str, user_id: str, session_id: str) -> list[str]:
async def chat_history(app_id: str, user_id: str, session_id: str) -> str:
async with SessionLocal() as db:
stmt = await crud.get_messages(db, app_id, user_id, session_id)
results = await db.execute(stmt)
messages = results.scalars()
history = []
history = ""
for message in messages:
if message.is_user:
history.append(f"user:{message.content}")
history += f"user:{message.content}\n"
else:
history.append(f"assistant:{message.content}")
history += f"assistant:{message.content}\n"
return history


async def get_latest_user_representation(
db: AsyncSession, app_id: str, user_id: str, session_id: str
db: AsyncSession, app_id: str, user_id: str
) -> str:
stmt = (
select(models.Metamessage)
Expand Down Expand Up @@ -126,13 +122,13 @@ async def chat(
session_id: str,
query: schemas.AgentQuery,
stream: bool = False,
):
) -> schemas.AgentChat | MessageStreamManager:
questions = [query.queries] if isinstance(query.queries, str) else query.queries
final_query = "\n".join(questions) if len(questions) > 1 else questions[0]

async with SessionLocal() as db:
# Run user representation retrieval and chat history retrieval concurrently
user_rep_task = get_latest_user_representation(db, app_id, user_id, session_id)
user_rep_task = get_latest_user_representation(db, app_id, user_id)
history_task = chat_history(app_id, user_id, session_id)

# Wait for both tasks to complete
Expand Down
7 changes: 5 additions & 2 deletions src/routers/sessions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from anthropic import MessageStreamManager
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from fastapi_pagination import Page
Expand Down Expand Up @@ -150,8 +151,10 @@ async def parse_stream():
query=query,
stream=True,
)
for chunk in stream:
yield chunk.content
if type(stream) is MessageStreamManager:
with stream as stream_manager:
for text in stream_manager.text_stream:
yield text

return StreamingResponse(
content=parse_stream(), media_type="text/event-stream", status_code=200
Expand Down

0 comments on commit 3ad9dcc

Please sign in to comment.