Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/bird zf #205

Merged
merged 7 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,22 @@
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"train_tables_file": "tables.json",
"dev_tables_file": "tables.json",
"db_id_name": "db_id",
"output_name": "query",
"is_multiple_turn": False,
}
# {
# "data_source": "bird",
# "train_file": ["train/train.json"],
# "dev_file": ["dev/dev.json"],
# "train_tables_file": "train/train_tables.json",
# "dev_tables_file": "dev/dev_tables.json",
# "db_id_name": "db_id",
# "output_name": "SQL",
# "is_multiple_turn": False,
# }
# ,
# {
# "data_source": "chase",
Expand Down
46 changes: 37 additions & 9 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
self,
data_file_list,
table_file,
db_id_name,
output_name,
is_multiple_turn=False,
):
"""
TO DO:
Expand Down Expand Up @@ -65,12 +70,29 @@ def decode_json_file(

# get primary key info
for j in range(len(primary_key)):
if coloumns[primary_key[j] - 1][0] == i:
if type(primary_key[j]) == int:
if coloumns[primary_key[j] - 1][0] == i:
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
)
# combination primary key
elif type(primary_key[j]) == list:
combine_p = "The combination of ("
keys = []
for k in range(len(primary_key[j])):
if coloumns[primary_key[j][k] - 1][0] == i:
keys.append(coloumns[primary_key[j][k] - 1][1])
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
combine_p
+ ", ".join(keys)
+ ") are the primary key."
+ "\n"
)
else:
print("not support type", type(primary_key[j]))
continue

# get foreign key info
for key in foreign_keys:
Expand Down Expand Up @@ -104,14 +126,14 @@ def decode_json_file(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
"output": interaction["query"],
"output": interaction[output_name],
"history": history,
}
res.append(input)
history.append(
(
INPUT_PROMPT.format(interaction["utterance"]),
interaction["query"],
interaction[output_name],
)
)
else: # 单轮
Expand All @@ -121,7 +143,7 @@ def decode_json_file(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"output": data[output_name],
"history": [],
}
res.append(input)
Expand All @@ -139,9 +161,12 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["train_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
is_multiple_turn=data_info["is_multiple_turn"],
)
)
Expand All @@ -154,9 +179,12 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["dev_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
is_multiple_turn=data_info["is_multiple_turn"],
)
)
Expand Down
206 changes: 206 additions & 0 deletions dbgpt_hub/eval/evaluation_bird.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
do evaluate about the predict sql in dataset BIRD,compare with default dev.sql
--db
"""

import sys
import json
import argparse
import sqlite3
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut


def load_json(dir):
with open(dir, "r") as j:
contents = json.loads(j.read())
return contents


def result_callback(result):
exec_result.append(result)


def execute_sql(predicted_sql, ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res


def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(
meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
res = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
res = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {"sql_idx": idx, "res": res}
# print(result)
return result


def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
clean_sqls = []
db_path_list = []
if mode == "gpt":
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
# for idx, sql_str in sql_data.items():
# if type(sql_str) == str:
# sql, db_name = sql_str.split('\t----- bird -----\t')
# else:
# sql, db_name = " ", "financial"
# clean_sqls.append(sql)
# db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
with open(sql_path) as f:
for l in f.readlines():
# if len(l.strip()) == 0:
# sql, db_name = " ", "financial"
# else:
# sql, db_name = l.split('\t')
clean_sqls.append(l.strip())
# db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
elif mode == "gt":
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split("\t")
clean_sqls.append(sql)
db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")

return clean_sqls, db_path_list


def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
pool = mp.Pool(processes=num_cpus)
for i, sql_pair in enumerate(sqls):
predicted_sql, ground_truth = sql_pair
pool.apply_async(
execute_model,
args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
callback=result_callback,
)
pool.close()
pool.join()


def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x["sql_idx"])


def compute_acc_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i, content in enumerate(contents):
if content["difficulty"] == "simple":
simple_results.append(exec_results[i])

if content["difficulty"] == "moderate":
moderate_results.append(exec_results[i])

if content["difficulty"] == "challenging":
challenging_results.append(exec_results[i])

simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res["res"] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
count_lists = [
len(simple_results),
len(moderate_results),
len(challenging_results),
num_queries,
]
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)


def print_data(score_lists, count_lists):
levels = ["simple", "moderate", "challenging", "total"]
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))

print(
"====================================== ACCURACY ====================================="
)
print(
"{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
)


if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument(
"--predicted_sql_path", type=str, required=True, default=""
)
args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
args_parser.add_argument("--db_root_path", type=str, required=True, default="")
args_parser.add_argument("--num_cpus", type=int, default=1)
args_parser.add_argument("--meta_time_out", type=float, default=30.0)
args_parser.add_argument("--mode_gt", type=str, default="gt")
args_parser.add_argument("--mode_predict", type=str, default="gpt")
args_parser.add_argument("--difficulty", type=str, default="simple")
args_parser.add_argument("--diff_json_path", type=str, default="")
args = args_parser.parse_args()
exec_result = []

pred_queries, db_paths = package_sqls(
args.predicted_sql_path,
args.db_root_path,
mode=args.mode_predict,
data_mode=args.data_mode,
)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(
args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
)

if len(db_paths) == 0:
db_paths = db_paths_gt

query_pairs = list(zip(pred_queries, gt_queries))
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
exec_result = sort_results(exec_result)

print("start calculate")
simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
exec_result, args.diff_json_path
)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists)
print(
"==========================================================================================="
)
print("Finished evaluation")
Loading