-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathapp.py
148 lines (111 loc) · 4.44 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import streamlit as st
from langchain_community.chat_models import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
def connectDatabase(username, port, host, password, database):
mysql_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
st.session_state.db = SQLDatabase.from_uri(mysql_uri)
def runQuery(query):
return st.session_state.db.run(query) if st.session_state.db else "Please connect to database"
def getDatabaseSchema():
return st.session_state.db.get_table_info() if st.session_state.db else "Please connect to database"
llm = ChatOllama(model="llama3")
def getQueryFromLLM(question):
template = """below is the schema of MYSQL database, read the schema carefully about the table and column names. Also take care of table or column name case sensitivity.
Finally answer user's question in the form of SQL query.
{schema}
please only provide the SQL query and nothing else
for example:
question: how many albums we have in database
SQL query: SELECT COUNT(*) FROM album
question: how many customers are from Brazil in the database ?
SQL query: SELECT COUNT(*) FROM customer WHERE country=Brazil
your turn :
question: {question}
SQL query :
please only provide the SQL query and nothing else
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
response = chain.invoke({
"question": question,
"schema": getDatabaseSchema()
})
return response.content
def getResponseForQueryResult(question, query, result):
template2 = """below is the schema of MYSQL database, read the schema carefully about the table and column names of each table.
Also look into the conversation if available
Finally write a response in natural language by looking into the conversation and result.
{schema}
Here are some example for you:
question: how many albums we have in database
SQL query: SELECT COUNT(*) FROM album;
Result : [(34,)]
Response: There are 34 albums in the database.
question: how many users we have in database
SQL query: SELECT COUNT(*) FROM customer;
Result : [(59,)]
Response: There are 59 amazing users in the database.
question: how many users above are from india we have in database
SQL query: SELECT COUNT(*) FROM customer WHERE country=india;
Result : [(4,)]
Response: There are 4 amazing users in the database.
your turn to write response in natural language from the given result :
question: {question}
SQL query : {query}
Result : {result}
Response:
"""
prompt2 = ChatPromptTemplate.from_template(template2)
chain2 = prompt2 | llm
response = chain2.invoke({
"question": question,
"schema": getDatabaseSchema(),
"query": query,
"result": result
})
return response.content
st.set_page_config(
page_icon="🤖",
page_title="Chat with MYSQL DB",
layout="centered"
)
question = st.chat_input('Chat with your mysql database')
if "chat" not in st.session_state:
st.session_state.chat = []
if question:
if "db" not in st.session_state:
st.error('Please connect database first.')
else:
st.session_state.chat.append({
"role": "user",
"content": question
})
query = getQueryFromLLM(question)
print(query)
result = runQuery(query)
print(result)
response = getResponseForQueryResult(question, query, result)
st.session_state.chat.append({
"role": "assistant",
"content": response
})
for chat in st.session_state.chat:
st.chat_message(chat['role']).markdown(chat['content'])
with st.sidebar:
st.title('Connect to database')
st.text_input(label="Host", key="host", value="localhost")
st.text_input(label="Port", key="port", value="3306")
st.text_input(label="Username", key="username", value="root")
st.text_input(label="Password", key="password", value="", type="password")
st.text_input(label="Database", key="database", value="rag_test")
connectBtn = st.button("Connect")
if connectBtn:
connectDatabase(
username=st.session_state.username,
port=st.session_state.port,
host=st.session_state.host,
password=st.session_state.password,
database=st.session_state.database,
)
st.success("Database connected")