Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bird eval all zf #240

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 125 additions & 38 deletions dbgpt_hub/eval/evaluation_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
do evaluate about the predict sql in dataset BIRD,compare with default dev.sql
--db
"""

import sys
import json
import argparse
import sqlite3
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut
import math
import time


def load_json(dir):
Expand All @@ -25,32 +26,45 @@ def execute_sql(predicted_sql, ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
pred_start_time = time.time()
cursor.execute(predicted_sql)
pred_exec_time = time.time() - pred_start_time
predicted_res = cursor.fetchall()
true_start_time = time.time()
cursor.execute(ground_truth)
true_exec_time = time.time() - true_start_time
ground_truth_res = cursor.fetchall()
res = 0
time_ratio = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res
time_ratio = true_exec_time / pred_exec_time if pred_exec_time > 0 else 0
return res, time_ratio


def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(
res, time_ratio = func_timeout(
meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
res = 0
time_ratio = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
res = 0
time_ratio = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {"sql_idx": idx, "res": res}
result = {
"sql_idx": idx,
"res": res,
"match": int(predicted_sql == ground_truth),
"time_ratio": time_ratio,
}
# print(result)
return result

Expand All @@ -59,7 +73,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
clean_sqls = []
db_path_list = []
if mode == "gpt":
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', "r'))
# for idx, sql_str in sql_data.items():
# if type(sql_str) == str:
# sql, db_name = sql_str.split('\t----- bird -----\t')
Expand Down Expand Up @@ -104,9 +118,22 @@ def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x["sql_idx"])


def compute_acc_by_diff(exec_results, diff_json_path):
def compute_ves(exec_results):
num_queries = len(exec_results)
total_ratio = 0
count = 0

for i, result in enumerate(exec_results):
if result["time_ratio"] != 0:
count += 1
total_ratio += math.sqrt(result["time_ratio"]) * 100
ves = total_ratio / num_queries
return ves


def compute_acc_by_diff(exec_results, diff_json_path, metric):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
results = [res[metric] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

Expand All @@ -119,35 +146,47 @@ def compute_acc_by_diff(exec_results, diff_json_path):

if content["difficulty"] == "challenging":
challenging_results.append(exec_results[i])

simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res["res"] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
if metric in ["res", "match"]:
simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res[metric] for res in moderate_results]) / len(
moderate_results
)
challenging_acc = sum([res[metric] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
elif metric in ["time_ratio"]:
simple_acc = compute_ves(simple_results)
moderate_acc = compute_ves(moderate_results)
challenging_acc = compute_ves(challenging_results)
all_acc = compute_ves(exec_results)
else:
raise NotImplementedError(f"metric: {metric} is not supported")
count_lists = [
len(simple_results),
len(moderate_results),
len(challenging_results),
num_queries,
]
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)
if metric in ["res", "match"]:
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)
else:
return simple_acc, moderate_acc, challenging_acc, all_acc, count_lists


def print_data(score_lists, count_lists):
def print_data(score_lists, count_lists, metric="Exec ACCURACY"):
levels = ["simple", "moderate", "challenging", "total"]
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))

print(
"====================================== ACCURACY ====================================="
f"====================================== {metric} ====================================="
)
print(
"{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
Expand All @@ -157,17 +196,33 @@ def print_data(score_lists, count_lists):
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument(
"--predicted_sql_path", type=str, required=True, default=""
"--predicted_sql_path",
type=str,
default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql",
)
args_parser.add_argument(
"--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql"
)
args_parser.add_argument("--data_mode", type=str, default="dev")
args_parser.add_argument(
"--db_root_path",
type=str,
default="../../dbgpt_hub/data/bird/dev/dev_databases/",
)
args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
args_parser.add_argument("--db_root_path", type=str, required=True, default="")
args_parser.add_argument("--num_cpus", type=int, default=1)
args_parser.add_argument("--meta_time_out", type=float, default=30.0)
args_parser.add_argument("--mode_gt", type=str, default="gt")
args_parser.add_argument("--mode_predict", type=str, default="gpt")
args_parser.add_argument("--difficulty", type=str, default="simple")
args_parser.add_argument("--diff_json_path", type=str, default="")
args_parser.add_argument(
"--etype",
dest="etype",
type=str,
default="match",
choices=("all", "exec", "match", "ves"),
)

args = args_parser.parse_args()
exec_result = []

Expand All @@ -186,20 +241,52 @@ def print_data(score_lists, count_lists):
db_paths = db_paths_gt

query_pairs = list(zip(pred_queries, gt_queries))
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
if args.etype in ["all", "exec", "ves"]:
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
else:
for i, sql_pair in enumerate(query_pairs):
predicted_sql, ground_truth = sql_pair
exec_result.append(
{"sql_idx": i, "match": int(predicted_sql == ground_truth)}
)
exec_result = sort_results(exec_result)

print("start calculate")
simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
exec_result, args.diff_json_path
)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists)
if args.etype in ["all", "exec"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "res")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Exec Accuracy")
if args.etype in ["all", "match"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "match")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Match Accuracy")
if args.etype in ["all", "ves"]:
(
simple_acc,
moderate_acc,
challenging_acc,
acc,
count_lists,
) = compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio")
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists, metric="Ves")
print(
"==========================================================================================="
)
Expand Down
Loading