Skip to content

Commit

Permalink
add streamlit app
Browse files Browse the repository at this point in the history
  • Loading branch information
a-s-g93 committed Sep 25, 2024
1 parent cb3fc4f commit b6e5e9a
Show file tree
Hide file tree
Showing 16 changed files with 969 additions and 262 deletions.
51 changes: 51 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os

import streamlit as st
from langchain.graphs import Neo4jGraph
from langchain_openai.chat_models import ChatOpenAI
from streamlit import session_state as ss

from components import chat, display_chat_history, sidebar
from src.ps_genai_agents.agents.graph import create_text2cypher_graph_agent


def initialize_state() -> None:
"""
Initialize the application state.
"""

if "agent" not in ss:
ss["llm"] = ChatOpenAI(model="gpt-4o")
ss["graph"] = Neo4jGraph(
url=os.environ.get("IQS_NEO4J_URI"),
username=os.environ.get("IQS_NEO4J_USERNAME"),
password=os.environ.get("IQS_NEO4J_PASSWORD"),
enhanced_schema=True,
driver_config={"liveness_check_timeout": 0},
)
ss["agent"] = create_text2cypher_graph_agent(
chat_llm=ss["llm"], neo4j_graph=ss["graph"]
)
ss["messages"] = list()
ss["source"] = "IQS"


def run_app():
"""
Run the Streamlit application.
"""

st.title("PS GenAI Retreat Workshop")
sidebar()
display_chat_history()
# Prompt for user input and save and display
if question := st.chat_input():
ss["current_question"] = question

if "current_question" in ss:
chat(ss.current_question)


if __name__ == "__main__":
initialize_state()
run_app()
1 change: 0 additions & 1 deletion app/README.md

This file was deleted.

1 change: 0 additions & 1 deletion app/components/__init__.py

This file was deleted.

4 changes: 4 additions & 0 deletions components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""This module contains the components used in the Streamlit app."""

from .chat import chat, display_chat_history
from .sidebar import sidebar
118 changes: 118 additions & 0 deletions components/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import io
import zipfile
from typing import Any, Dict, List
from uuid import uuid4

import pandas as pd
import streamlit as st
from neo4j.exceptions import SessionExpired
from streamlit import session_state as ss


def append_user_question(question: str) -> None:
ss.messages.append({"role": "user", "content": question})
st.chat_message("user").markdown(question)


def append_llm_response(question: str) -> None:
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.status("thinking...")
print("question: ", question)
response: Any = ss.agent.invoke({"input": question, "chat_history": []})[
"agent_outcome"
]

message_placeholder.markdown(response.answer)

show_response_information(response=response)

ss.messages.append({"role": "assistant", "content": response})


def show_response_information(response: Any) -> None:
if (
hasattr(response, "cypher")
and response.cypher
and not response.answer.startswith("I can only generate queries based on")
):
download_csv_button(cypher_result=response.cypher_result)

with st.expander("Cypher"):
if isinstance(response.cypher, str):
st.code(response.cypher, language="cypher")
st.json(response.cypher_result, expanded=False)
else:
[
(
st.write(response.sub_questions[i]),
st.code(response.cypher[i], language="cypher"),
st.json(response.cypher_result[i], expanded=False),
)
for i in range(len(response.cypher))
]

if hasattr(response, "sources") and response.sources:
with st.expander("Vector Search"):
st.write("Source Node IDs")
st.write(response.sources)


def chat(question: str):
try:
append_user_question(question=question)
append_llm_response(question=question)
except SessionExpired as e:
st.error("Neo4j Session expired. Please restart the application.")


def display_chat_history() -> None:
for message in ss.messages:
print(message)
with st.chat_message(message["role"]):
if message["role"] == "user":
st.markdown(message["content"])
else:
st.markdown(message["content"].answer)
if not isinstance(message["content"], str):
show_response_information(response=message["content"])


def prepare_csv(cypher_result: List[Dict[str, Any]]) -> str:
# if not cypher_result: return pd.DataFrame().to_csv().encode("utf-8")

index = [i for i in range(len(cypher_result[0].values()))]
return pd.DataFrame(data=cypher_result).to_csv(index=index).encode("utf-8")


@st.experimental_fragment()
def download_csv_button(cypher_result: List[Dict[str, Any]]) -> None:
try:
print("cypher result in button", cypher_result)
if len(cypher_result) > 0 and isinstance(cypher_result[0], list):
content = [prepare_csv(result) for result in cypher_result if result]
buf = io.BytesIO()
with zipfile.ZipFile(buf, "x") as zip:
for file_num, csv in enumerate(content):
zip.writestr(f"cypher_result_part_{str(file_num+1)}.csv", csv)

st.download_button(
label="Download Cypher Results Tables as CSV",
data=buf.getvalue(),
file_name="cypher_results.zip",
mime="application/zip",
help="The cypher results .csv files in a .zip.",
key=str(uuid4()),
)
else:
csv = prepare_csv(cypher_result=cypher_result)
st.download_button(
label="Download Cypher Results Table as CSV",
data=csv,
file_name=f"cypher_results.csv",
mime="text/csv",
help="The cypher results .csv file.",
key=str(uuid4()),
)
except Exception as e:
print("Unable to generate Download Button for most recent question.")
19 changes: 19 additions & 0 deletions components/questions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Literal

iqs_questions = [
"How many vehicles are there?",
"Summarize the responses under fcd10 for honda pilot. What is the men to women proportion for these responses and what is the problem for fcd10?",
"What are the top 5 most severe problems for women aged 30-34 for all Acura models?",
"Please summarize the verbatims for 2023 RDX for question 010 Trunk/TG Touch-Free Sensor DTU and create categories for the problems. As an output, I want the summary, corresponding categories and their verbatims",
"What are the top 5 problems about seats for each age buckets for men over the age of 53?",
"What are the top 5 problems about seats for each age buckets over the age of 53? Summarize the responses for each bucket",
"What is the customer with the most reported problems? Can you list the problems, summarize them and include the problem id's as well as the customer gender and age range.",
"Summarize and compare the sentiment for responses related to noise and sound in Honda Accord and Honda Pilot for women and return the number of responses considered.",
"What color is the sky?",
]

patient_journey_questions = []


def get_demo_questions(source: Literal["IQS", "Patient Journey"]) -> List[str]:
return iqs_questions if source == "IQS" else patient_journey_questions
52 changes: 52 additions & 0 deletions components/sidebar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

from langchain.graphs import Neo4jGraph
from streamlit import session_state as ss
from streamlit import sidebar as sb

from src.ps_genai_agents.agents.graph import create_text2cypher_graph_agent

from .questions import get_demo_questions


def sidebar() -> None:
"""
The Streamlit app side bar.
"""

source = sb.radio("Data Source", options=["IQS", "Patient Journey"])
if source != ss["source"]:
ss["source"] = source
if source == "IQS":
ss["graph"] = Neo4jGraph(
url=os.environ.get("IQS_NEO4J_URI"),
username=os.environ.get("IQS_NEO4J_USERNAME"),
password=os.environ.get("IQS_NEO4J_PASSWORD"),
enhanced_schema=True,
driver_config={"liveness_check_timeout": 0},
)
else:
ss["graph"] = Neo4jGraph(
url=os.environ.get("PJ_NEO4J_URI"),
username=os.environ.get("PJ_NEO4J_USERNAME"),
password=os.environ.get("PJ_NEO4J_PASSWORD"),
enhanced_schema=True,
driver_config={"liveness_check_timeout": 0},
)
ss["agent"] = create_text2cypher_graph_agent(
chat_llm=ss["llm"], neo4j_graph=ss["graph"]
)

demo_questions = get_demo_questions(source=ss["source"])

sb.title("Demo Questions")
with sb.expander("Demo Questions"):
for ex in demo_questions:
if sb.button(label=ex, key=ex):
ss.current_question = ex

sb.divider()
if len(ss.messages) > 0:
if sb.button("Reset Chat", type="primary"):
ss.messages = list()
del ss.current_question
Loading

0 comments on commit b6e5e9a

Please sign in to comment.