From 0de3be81fd17ee0b7b6755e62333585a3f64adf0 Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:57:15 -0400 Subject: [PATCH] feat(dialectic) parallelize facts and history queries --- src/agent.py | 66 +++++++++++++++++++--------------------------------- 1 file changed, 24 insertions(+), 42 deletions(-) diff --git a/src/agent.py b/src/agent.py index e93ad82..bd6b1a5 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid @@ -85,26 +86,20 @@ async def prep_inference( return retrieved_facts -# async def chat( -# app_id: uuid.UUID, -# user_id: uuid.UUID, -# query: str, -# db: AsyncSession, -# stream: bool = False, -# ): -# retrieved_facts = await prep_inference(db, app_id, user_id, query) -# facts = "None" -# if retrieved_facts is not None: -# facts = "\n".join(retrieved_facts) -# chain = Dialectic( -# agent_input=query, -# retrieved_facts=facts, -# ) -# -# if stream: -# return chain.stream_async() -# response = chain.call() -# return schemas.AgentChat(content=response.content) +async def generate_facts(db, app_id, user_id, questions): + all_facts = set() + + async def fetch_facts(query): + retrieved_facts = await prep_inference(db, app_id, user_id, query) + if retrieved_facts is not None: + all_facts.update(retrieved_facts) + + await asyncio.gather(*[fetch_facts(query) for query in questions]) + + facts = "None" + if all_facts: + facts = "\n".join(all_facts) + return facts async def chat( @@ -120,16 +115,17 @@ async def chat( else: questions = queries.queries - all_facts = set() - for query in questions: - retrieved_facts = await prep_inference(db, app_id, user_id, query) - if retrieved_facts is not None: - all_facts.update(retrieved_facts) - facts = "None" - if len(all_facts) > 0: - facts = "\n".join(all_facts) query = "\n".join(questions) + history = await chat_history(db, app_id, user_id, session_id) + + # Run fact generation and chat history retrieval concurrently + facts_task = asyncio.create_task(generate_facts(db, app_id, user_id, questions)) + history_task = asyncio.create_task(chat_history(db, app_id, user_id, session_id)) + + # Wait for both tasks to complete + facts, history = await asyncio.gather(facts_task, history_task) + chain = Dialectic( agent_input=query, retrieved_facts=facts, @@ -139,17 +135,3 @@ async def chat( return chain.stream_async() response = chain.call() return schemas.AgentChat(content=response.content) - - -# async def stream( -# app_id: uuid.UUID, -# user_id: uuid.UUID, -# query: str, -# db: AsyncSession, -# ): -# retrieved_facts = await prep_inference(db, app_id, user_id, query) -# chain = Dialectic( -# agent_input=query, -# retrieved_facts=retrieved_facts if retrieved_facts else "None", -# ) -# return chain.stream_async()