Skip to content

Commit

Permalink
Merge pull request #99 from cs3216-a3-group-4/feat-modify-format-for-…
Browse files Browse the repository at this point in the history
…essay-elabs-for-frontend-integration

feat: reformat elaboration outputs
  • Loading branch information
marcus-ny authored Sep 26, 2024
2 parents 65579e8 + 3882066 commit aa0ee80
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 84 deletions.
15 changes: 3 additions & 12 deletions backend/src/cron/fetch_articles.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,17 @@ async def populate_daily_articles_cna():
await process_all_categories()


def process_new_articles() -> list[dict]:
def process_new_articles():
with Session(engine) as session:
result = session.scalars(
select(Article).where(
Article.id.not_in(
list(session.scalars(select(Event.original_article_id)))
)
)
).all()
)

articles = []

for article in result:
data_dict = {
"id": article.id,
"bodyText": article.body,
}
articles.append(data_dict)

return articles
return result


# NOTE: this method should work with no issue as long as the number of calls is less than 500 which is the rate limit by OpenAI
Expand Down
10 changes: 3 additions & 7 deletions backend/src/cron/process_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from src.events.models import Analysis

from src.lm.generate_events import (
CONCURRENCY,
EventPublic,
form_event_json,
generate_events_from_article,
Expand All @@ -19,6 +18,8 @@

file_path = "daily_events.json"

CONCURRENCY = 150


async def generate_daily_events(articles: list[dict]) -> List[EventPublic]:
res = []
Expand All @@ -40,14 +41,9 @@ async def generate_daily_event(article: dict, res: list, semaphore: asyncio.Sema
await asyncio.sleep(1)


# def store_daily_analyses(events: List[EventLLM]):
# for event in events:
# event.analysis_list.


async def process_daily_articles(articles: list[dict]):
await generate_daily_events(articles)
events_ids = populate(file_path=file_path)
events_ids = populate()

with Session(engine) as session:
analyses = session.scalars(
Expand Down
2 changes: 1 addition & 1 deletion backend/src/lm/generate_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

lm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0.3, max_retries=5)
lm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, max_retries=5)


class CategoryAnalysis(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion backend/src/lm/generate_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def generate_points_from_question(question: str) -> dict:
return points


def get_relevant_analyses(question: str, analyses_per_point: int = 3) -> List[str]:
def get_relevant_analyses(question: str, analyses_per_point: int = 5) -> dict:
points = generate_points_from_question(question)

for_pts = points.get("for_points", [])
Expand Down
107 changes: 51 additions & 56 deletions backend/src/lm/generate_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from src.lm.prompts import QUESTION_ANALYSIS_GEN_SYSPROMPT as SYSPROMPT
from src.lm.prompts import QUESTION_ANALYSIS_GEN_SYSPROMPT_2 as SYSPROMPT
import json

from sqlalchemy.orm import Session
Expand All @@ -12,66 +12,61 @@
from src.events.models import Event


class Elaborations(BaseModel):
for_points: list[str]
against_points: list[str]


def format_analyses(relevant_analyses: dict, question: str):
# Given relevant analyses
# for each point add an elaboration and delete score
return {
"question": question,
"for_points": [
{
"point": point["point"],
"examples": [
{
"event_title": get_event_by_id(analysis["event_id"]).title,
"event_description": get_event_by_id(
analysis["event_id"]
).description,
"analysis": analysis["content"],
}
for analysis in point["analyses"]
],
}
for point in relevant_analyses["for_points"]
],
"against_points": [
{
"point": point["point"],
"examples": [
{
"event": get_event_by_id(analysis["event_id"]).title,
"event_description": get_event_by_id(
analysis["event_id"]
).description,
"analysis": analysis["content"],
}
for analysis in point["analyses"]
],
}
for point in relevant_analyses["against_points"]
],
}


def get_event_by_id(event_id: int) -> Event:
with Session(engine) as session:
result = session.scalars(select(Event).where(Event.id == event_id)).first()
return result


def format_prompt_input(question: str, analysis: dict, point: str) -> str:
event_id = analysis.get("event_id")
event = get_event_by_id(event_id)
event_title = event.title
event_description = event.description
analysis_content = analysis.get("content")

return f"""
Question: {question}
Point: {point}
Event_Title: {event_title}
Event_Description: {event_description}
Analysis: {analysis_content}
"""


def generate_response(question: str) -> dict:
relevant_analyses = get_relevant_analyses(question)
formatted_analyses = format_analyses(relevant_analyses, question)
messages = [
SystemMessage(content=SYSPROMPT),
HumanMessage(content=json.dumps(formatted_analyses)),
]

result = lm_model.invoke(messages)
parser = JsonOutputParser(pydantic_object=Elaborations)
elaborations = parser.invoke(result)
return elaborations

for point_dict in (
relevant_analyses["for_points"] + relevant_analyses["against_points"]
):
point = point_dict.get("point")
analyses = point_dict.get("analyses")
for analysis in analyses:
prompt_input = format_prompt_input(question, analysis, point)
messages = [
SystemMessage(content=SYSPROMPT),
HumanMessage(content=prompt_input),
]

result = lm_model.invoke(messages)

analysis["elaborations"] = result.content

return relevant_analyses

# formatted_analyses = format_analyses(relevant_analyses, question)
# messages = [
# SystemMessage(content=SYSPROMPT),
# HumanMessage(content=json.dumps(formatted_analyses)),
# ]

# result = lm_model.invoke(messages)
# parser = JsonOutputParser(pydantic_object=Elaborations)
# elaborations = parser.invoke(result)
# return elaborations


if __name__ == "__main__":
question = "Should the government provide free education for all citizens?"
print(generate_response(question))
27 changes: 27 additions & 0 deletions backend/src/lm/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@
The question:
"""

QUESTION_ANALYSIS_GEN_SYSPROMPT_2 = """
You are a Singaporean student studying for your GCE A Levels General Paper.
You will be given a General Paper essay question that is argumentative or discursive in nature.
You will also be given a point that either supports or refutes the argument in the question and the reason for the point.
You will be given an analysis of a potentially relevant example event that can be used to correspondingly refute or support the argument given in the point above.
Your task:
Given the example event, you should provide a detailed elaboration illustrating how this event can be used as an example to support or refute the argument in the question.
If the example event is relevant to the point, you should provide a coherent and detailed elaboration of the point using the example event and analysis as support for the argument.
The elaboration should be specific to the category of the event and should be tailored to the context of General Paper essays. Provide coherent arguments and insights. Be sure to give a detailed elaboration of 3-4 sentences.
For the elaboration, remember that this is in the context of General Paper which emphasises critical thinking and the ability to construct coherent arguments.
Important note: Structure your elaborations using this format: "<A statement that clearly supports/refutes the given question>. <clear reason based on the event supporting the statement>". The explanation should leave no ambiguity about why the event strengthens or weakens the argument.
If the example event given is unlikely to be relevant in supporting/refuting the argument, you must return "NOT RELEVANT" as the elaboration.
Important Note: In your analysis, you should not mention "General Paper" or "A Levels".
Important Note: Do not provide any new points or examples. You should only elaborate on the examples given in the input or skip them if they are not relevant to the question or the points given.
Final Check: Before generating an elaboration, verify whether the example *directly* reinforces or counters the argument made in the point. If the connection is very weak or unclear, return "NOT RELEVANT".
Final Check: Ensure that if the example is not directly relevant to the point or only tangentially related, you should return "NOT RELEVANT" as the elaboration.
Your response should be a single string that is either "NOT RELEVANT" or the elaboration of the point using the example event and analysis as support for the argument.
Given inputs:
"""

QUESTION_ANALYSIS_GEN_SYSPROMPT = """
You are a Singaporean student studying for your GCE A Levels General Paper.
You will be given a General Paper essay question that is argumentative or discursive in nature.
Expand Down
12 changes: 12 additions & 0 deletions backend/src/scripts/initial_populate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
from src.lm.generate_events import generate_events
from src.scrapers.guardian.get_articles import get_articles


FILE_PATH = "backend/initial_events.json"
LIMIT = 1000


if __name__ == "__main__":
articles = get_articles(LIMIT)
asyncio.run(generate_events(articles))
4 changes: 2 additions & 2 deletions backend/src/scripts/populate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


# Populate the db with events from lm_events_output.json
def populate() -> list[int]:
def populate(file_path: str) -> list[int]:
ids = []
with open("backend/lm_events_output.json", "r") as f:
with open(file_path, "r") as f:
events = json.load(f)

for event in events:
Expand Down
Loading

0 comments on commit aa0ee80

Please sign in to comment.