Skip to content

Commit

Permalink
Merge pull request #107 from cs3216-a3-group-4/seeleng/parallelise-th…
Browse files Browse the repository at this point in the history
…ird-layer

feat: parallelise user question answer generation
  • Loading branch information
seelengxd authored Sep 27, 2024
2 parents e2a3ad0 + 92339eb commit afb31de
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 29 deletions.
74 changes: 47 additions & 27 deletions backend/src/lm/generate_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.common.database import engine
from sqlalchemy import select
from src.events.models import Event
import asyncio


def get_event_by_id(event_id: int) -> Event:
Expand All @@ -35,36 +36,55 @@ def format_prompt_input(question: str, analysis: dict, point: str) -> str:
"""


def generate_response(question: str) -> dict:
async def get_elaborated_analysis(
question, analysis, point, elaborated_analyses, index
):
prompt_input = format_prompt_input(question, analysis, point)
messages = [
SystemMessage(content=SYSPROMPT),
HumanMessage(content=prompt_input),
]

result = lm_model.invoke(messages)

analysis["elaborations"] = result.content
if analysis["elaborations"] != "NOT RELEVANT":
elaborated_analyses.append((analysis, index))


async def process_point_dict(point_dict, question):
point = point_dict.get("point")
analyses = point_dict.get("analyses")
elaborated_analyses = []
await asyncio.gather(
*[
get_elaborated_analysis(
question, analysis, point, elaborated_analyses, index
)
for index, analysis in enumerate(analyses)
]
)
elaborated_analyses.sort(key=lambda item: item[1])
elaborated_analyses = [item[0] for item in elaborated_analyses]

point_dict["analyses"] = elaborated_analyses

if len(elaborated_analyses) == 0:
point_dict["fall_back_response"] = generate_fallback_response(question, point)


async def generate_response(question: str) -> dict:
relevant_analyses = get_relevant_analyses(question)
count = 0
for point_dict in (
relevant_analyses["for_points"] + relevant_analyses["against_points"]
):
count += 1
point = point_dict.get("point")
analyses = point_dict.get("analyses")
elaborated_analyses = []
for analysis in analyses:
prompt_input = format_prompt_input(question, analysis, point)
messages = [
SystemMessage(content=SYSPROMPT),
HumanMessage(content=prompt_input),
]

result = lm_model.invoke(messages)

analysis["elaborations"] = result.content
if analysis["elaborations"] != "NOT RELEVANT":
elaborated_analyses.append(analysis)
point_dict["analyses"] = elaborated_analyses

if len(elaborated_analyses) == 0:
point_dict["fall_back_response"] = generate_fallback_response(
question, point

await asyncio.gather(
*[
process_point_dict(point_dict, question)
for point_dict in (
relevant_analyses["for_points"] + relevant_analyses["against_points"]
)
]
)

print(count)
return relevant_analyses


Expand Down
7 changes: 5 additions & 2 deletions backend/src/user_questions/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from http import HTTPStatus
from pprint import pprint
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
Expand Down Expand Up @@ -120,7 +121,7 @@ def classify_question(question: str):


@router.post("/")
def create_user_question(
async def create_user_question(
data: CreateUserQuestion,
user: Annotated[User, Depends(get_current_user)],
session=Depends(get_session),
Expand All @@ -134,7 +135,9 @@ def create_user_question(
answer = Answer()
user_question.answer = answer

results = generate_response(data.question)
results = await generate_response(data.question)

pprint(results)

for row in results["for_points"]:
point = row["point"]
Expand Down

0 comments on commit afb31de

Please sign in to comment.