-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
144 lines (123 loc) · 4.28 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Doing evaluate stuff
"""
import itertools
import threading
from typing import Type, TypeVar
from loader import load_json, load_jsonl, Problem, MultiChoiceProblem, AnswerType, GSM8K
from solver import CoTSolver
from logger import ThreadLogger
logger = ThreadLogger()
P = TypeVar("P", bound=Problem)
M = TypeVar("M", bound=MultiChoiceProblem)
S = TypeVar("S", bound=CoTSolver)
def answer_equal(answer: str, output: str, answer_type: AnswerType) -> bool:
"""
Check if the answer is equal to the output
"""
def num_equal(lhs: str, rhs: str) -> bool:
eps = 1e-4
try:
return abs(float(lhs) - float(rhs)) < eps
except ValueError:
return False
def option_equal(lhs: str, rhs: str) -> bool:
return lhs.lower() == rhs.lower()
def boolean_equal(lhs: str, rhs: str) -> bool:
"""
Assuming answer is yes or no. May change in the future.
"""
return lhs.strip(".").lower() == rhs.strip(".").lower()
if answer_type is AnswerType.Number:
return num_equal(answer, output)
elif answer_type is AnswerType.Option:
return option_equal(answer, output)
elif answer_type is AnswerType.Boolean:
return boolean_equal(answer, output)
else:
assert False
def evaluate_dataset(
file_path: str,
dataset: Type[P],
solver: Type[S],
answer_type: AnswerType,
range_arg: range | None = None,
model_name: str | None = None,
):
"""
Evaluate the accuracy of the dataset, applicable to numerical and multiple-choice questions
Args:
file_path: Dataset file path
dataset: Dataset type
solver: Solver type
range_arg: The range of the problems to be loaded
model_name: Model name
is_numerical: Whether it's a numerical dataset
"""
dataset = (
load_json(file_path, dataset, range_arg=range_arg)
if dataset.file_format() == "json"
else load_jsonl(file_path, dataset, range_arg=range_arg)
)
tot_cnt = len(dataset)
cot_solver = solver(model_name=model_name)
correct_cnt = 0
for index, problem in enumerate(dataset):
logger.info(f"Running case {index + 1}... Total {tot_cnt}")
cot_solver.set_problem(problem.problem())
if answer_type is AnswerType.Number:
output = cot_solver.solve_numerical()
elif answer_type is AnswerType.Option:
output = cot_solver.solve_multichoice(problem.options())
elif answer_type is AnswerType.Boolean:
output = cot_solver.solve_boolean(["yes", "no"])
else:
assert False
answer = problem.answer()
if answer_equal(answer, output, answer_type):
correct_cnt += 1
else:
logger.warning(
f"Solving failed {index + 1}!!! Expected {answer}, Got {output}."
)
cot_solver.agent.debug()
logger.info(f"In case {index + 1}, correct {correct_cnt}.")
logger.info(f"{solver.__name__} solver accuracy: {correct_cnt / tot_cnt}")
def evaluate_in_threads(
solvers: list[Type[S]],
datasets: list[Type[P]],
range_arg: range | None = None,
model: str = "gpt-4o-mini",
debug: bool = False,
):
"""
Evaluate datasets and solvers simultaneously.
"""
group = itertools.product(solvers, datasets)
threads = []
for solver, dataset in group:
log_file = f"./logs/{solver.__name__}_{dataset.__name__}.log"
if range_arg is None and dataset is GSM8K:
range_arg = range(0, 400)
dataset_path = f"./dataset/{dataset.__name__}.{dataset.file_format()}"
evaluation_thread = threading.Thread(
target=evaluate_dataset,
kwargs={
"file_path": dataset_path,
"dataset": dataset,
"solver": solver,
"range_arg": range_arg,
"answer_type": dataset.answer_type(),
"model_name": model,
},
)
threads.append(evaluation_thread)
evaluation_thread.start()
logger.bind(
evaluation_thread.ident,
log_file,
"DEBUG" if debug else "INFO",
)
print(f"Starting evaluation for {solver} on {dataset}")
for thread in threads:
thread.join()