-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
969 additions
and
262 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
|
||
import streamlit as st | ||
from langchain.graphs import Neo4jGraph | ||
from langchain_openai.chat_models import ChatOpenAI | ||
from streamlit import session_state as ss | ||
|
||
from components import chat, display_chat_history, sidebar | ||
from src.ps_genai_agents.agents.graph import create_text2cypher_graph_agent | ||
|
||
|
||
def initialize_state() -> None: | ||
""" | ||
Initialize the application state. | ||
""" | ||
|
||
if "agent" not in ss: | ||
ss["llm"] = ChatOpenAI(model="gpt-4o") | ||
ss["graph"] = Neo4jGraph( | ||
url=os.environ.get("IQS_NEO4J_URI"), | ||
username=os.environ.get("IQS_NEO4J_USERNAME"), | ||
password=os.environ.get("IQS_NEO4J_PASSWORD"), | ||
enhanced_schema=True, | ||
driver_config={"liveness_check_timeout": 0}, | ||
) | ||
ss["agent"] = create_text2cypher_graph_agent( | ||
chat_llm=ss["llm"], neo4j_graph=ss["graph"] | ||
) | ||
ss["messages"] = list() | ||
ss["source"] = "IQS" | ||
|
||
|
||
def run_app(): | ||
""" | ||
Run the Streamlit application. | ||
""" | ||
|
||
st.title("PS GenAI Retreat Workshop") | ||
sidebar() | ||
display_chat_history() | ||
# Prompt for user input and save and display | ||
if question := st.chat_input(): | ||
ss["current_question"] = question | ||
|
||
if "current_question" in ss: | ||
chat(ss.current_question) | ||
|
||
|
||
if __name__ == "__main__": | ||
initialize_state() | ||
run_app() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""This module contains the components used in the Streamlit app.""" | ||
|
||
from .chat import chat, display_chat_history | ||
from .sidebar import sidebar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import io | ||
import zipfile | ||
from typing import Any, Dict, List | ||
from uuid import uuid4 | ||
|
||
import pandas as pd | ||
import streamlit as st | ||
from neo4j.exceptions import SessionExpired | ||
from streamlit import session_state as ss | ||
|
||
|
||
def append_user_question(question: str) -> None: | ||
ss.messages.append({"role": "user", "content": question}) | ||
st.chat_message("user").markdown(question) | ||
|
||
|
||
def append_llm_response(question: str) -> None: | ||
with st.chat_message("assistant"): | ||
message_placeholder = st.empty() | ||
message_placeholder.status("thinking...") | ||
print("question: ", question) | ||
response: Any = ss.agent.invoke({"input": question, "chat_history": []})[ | ||
"agent_outcome" | ||
] | ||
|
||
message_placeholder.markdown(response.answer) | ||
|
||
show_response_information(response=response) | ||
|
||
ss.messages.append({"role": "assistant", "content": response}) | ||
|
||
|
||
def show_response_information(response: Any) -> None: | ||
if ( | ||
hasattr(response, "cypher") | ||
and response.cypher | ||
and not response.answer.startswith("I can only generate queries based on") | ||
): | ||
download_csv_button(cypher_result=response.cypher_result) | ||
|
||
with st.expander("Cypher"): | ||
if isinstance(response.cypher, str): | ||
st.code(response.cypher, language="cypher") | ||
st.json(response.cypher_result, expanded=False) | ||
else: | ||
[ | ||
( | ||
st.write(response.sub_questions[i]), | ||
st.code(response.cypher[i], language="cypher"), | ||
st.json(response.cypher_result[i], expanded=False), | ||
) | ||
for i in range(len(response.cypher)) | ||
] | ||
|
||
if hasattr(response, "sources") and response.sources: | ||
with st.expander("Vector Search"): | ||
st.write("Source Node IDs") | ||
st.write(response.sources) | ||
|
||
|
||
def chat(question: str): | ||
try: | ||
append_user_question(question=question) | ||
append_llm_response(question=question) | ||
except SessionExpired as e: | ||
st.error("Neo4j Session expired. Please restart the application.") | ||
|
||
|
||
def display_chat_history() -> None: | ||
for message in ss.messages: | ||
print(message) | ||
with st.chat_message(message["role"]): | ||
if message["role"] == "user": | ||
st.markdown(message["content"]) | ||
else: | ||
st.markdown(message["content"].answer) | ||
if not isinstance(message["content"], str): | ||
show_response_information(response=message["content"]) | ||
|
||
|
||
def prepare_csv(cypher_result: List[Dict[str, Any]]) -> str: | ||
# if not cypher_result: return pd.DataFrame().to_csv().encode("utf-8") | ||
|
||
index = [i for i in range(len(cypher_result[0].values()))] | ||
return pd.DataFrame(data=cypher_result).to_csv(index=index).encode("utf-8") | ||
|
||
|
||
@st.experimental_fragment() | ||
def download_csv_button(cypher_result: List[Dict[str, Any]]) -> None: | ||
try: | ||
print("cypher result in button", cypher_result) | ||
if len(cypher_result) > 0 and isinstance(cypher_result[0], list): | ||
content = [prepare_csv(result) for result in cypher_result if result] | ||
buf = io.BytesIO() | ||
with zipfile.ZipFile(buf, "x") as zip: | ||
for file_num, csv in enumerate(content): | ||
zip.writestr(f"cypher_result_part_{str(file_num+1)}.csv", csv) | ||
|
||
st.download_button( | ||
label="Download Cypher Results Tables as CSV", | ||
data=buf.getvalue(), | ||
file_name="cypher_results.zip", | ||
mime="application/zip", | ||
help="The cypher results .csv files in a .zip.", | ||
key=str(uuid4()), | ||
) | ||
else: | ||
csv = prepare_csv(cypher_result=cypher_result) | ||
st.download_button( | ||
label="Download Cypher Results Table as CSV", | ||
data=csv, | ||
file_name=f"cypher_results.csv", | ||
mime="text/csv", | ||
help="The cypher results .csv file.", | ||
key=str(uuid4()), | ||
) | ||
except Exception as e: | ||
print("Unable to generate Download Button for most recent question.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List, Literal | ||
|
||
iqs_questions = [ | ||
"How many vehicles are there?", | ||
"Summarize the responses under fcd10 for honda pilot. What is the men to women proportion for these responses and what is the problem for fcd10?", | ||
"What are the top 5 most severe problems for women aged 30-34 for all Acura models?", | ||
"Please summarize the verbatims for 2023 RDX for question 010 Trunk/TG Touch-Free Sensor DTU and create categories for the problems. As an output, I want the summary, corresponding categories and their verbatims", | ||
"What are the top 5 problems about seats for each age buckets for men over the age of 53?", | ||
"What are the top 5 problems about seats for each age buckets over the age of 53? Summarize the responses for each bucket", | ||
"What is the customer with the most reported problems? Can you list the problems, summarize them and include the problem id's as well as the customer gender and age range.", | ||
"Summarize and compare the sentiment for responses related to noise and sound in Honda Accord and Honda Pilot for women and return the number of responses considered.", | ||
"What color is the sky?", | ||
] | ||
|
||
patient_journey_questions = [] | ||
|
||
|
||
def get_demo_questions(source: Literal["IQS", "Patient Journey"]) -> List[str]: | ||
return iqs_questions if source == "IQS" else patient_journey_questions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
from langchain.graphs import Neo4jGraph | ||
from streamlit import session_state as ss | ||
from streamlit import sidebar as sb | ||
|
||
from src.ps_genai_agents.agents.graph import create_text2cypher_graph_agent | ||
|
||
from .questions import get_demo_questions | ||
|
||
|
||
def sidebar() -> None: | ||
""" | ||
The Streamlit app side bar. | ||
""" | ||
|
||
source = sb.radio("Data Source", options=["IQS", "Patient Journey"]) | ||
if source != ss["source"]: | ||
ss["source"] = source | ||
if source == "IQS": | ||
ss["graph"] = Neo4jGraph( | ||
url=os.environ.get("IQS_NEO4J_URI"), | ||
username=os.environ.get("IQS_NEO4J_USERNAME"), | ||
password=os.environ.get("IQS_NEO4J_PASSWORD"), | ||
enhanced_schema=True, | ||
driver_config={"liveness_check_timeout": 0}, | ||
) | ||
else: | ||
ss["graph"] = Neo4jGraph( | ||
url=os.environ.get("PJ_NEO4J_URI"), | ||
username=os.environ.get("PJ_NEO4J_USERNAME"), | ||
password=os.environ.get("PJ_NEO4J_PASSWORD"), | ||
enhanced_schema=True, | ||
driver_config={"liveness_check_timeout": 0}, | ||
) | ||
ss["agent"] = create_text2cypher_graph_agent( | ||
chat_llm=ss["llm"], neo4j_graph=ss["graph"] | ||
) | ||
|
||
demo_questions = get_demo_questions(source=ss["source"]) | ||
|
||
sb.title("Demo Questions") | ||
with sb.expander("Demo Questions"): | ||
for ex in demo_questions: | ||
if sb.button(label=ex, key=ex): | ||
ss.current_question = ex | ||
|
||
sb.divider() | ||
if len(ss.messages) > 0: | ||
if sb.button("Reset Chat", type="primary"): | ||
ss.messages = list() | ||
del ss.current_question |
Oops, something went wrong.