Skip to content

Commit

Permalink
Feature/bird zf (#205)
Browse files Browse the repository at this point in the history
Co-authored-by: moutozf <[email protected]>
Co-authored-by: qidanrui <[email protected]>
  • Loading branch information
3 people authored Jan 2, 2024
1 parent 22a1675 commit c65d263
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 10 deletions.
14 changes: 13 additions & 1 deletion dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,22 @@
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"train_tables_file": "tables.json",
"dev_tables_file": "tables.json",
"db_id_name": "db_id",
"output_name": "query",
"is_multiple_turn": False,
}
# {
# "data_source": "bird",
# "train_file": ["train/train.json"],
# "dev_file": ["dev/dev.json"],
# "train_tables_file": "train/train_tables.json",
# "dev_tables_file": "dev/dev_tables.json",
# "db_id_name": "db_id",
# "output_name": "SQL",
# "is_multiple_turn": False,
# }
# ,
# {
# "data_source": "chase",
Expand Down
46 changes: 37 additions & 9 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
self,
data_file_list,
table_file,
db_id_name,
output_name,
is_multiple_turn=False,
):
"""
TO DO:
Expand Down Expand Up @@ -65,12 +70,29 @@ def decode_json_file(

# get primary key info
for j in range(len(primary_key)):
if coloumns[primary_key[j] - 1][0] == i:
if type(primary_key[j]) == int:
if coloumns[primary_key[j] - 1][0] == i:
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
)
# combination primary key
elif type(primary_key[j]) == list:
combine_p = "The combination of ("
keys = []
for k in range(len(primary_key[j])):
if coloumns[primary_key[j][k] - 1][0] == i:
keys.append(coloumns[primary_key[j][k] - 1][1])
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
combine_p
+ ", ".join(keys)
+ ") are the primary key."
+ "\n"
)
else:
print("not support type", type(primary_key[j]))
continue

# get foreign key info
for key in foreign_keys:
Expand Down Expand Up @@ -104,14 +126,14 @@ def decode_json_file(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
"output": interaction["query"],
"output": interaction[output_name],
"history": history,
}
res.append(input)
history.append(
(
INPUT_PROMPT.format(interaction["utterance"]),
interaction["query"],
interaction[output_name],
)
)
else: # 单轮
Expand All @@ -121,7 +143,7 @@ def decode_json_file(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"output": data[output_name],
"history": [],
}
res.append(input)
Expand All @@ -139,9 +161,12 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["train_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
is_multiple_turn=data_info["is_multiple_turn"],
)
)
Expand All @@ -154,9 +179,12 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["dev_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
is_multiple_turn=data_info["is_multiple_turn"],
)
)
Expand Down
206 changes: 206 additions & 0 deletions dbgpt_hub/eval/evaluation_bird.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
do evaluate about the predict sql in dataset BIRD,compare with default dev.sql
--db
"""

import sys
import json
import argparse
import sqlite3
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut


def load_json(dir):
with open(dir, "r") as j:
contents = json.loads(j.read())
return contents


def result_callback(result):
exec_result.append(result)


def execute_sql(predicted_sql, ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res


def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(
meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
res = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
res = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {"sql_idx": idx, "res": res}
# print(result)
return result


def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
clean_sqls = []
db_path_list = []
if mode == "gpt":
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
# for idx, sql_str in sql_data.items():
# if type(sql_str) == str:
# sql, db_name = sql_str.split('\t----- bird -----\t')
# else:
# sql, db_name = " ", "financial"
# clean_sqls.append(sql)
# db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
with open(sql_path) as f:
for l in f.readlines():
# if len(l.strip()) == 0:
# sql, db_name = " ", "financial"
# else:
# sql, db_name = l.split('\t')
clean_sqls.append(l.strip())
# db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
elif mode == "gt":
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split("\t")
clean_sqls.append(sql)
db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")

return clean_sqls, db_path_list


def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
pool = mp.Pool(processes=num_cpus)
for i, sql_pair in enumerate(sqls):
predicted_sql, ground_truth = sql_pair
pool.apply_async(
execute_model,
args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
callback=result_callback,
)
pool.close()
pool.join()


def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x["sql_idx"])


def compute_acc_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i, content in enumerate(contents):
if content["difficulty"] == "simple":
simple_results.append(exec_results[i])

if content["difficulty"] == "moderate":
moderate_results.append(exec_results[i])

if content["difficulty"] == "challenging":
challenging_results.append(exec_results[i])

simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res["res"] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
count_lists = [
len(simple_results),
len(moderate_results),
len(challenging_results),
num_queries,
]
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)


def print_data(score_lists, count_lists):
levels = ["simple", "moderate", "challenging", "total"]
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))

print(
"====================================== ACCURACY ====================================="
)
print(
"{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
)


if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument(
"--predicted_sql_path", type=str, required=True, default=""
)
args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
args_parser.add_argument("--db_root_path", type=str, required=True, default="")
args_parser.add_argument("--num_cpus", type=int, default=1)
args_parser.add_argument("--meta_time_out", type=float, default=30.0)
args_parser.add_argument("--mode_gt", type=str, default="gt")
args_parser.add_argument("--mode_predict", type=str, default="gpt")
args_parser.add_argument("--difficulty", type=str, default="simple")
args_parser.add_argument("--diff_json_path", type=str, default="")
args = args_parser.parse_args()
exec_result = []

pred_queries, db_paths = package_sqls(
args.predicted_sql_path,
args.db_root_path,
mode=args.mode_predict,
data_mode=args.data_mode,
)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(
args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
)

if len(db_paths) == 0:
db_paths = db_paths_gt

query_pairs = list(zip(pred_queries, gt_queries))
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
exec_result = sort_results(exec_result)

print("start calculate")
simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
exec_result, args.diff_json_path
)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists)
print(
"==========================================================================================="
)
print("Finished evaluation")

0 comments on commit c65d263

Please sign in to comment.