Skip to content

Commit

Permalink
feat(dialectic) parallelize facts and history queries
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Sep 4, 2024
1 parent 665e8dc commit 0de3be8
Showing 1 changed file with 24 additions and 42 deletions.
66 changes: 24 additions & 42 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import uuid

Expand Down Expand Up @@ -85,26 +86,20 @@ async def prep_inference(
return retrieved_facts


# async def chat(
# app_id: uuid.UUID,
# user_id: uuid.UUID,
# query: str,
# db: AsyncSession,
# stream: bool = False,
# ):
# retrieved_facts = await prep_inference(db, app_id, user_id, query)
# facts = "None"
# if retrieved_facts is not None:
# facts = "\n".join(retrieved_facts)
# chain = Dialectic(
# agent_input=query,
# retrieved_facts=facts,
# )
#
# if stream:
# return chain.stream_async()
# response = chain.call()
# return schemas.AgentChat(content=response.content)
async def generate_facts(db, app_id, user_id, questions):
all_facts = set()

async def fetch_facts(query):
retrieved_facts = await prep_inference(db, app_id, user_id, query)
if retrieved_facts is not None:
all_facts.update(retrieved_facts)

await asyncio.gather(*[fetch_facts(query) for query in questions])

facts = "None"
if all_facts:
facts = "\n".join(all_facts)
return facts


async def chat(
Expand All @@ -120,16 +115,17 @@ async def chat(
else:
questions = queries.queries

all_facts = set()
for query in questions:
retrieved_facts = await prep_inference(db, app_id, user_id, query)
if retrieved_facts is not None:
all_facts.update(retrieved_facts)
facts = "None"
if len(all_facts) > 0:
facts = "\n".join(all_facts)
query = "\n".join(questions)

history = await chat_history(db, app_id, user_id, session_id)

# Run fact generation and chat history retrieval concurrently
facts_task = asyncio.create_task(generate_facts(db, app_id, user_id, questions))
history_task = asyncio.create_task(chat_history(db, app_id, user_id, session_id))

# Wait for both tasks to complete
facts, history = await asyncio.gather(facts_task, history_task)

chain = Dialectic(
agent_input=query,
retrieved_facts=facts,
Expand All @@ -139,17 +135,3 @@ async def chat(
return chain.stream_async()
response = chain.call()
return schemas.AgentChat(content=response.content)


# async def stream(
# app_id: uuid.UUID,
# user_id: uuid.UUID,
# query: str,
# db: AsyncSession,
# ):
# retrieved_facts = await prep_inference(db, app_id, user_id, query)
# chain = Dialectic(
# agent_input=query,
# retrieved_facts=retrieved_facts if retrieved_facts else "None",
# )
# return chain.stream_async()

0 comments on commit 0de3be8

Please sign in to comment.