Skip to content

Commit

Permalink
allow users to choose which openai llm to use
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 15, 2024
1 parent 97533eb commit 66219a8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
1 change: 0 additions & 1 deletion wren-ai-service/eval/.env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
OPENAI_API_KEY=
OPENAI_GENERATION_MODEL=gpt-3.5-turbo
WREN_IBIS_ENDPOINT=http://localhost:8000
bigquery.project-id=
bigquery.dataset-id=
Expand Down
53 changes: 21 additions & 32 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DATA_SOURCES,
get_contexts_from_sqls,
get_eval_dataset_in_toml_string,
get_llm_client,
get_openai_client,
get_question_sql_pairs,
is_sql_valid,
prettify_sql,
Expand All @@ -23,10 +23,12 @@
st.title("WrenAI Data Curation App")


llm_client = get_llm_client()
llm_client = get_openai_client()


# session states
if "llm_model" not in st.session_state:
st.session_state["llm_model"] = None
if "deployment_id" not in st.session_state:
st.session_state["deployment_id"] = str(uuid.uuid4())
if "mdl_json" not in st.session_state:
Expand Down Expand Up @@ -55,7 +57,9 @@ def on_change_upload_eval_dataset():
def on_click_generate_question_sql_pairs(llm_client: AsyncClient):
st.toast("Generating question-sql-pairs...")
st.session_state["llm_question_sql_pairs"] = asyncio.run(
get_question_sql_pairs(llm_client, st.session_state["mdl_json"])
get_question_sql_pairs(
llm_client, st.session_state["llm_model"], st.session_state["mdl_json"]
)
)


Expand Down Expand Up @@ -89,6 +93,11 @@ def on_click_setup_uploaded_file():
st.session_state["connection_info"] = None


def on_change_llm_model():
st.toast(f"Switching LLM model to {st.session_state["select_llm_model"]}")
st.session_state["llm_model"] = st.session_state["select_llm_model"]


def on_change_sql(i: int, key: str):
sql = st.session_state[key]

Expand Down Expand Up @@ -140,19 +149,6 @@ def on_click_add_candidate_dataset(i: int, categories: list):
st.session_state["candidate_dataset"].append(dataset_to_add)


def on_change_sql_context(i: int):
if i == -1:
st.session_state.get("user_question_sql_pair", {}).get("context", []).append(
st.session_state["user_context_input"]
)
st.session_state["user_context_input"] = ""
else:
st.session_state["llm_question_sql_pairs"][i]["context"].append(
st.session_state[f"context_input_{i}"]
)
st.session_state[f"context_input_{i}"] = ""


def on_change_user_question():
if not st.session_state["user_question_sql_pair"]:
st.session_state["user_question_sql_pair"] = {
Expand Down Expand Up @@ -180,6 +176,13 @@ def on_click_remove_candidate_dataset_button(i: int):
on_change=on_click_setup_uploaded_file,
)

st.selectbox(
label="Select which LLM model you want to use",
options=["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"],
index=0,
key="select_llm_model",
on_change=on_change_llm_model,
)

tab_create_dataset, tab_modify_dataset = st.tabs(
["Create New Evaluation Dataset", "Modify Saved Evaluation Dataset"]
Expand Down Expand Up @@ -249,20 +252,13 @@ def on_click_remove_candidate_dataset_button(i: int):
value=[],
key=f"categories_{i}",
)
st.text_input(
f"Context {i}",
placeholder="Enter the context of SQL manually with format <table_name>.<column_name>",
key=f"context_input_{i}",
on_change=on_change_sql_context,
args=(i,),
)
st.multiselect(
label=f"Context {i}",
options=question_sql_pair["context"],
default=question_sql_pair["context"],
key=f"context_{i}",
help="Contexts are automatically generated based on the SQL once you save the changes of the it(ctrl+enter or command+enter)",
label_visibility="hidden",
disabled=True,
)
st.text_area(
f"SQL {i}",
Expand Down Expand Up @@ -313,13 +309,6 @@ def on_click_remove_candidate_dataset_button(i: int):
value=[],
key="user_categories",
)
st.text_input(
"Context",
placeholder="Enter the context of SQL manually with format <table_name>.<column_name>",
key="user_context_input",
on_change=on_change_sql_context,
args=(-1,),
)
st.multiselect(
label="Context",
options=st.session_state.get("user_question_sql_pair", {}).get(
Expand All @@ -330,7 +319,7 @@ def on_click_remove_candidate_dataset_button(i: int):
),
key="user_context",
help="Contexts are automatically generated based on the SQL once you save the changes of the it(ctrl+enter or command+enter)",
label_visibility="hidden",
disabled=True,
)
st.text_area(
"SQL",
Expand Down
8 changes: 5 additions & 3 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DATA_SOURCES = ["bigquery"]


def get_llm_client() -> AsyncClient:
def get_openai_client() -> AsyncClient:
return AsyncClient(
api_key=os.getenv("OPENAI_API_KEY"),
)
Expand Down Expand Up @@ -285,6 +285,8 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}'
)
print(f"CONTEXTS: {sorted(set(contexts))}")
print("\n\n")

return sorted(set(contexts))

async with aiohttp.ClientSession():
Expand All @@ -299,7 +301,7 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):


async def get_question_sql_pairs(
llm_client: AsyncClient, mdl_json: dict, num_pairs: int = 10
llm_client: AsyncClient, llm_model: str, mdl_json: dict, num_pairs: int = 10
) -> list[dict]:
messages = [
{
Expand Down Expand Up @@ -338,7 +340,7 @@ async def get_question_sql_pairs(

try:
response = await llm_client.chat.completions.create(
model=os.getenv("GENERATION_MODEL", "gpt-3.5-turbo"),
model=llm_model,
messages=messages,
response_format={"type": "json_object"},
max_tokens=4096,
Expand Down

0 comments on commit 66219a8

Please sign in to comment.