-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_connector.py
63 lines (52 loc) · 2.05 KB
/
model_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
from config import client, OPENAI_MODEL
def generate_sql_from_nl(prompt, schema_info):
"""
Generate an SQL query from a natural language prompt.
Args:
prompt (str): The natural language prompt provided by the user.
schema_info (dict): A dictionary containing the database schema information.
Returns:
str: The generated SQL query.
"""
schema_description = "The database has the following tables:\n"
for table, cols in schema_info.items():
schema_description += f"- {table}({', '.join(cols)})\n"
system_message = (
"You are a helpful assistant that converts natural language into SQL queries. "
"The SQL should be valid for the provided database schema. "
"Don't include explanation text, just provide a single SQL query."
)
user_message = f"{schema_description}\nUser query: {prompt}\nReturn just the SQL without explanation."
# Call the new chat completion method
response = client.chat.completions.create(
model=OPENAI_MODEL,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0,
)
sql = response.choices[0].message.content.strip()
sql = sql.replace("LIKE", "ILIKE")
return sql
def format_response(data, sql):
"""
Format the response for the user.
:param data: The data returned from the SQL query.
:param sql: The SQL query.
:return: The formatted response.
"""
if len(data) == 0:
nl_response = "I found no matching results."
return f"{nl_response}\n\nSQL Query:\n{sql}"
elif len(data) == 1:
row = data[0]
nl_response = "Here is the single matching result:\n"
for k, v in row.items():
nl_response += f"{k}: {v}\n"
return f"{nl_response}\nSQL Query:\n{sql}"
else:
df = pd.DataFrame(data)
table_view = df.to_markdown(index=False)
return f"Multiple results found:\n{table_view}\n\nSQL Query:\n{sql}"