Skip to content

Commit

Permalink
feat(dialectic) Allow for batch questions and load session history
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Sep 4, 2024
1 parent 5cac118 commit 665e8dc
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 28 deletions.
100 changes: 80 additions & 20 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class Dialectic(OpenAICall):
---
query: {agent_input}
context: {retrieved_facts}
conversation_history: {chat_history}
---
Provide a brief, matter-of-fact, and appropriate response to the query based on the context provided. If the context provided doesn't aid in addressing the query, return None.
"""
agent_input: str
retrieved_facts: str
chat_history: list[str]

configuration = BaseConfig(
client_wrappers=[
Expand All @@ -35,15 +37,29 @@ class Dialectic(OpenAICall):
call_params = OpenAICallParams(
model=os.getenv("AZURE_OPENAI_DEPLOYMENT"), temperature=1.2, top_p=0.5
)
# call_params = OpenAICallParams(model="gpt-4o-2024-05-13")


async def chat_history(
db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, session_id: uuid.UUID
) -> list[str]:
stmt = await crud.get_messages(db, app_id, user_id, session_id)
results = await db.execute(stmt)
messages = results.scalars()
history = []
for message in messages:
if message.is_user:
history.append(f"user:{message.content}")
else:
history.append(f"assistant:{message.content}")
return history


async def prep_inference(
db: AsyncSession,
app_id: uuid.UUID,
user_id: uuid.UUID,
query: str,
):
) -> None | list[str]:
collection = await crud.get_collection_by_name(db, app_id, user_id, "honcho")
retrieved_facts = None
if collection is None:
Expand All @@ -61,35 +77,79 @@ async def prep_inference(
user_id=user_id,
collection_id=collection.id,
query=query,
top_k=1,
top_k=3,
)
if len(retrieved_documents) > 0:
retrieved_facts = retrieved_documents[0].content
retrieved_facts = [d.content for d in retrieved_documents]

return retrieved_facts

chain = Dialectic(
agent_input=query,
retrieved_facts=retrieved_facts if retrieved_facts else "None",
)
return chain

# async def chat(
# app_id: uuid.UUID,
# user_id: uuid.UUID,
# query: str,
# db: AsyncSession,
# stream: bool = False,
# ):
# retrieved_facts = await prep_inference(db, app_id, user_id, query)
# facts = "None"
# if retrieved_facts is not None:
# facts = "\n".join(retrieved_facts)
# chain = Dialectic(
# agent_input=query,
# retrieved_facts=facts,
# )
#
# if stream:
# return chain.stream_async()
# response = chain.call()
# return schemas.AgentChat(content=response.content)


async def chat(
app_id: uuid.UUID,
user_id: uuid.UUID,
query: str,
session_id: uuid.UUID,
queries: schemas.AgentQuery,
db: AsyncSession,
stream: bool = False,
):
chain = await prep_inference(db, app_id, user_id, query)
response = await chain.call_async()
if isinstance(queries.queries, str):
questions = [queries.queries]
else:
questions = queries.queries

all_facts = set()
for query in questions:
retrieved_facts = await prep_inference(db, app_id, user_id, query)
if retrieved_facts is not None:
all_facts.update(retrieved_facts)
facts = "None"
if len(all_facts) > 0:
facts = "\n".join(all_facts)
query = "\n".join(questions)
history = await chat_history(db, app_id, user_id, session_id)
chain = Dialectic(
agent_input=query,
retrieved_facts=facts,
chat_history=history,
)
if stream:
return chain.stream_async()
response = chain.call()
return schemas.AgentChat(content=response.content)


async def stream(
app_id: uuid.UUID,
user_id: uuid.UUID,
query: str,
db: AsyncSession,
):
chain = await prep_inference(db, app_id, user_id, query)
return chain.stream_async()
# async def stream(
# app_id: uuid.UUID,
# user_id: uuid.UUID,
# query: str,
# db: AsyncSession,
# ):
# retrieved_facts = await prep_inference(db, app_id, user_id, query)
# chain = Dialectic(
# agent_input=query,
# retrieved_facts=retrieved_facts if retrieved_facts else "None",
# )
# return chain.stream_async()
25 changes: 17 additions & 8 deletions src/routers/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,22 @@ async def get_session(
return honcho_session


@router.get("/{session_id}/chat", response_model=schemas.AgentChat)
async def get_chat(
request: Request,
@router.post("/{session_id}/chat", response_model=schemas.AgentChat)
async def chat(
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
query: schemas.AgentQuery,
db=db,
auth=Depends(auth),
):
return await agent.chat(app_id=app_id, user_id=user_id, query=query, db=db)
print(query)
return await agent.chat(
app_id=app_id, user_id=user_id, session_id=session_id, queries=query, db=db
)


@router.get(
@router.post(
"/{session_id}/chat/stream",
responses={
200: {
Expand All @@ -217,12 +219,19 @@ async def get_chat_stream(
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
query: schemas.AgentQuery,
db=db,
auth=Depends(auth),
):
async def parse_stream():
stream = await agent.stream(app_id=app_id, user_id=user_id, query=query, db=db)
stream = await agent.chat(
app_id=app_id,
user_id=user_id,
session_id=session_id,
queries=query,
db=db,
stream=True,
)
async for chunk in stream:
yield chunk.content

Expand Down
4 changes: 4 additions & 0 deletions src/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,9 @@ def fetch_h_metadata(cls, value, info):
)


class AgentQuery(BaseModel):
queries: str | list[str]


class AgentChat(BaseModel):
content: str

0 comments on commit 665e8dc

Please sign in to comment.