forked from openai/evals
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassify.py
127 lines (110 loc) · 4.57 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Generic eval that uses a prompt + classification.
"""
from collections import Counter
from random import Random
from typing import Any, Optional, Union
import evals
import evals.record
from evals.elsuite.modelgraded.classify_utils import classify, sample_and_concat_n_completions
from evals.elsuite.utils import PromptFn, scrub_formatting_from_prompt
class ModelBasedClassify(evals.Eval):
def __init__(
self,
modelgraded_spec: str,
*args,
modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None,
sample_kwargs: Optional[dict[str, Any]] = None,
eval_kwargs: Optional[dict[str, Any]] = None,
multicomp_n: Union[int, str] = 1,
eval_type: Optional[str] = None,
match_fn: Optional[str] = None,
metaeval: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
# treat last completion_fn as eval_completion_fn
self.eval_completion_fn = self.completion_fns[-1]
if len(self.completion_fns) > 1:
self.completion_fns = self.completion_fns[:-1]
n_models = len(self.completion_fns)
self.sample_kwargs = {"max_tokens": 1024}
self.sample_kwargs.update(sample_kwargs or {})
self.eval_kwargs = {"max_tokens": 1024}
self.eval_kwargs.update(eval_kwargs or {})
self.metaeval = metaeval
self.modelgraded_spec_args = modelgraded_spec_args or {}
self.eval_type = eval_type
self.match_fn = match_fn
if multicomp_n == "from_models":
assert n_models > 1
self.multicomp_n = n_models
else:
assert isinstance(multicomp_n, int)
self.multicomp_n = multicomp_n
if len(self.completion_fns) > 1:
assert self.multicomp_n == n_models
self.mg = self.registry.get_modelgraded_spec(modelgraded_spec)
def eval_sample(self, test_sample: dict, rng: Random) -> None:
"""Evaluate a single sample.
Recorded metrics are always: one of the self.choice_strings, or "__invalid__".
"""
# process test_sample
for k in self.mg.input_outputs:
test_sample[k] = scrub_formatting_from_prompt(test_sample[k])
# run policy completions
completions = {}
for k, v in self.mg.input_outputs.items():
if v in test_sample: # test_sample already has completion, skip.
continue
if self.multicomp_n > 1:
completion = sample_and_concat_n_completions(
self.completion_fns,
prompt=test_sample[k],
template_i=self.mg.output_template,
sample_kwargs=self.sample_kwargs,
n=self.multicomp_n,
)
else:
get_input_completion = PromptFn(
test_sample[k], completion_fn=self.completion_fn, **self.sample_kwargs
)
completion, _ = get_input_completion()
completions[v] = completion
# run modelgraded eval
metrics = {}
choice, info = classify(
mg=self.mg,
completion_fn=self.eval_completion_fn,
completion_kwargs=self.eval_kwargs,
eval_type=self.eval_type,
n=self.multicomp_n,
match_fn=self.match_fn,
format_kwargs={**completions, **test_sample, **self.modelgraded_spec_args},
)
metrics.update(dict(choice=choice, score=info["score"]))
# run metaeval if requested
if self.metaeval:
assert "choice" in test_sample
metrics["metascore"] = choice == test_sample["choice"]
evals.record.record_metrics(**metrics)
return choice
def run(self, recorder):
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
record_metrics = {}
all_sample_metrics = recorder.get_metrics()
if not all_sample_metrics:
return record_metrics
# record the counts
choices = [m["choice"] for m in all_sample_metrics]
counts = dict(Counter(choices))
record_metrics.update({f"counts/{k}": v for k, v in counts.items()})
# record the scores
scores = [m["score"] for m in all_sample_metrics if m["score"] is not None]
if scores:
record_metrics["score"] = sum(scores) / len(scores)
metascores = [m["metascore"] for m in all_sample_metrics if "metascore" in m]
if metascores:
record_metrics["metascore"] = sum(metascores) / len(metascores)
return record_metrics