Skip to content

Commit

Permalink
Bird eval all zf (#240)
Browse files Browse the repository at this point in the history
great work, thanks a lot  ~
  • Loading branch information
moutozf authored Mar 12, 2024
1 parent c329ba1 commit e8fadb2
Showing 1 changed file with 125 additions and 38 deletions.
163 changes: 125 additions & 38 deletions dbgpt_hub/eval/evaluation_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
do evaluate about the predict sql in dataset BIRD,compare with default dev.sql
--db
"""

import sys
import json
import argparse
import sqlite3
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut
import math
import time


def load_json(dir):
Expand All @@ -25,32 +26,45 @@ def execute_sql(predicted_sql, ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
pred_start_time = time.time()
cursor.execute(predicted_sql)
pred_exec_time = time.time() - pred_start_time
predicted_res = cursor.fetchall()
true_start_time = time.time()
cursor.execute(ground_truth)
true_exec_time = time.time() - true_start_time
ground_truth_res = cursor.fetchall()
res = 0
time_ratio = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res
time_ratio = true_exec_time / pred_exec_time if pred_exec_time > 0 else 0
return res, time_ratio


def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(
res, time_ratio = func_timeout(
meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
res = 0
time_ratio = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
res = 0
time_ratio = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {"sql_idx": idx, "res": res}
result = {
"sql_idx": idx,
"res": res,
"match": int(predicted_sql == ground_truth),
"time_ratio": time_ratio,
}
# print(result)
return result

Expand All @@ -59,7 +73,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
clean_sqls = []
db_path_list = []
if mode == "gpt":
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', "r'))
# for idx, sql_str in sql_data.items():
# if type(sql_str) == str:
# sql, db_name = sql_str.split('\t----- bird -----\t')
Expand Down Expand Up @@ -104,9 +118,22 @@ def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x["sql_idx"])


def compute_acc_by_diff(exec_results, diff_json_path):
def compute_ves(exec_results):
num_queries = len(exec_results)
total_ratio = 0
count = 0

for i, result in enumerate(exec_results):
if result["time_ratio"] != 0:
count += 1
total_ratio += math.sqrt(result["time_ratio"]) * 100
ves = total_ratio / num_queries
return ves


def compute_acc_by_diff(exec_results, diff_json_path, metric):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
results = [res[metric] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

Expand All @@ -119,35 +146,47 @@ def compute_acc_by_diff(exec_results, diff_json_path):

if content["difficulty"] == "challenging":
challenging_results.append(exec_results[i])

simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res["res"] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
if metric in ["res", "match"]:
simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res[metric] for res in moderate_results]) / len(
moderate_results
)
challenging_acc = sum([res[metric] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
elif metric in ["time_ratio"]:
simple_acc = compute_ves(simple_results)
moderate_acc = compute_ves(moderate_results)
challenging_acc = compute_ves(challenging_results)
all_acc = compute_ves(exec_results)
else:
raise NotImplementedError(f"metric: {metric} is not supported")
count_lists = [
len(simple_results),
len(moderate_results),
len(challenging_results),
num_queries,
]
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)
if metric in ["res", "match"]:
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)
else:
return simple_acc, moderate_acc, challenging_acc, all_acc, count_lists


def print_data(score_lists, count_lists):
def print_data(score_lists, count_lists, metric="Exec ACCURACY"):
levels = ["simple", "moderate", "challenging", "total"]
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))

print(
"====================================== ACCURACY ====================================="
f"====================================== {metric} ====================================="
)
print(
"{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
Expand All @@ -157,17 +196,33 @@ def print_data(score_lists, count_lists):
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument(
"--predicted_sql_path", type=str, required=True, default=""
"--predicted_sql_path",
type=str,
default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql",
)
args_parser.add_argument(
"--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql"
)
args_parser.add_argument("--data_mode", type=str, default="dev")
args_parser.add_argument(
"--db_root_path",
type=str,
default="../../dbgpt_hub/data/bird/dev/dev_databases/",
)
args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
args_parser.add_argument("--db_root_path", type=str, required=True, default="")
args_parser.add_argument("--num_cpus", type=int, default=1)
args_parser.add_argument("--meta_time_out", type=float, default=30.0)
args_parser.add_argument("--mode_gt", type=str, default="gt")
args_parser.add_argument("--mode_predict", type=str, default="gpt")
args_parser.add_argument("--difficulty", type=str, default="simple")
args_parser.add_argument("--diff_json_path", type=str, default="")
args_parser.add_argument(
"--etype",
dest="etype",
type=str,
default="match",
choices=("all", "exec", "match", "ves"),
)

args = args_parser.parse_args()
exec_result = []

Expand All @@ -186,20 +241,52 @@ def print_data(score_lists, count_lists):
db_paths = db_paths_gt

query_pairs = list(zip(pred_queries, gt_queries))
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
if args.etype in ["all", "exec", "ves"]:
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
else:
for i, sql_pair in enumerate(query_pairs):
predicted_sql, ground_truth = sql_pair
exec_result.append(
{"sql_idx": i, "match": int(predicted_sql == ground_truth)}
)
exec_result = sort_results(exec_result)

print("start calculate")
simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
exec_result, args.diff_json_path
)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists)
if args.etype in ["all", "exec"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "res")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Exec Accuracy")
if args.etype in ["all", "match"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "match")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Match Accuracy")
if args.etype in ["all", "ves"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Ves")
print(
"==========================================================================================="
)
Expand Down

0 comments on commit e8fadb2

Please sign in to comment.