Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/Agent_Hub_Dev' into Agent_Hub_Dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Oct 20, 2023
2 parents 6cc194b + bb882a8 commit 8db497f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 33 deletions.
6 changes: 3 additions & 3 deletions pilot/base_modules/agent/commands/command_mange.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __is_need_wait_plugin_call(self, api_call_context):
i += 1
return False

def __check_last_plugin_call_ready(self, all_context):
def check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)

Expand Down Expand Up @@ -359,7 +359,7 @@ def to_view_text(self, api_status: PluginStatus):
def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
Expand All @@ -379,7 +379,7 @@ def run(self, llm_text):
def run_display_sql(self, llm_text, sql_run_func):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
Expand Down
34 changes: 20 additions & 14 deletions pilot/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.data_loader import DbDataLoader
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.base_modules.agent.commands.command_mange import ApiCall


router = APIRouter()
CFG = Config()
Expand Down Expand Up @@ -101,12 +103,15 @@ async def get_editor_sql(con_uid: str, round: int):
logger.info(
f'history ai json resp:{element["data"]["content"]}'
)
context = (
element["data"]["content"]
.replace("\\n", " ")
.replace("\n", " ")
)
return Result.succ(json.loads(context))
api_call = ApiCall()
result = {}
result['thoughts'] = element["data"]["content"]
if api_call.check_last_plugin_call_ready(element["data"]["content"]):
api_call.update_from_context(element["data"]["content"])
if len(api_call.plugin_status_map) > 0:
first_item = next(iter(api_call.plugin_status_map.items()))[1]
result['sql'] = first_item.args["sql"]
return Result.succ(result)
return Result.faild(msg="not have sql!")


Expand Down Expand Up @@ -156,17 +161,18 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
)
)[0]
if edit_round:
new_ai_text = ""
for element in edit_round["messages"]:
if element["type"] == "ai":
db_resp = json.loads(element["data"]["content"])
db_resp["thoughts"] = sql_edit_context.new_speak
db_resp["sql"] = sql_edit_context.new_sql
element["data"]["content"] = json.dumps(db_resp)
new_ai_text = element["data"]["content"]
new_ai_text.replace(sql_edit_context.old_sql, sql_edit_context.new_sql)
element["data"]["content"] = new_ai_text

for element in edit_round["messages"]:
if element["type"] == "view":
data_loader = DbDataLoader()
element["data"]["content"] = data_loader.get_table_view_by_conn(
conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
)
api_call = ApiCall()
new_view_text = api_call.run_display_sql(new_ai_text, conn.run_to_df)
element["data"]["content"] = new_view_text
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild(msg="Edit Faild!")
Expand Down
13 changes: 10 additions & 3 deletions pilot/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.base_modules.agent.commands.command_mange import ApiCall

CFG = Config()

Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self, chat_param: Dict):

self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200
self.api_call = ApiCall(display_registry=CFG.command_disply)

def generate_input_values(self):
"""
Expand Down Expand Up @@ -69,6 +71,11 @@ def generate_input_values(self):
}
return input_values

def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql)
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.run_display_sql(text, self.database.run_to_df)
#
# def do_action(self, prompt_response):
# print(f"do_action:{prompt_response}")
# return self.database.run(prompt_response.sql)
53 changes: 40 additions & 13 deletions pilot/scene/chat_db/auto_execute/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,59 @@

CFG = Config()

PROMPT_SCENE_DEFINE = "You are a SQL expert. "

_DEFAULT_TEMPLATE = """
_PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "

_DEFAULT_TEMPLATE_EN = """
Given an input question, create a syntactically correct {dialect} sql.
Table structure information:
{table_info}
Constraint:
1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
2. Do not query columns that do not exist. Pay attention to which column is in which table.
3. Replace the corresponding sql into the sql field in the returned result
4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
Please make sure to respond as following format:
thoughts summary to say to user.<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
Question: {input}
"""

Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying.
Only use the following tables schema to generate sql:
{table_info}
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
_DEFAULT_TEMPLATE_ZH = """
给定一个输入问题,创建一个语法正确的 {dialect} sql。
已知表结构信息:
{table_info}
Question: {input}
约束:
1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
2. 不要查询不存在的列,注意哪一列位于哪张表中。
3.将对应的sql替换到返回结果中的sql字段中
4.除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。
Respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
请务必按照以下格式回复:
对用户说的想法摘要。<api-call><name>response_table</name><args><sql>要运行的 SQL</sql></args></api-call>
问题:{input}
"""

_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)

PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)

RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
}

PROMPT_SEP = SeparatorStyle.SINGLE.value

PROMPT_NEED_NEED_STREAM_OUT = False
PROMPT_NEED_NEED_STREAM_OUT = True

# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
Expand All @@ -43,7 +70,7 @@
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
# response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
Expand Down

0 comments on commit 8db497f

Please sign in to comment.