Skip to content

Commit

Permalink
Merge pull request #104 from cs3216-a3-group-4/feat-prompt-tuning
Browse files Browse the repository at this point in the history
Feat prompt tuning
  • Loading branch information
marcus-ny authored Sep 26, 2024
2 parents 445671e + 6a5b800 commit 06941c2
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
1 change: 0 additions & 1 deletion backend/src/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def _get_env_var(name: str, default: str | None = None, required: bool = True):
GOOGLE_CLIENT_ID: str = _get_env_var("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET: str = _get_env_var("GOOGLE_CLIENT_SECRET")
GOOGLE_REDIRECT_URI: str = _get_env_var("GOOGLE_REDIRECT_URI")
LANGCHAIN_TRACING_V2: str = _get_env_var("LANGCHAIN_TRACING_V2")
LANGCHAIN_API_KEY: str = _get_env_var("LANGCHAIN_API_KEY")
OPENAI_API_KEY: str = _get_env_var("OPENAI_API_KEY")
PINECONE_API_KEY: str = _get_env_var("PINECONE_API_KEY")
Expand Down
4 changes: 1 addition & 3 deletions backend/src/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from langchain_core.documents import Document

from src.common.constants import LANGCHAIN_API_KEY
from src.common.constants import LANGCHAIN_TRACING_V2
from src.common.constants import OPENAI_API_KEY
from src.common.constants import PINECONE_API_KEY

Expand All @@ -15,7 +14,6 @@
import os
import time

os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2
os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
Expand All @@ -25,7 +23,7 @@


def create_vector_store():
index_name = "langchain-test-index-5" # change to create a new index
index_name = "main-index-1" # change to create a new index

existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]

Expand Down
12 changes: 7 additions & 5 deletions backend/src/lm/generate_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from src.scrapers.guardian.get_articles import get_articles
from typing import List
from pydantic import BaseModel
from src.common.constants import LANGCHAIN_API_KEY
from src.common.constants import LANGCHAIN_TRACING_V2
from src.common.constants import OPENAI_API_KEY
from src.common.constants import LANGCHAIN_API_KEY, OPENAI_API_KEY
from src.lm.prompts import EVENT_GEN_SYSPROMPT as SYSPROMPT
import asyncio

import os

os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

lm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, max_retries=5)
lm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0.3, max_retries=5)


lm_model_essay = ChatOpenAI(
model="gpt-4o-mini", temperature=0.7, frequency_penalty=0.5, max_retries=5
)


class CategoryAnalysis(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion backend/src/lm/generate_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel
from src.lm.generate_events import lm_model
from src.lm.generate_events import lm_model_essay as lm_model
from src.embeddings.vector_store import get_similar_results

from src.lm.prompts import QUESTION_POINT_GEN_SYSPROMPT as SYSPROMPT
Expand All @@ -23,6 +23,7 @@ def generate_points_from_question(question: str) -> dict:


def get_relevant_analyses(question: str, analyses_per_point: int = 5) -> dict:
print(f"Freq penalty: {lm_model.frequency_penalty}")
points = generate_points_from_question(question)

for_pts = points.get("for_points", [])
Expand Down
2 changes: 1 addition & 1 deletion backend/src/lm/generate_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.lm.generate_points import get_relevant_analyses
from src.lm.generate_events import lm_model
from src.lm.generate_events import lm_model_essay as lm_model
from langchain_core.messages import HumanMessage, SystemMessage
from src.lm.prompts import QUESTION_ANALYSIS_GEN_SYSPROMPT_2 as SYSPROMPT
from src.lm.prompts import (
Expand Down

0 comments on commit 06941c2

Please sign in to comment.