-
Notifications
You must be signed in to change notification settings - Fork 17
/
eval.py
450 lines (386 loc) · 18.2 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from dataclasses import asdict, dataclass
from typing import Any, List, Optional, Tuple, Union
import torch
import transformers
from tqdm import tqdm
from lm_eval import utils
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM, TemplateLM
from lm_eval.models.utils import pad_and_concat, Collator
from arguments import Arguments, simple_parse_args_string
from self_speculation.autoregressive_generator import AutoRegressiveGenerationStrategy
from self_speculation.generator_base import (
GenerationConfig,
GenerationResult,
GenerationStrategy,
HuggingfaceLlamaGenerator,
)
from self_speculation.self_speculation_generator import SelfSpeculativeGenerationStrategy
from generate import load_model_and_tokenizer, setup
from benchmark import EvaluationMetrics
@dataclass
class EvalArguments:
tasks: List[str] = None
num_fewshot: Optional[int] = None
device: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: Optional[int] = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: Optional[str] = None
apply_chat_template: Union[bool, str] = False
fewshot_as_multiturn: bool = False
gen_kwargs: Optional[str] = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
def all_dicts_same(dict_list):
if not dict_list: # Check if the list is empty
return True
# Compare each dictionary to the first one
first_dict = dict_list[0]
return all(d == first_dict for d in dict_list)
# Light wrapper around generator for lm-eval harness
class EvalHarnessLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
generator: HuggingfaceLlamaGenerator,
generation_config: GenerationConfig,
device: Union[str, torch.device],
logits_cache: bool = True,
batch_size: Optional[Union[int, str]] = 1,
add_bos_token: Optional[bool] = False,
max_length: Optional[int] = None,
):
super().__init__()
assert batch_size == 1, "Currently we only support batch size 1"
self.generator = generator
self.generation_config = generation_config
self.device = device
self.logits_cache = logits_cache
self.batch_size = batch_size
self.add_bos_token = add_bos_token
self._max_length = max_length
self.metric_result = None
def generate_until(self, requests: List[Instance]) -> List[str]:
prompts, gen_args = zip(*[req.args for req in requests])
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
gen_args = gen_args[0]
# TODO: remove "temperature", "top_p", and "top_k" from "gen_args"
until = gen_args.get("until", [])
self.generation_config.stop_words = until
generations = []
metrics = EvaluationMetrics.build_metrics()
for prompt in tqdm(prompts):
response: GenerationResult = self.generator.generate(
prompt=prompt,
generation_config=self.generation_config,
)
generations.append(response.decoded_prediction)
metrics.update(None, response)
self.metric_result = metrics.compute()
filtered_gen = []
for p, g in tqdm(zip(prompts, generations)):
for e in until:
g = g.split(e, 1)[0]
filtered_gen.append(g)
self.cache_hook.add_partial("generate_until", (p, gen_args), g)
return filtered_gen
# Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.generator.model.config, attr):
return getattr(self.generator.model.config, attr)
if hasattr(self.generator.tokenizer, "model_max_length"):
if self.generator.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.generator.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
# Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None:
special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token
}
# otherwise the method explicitly defines the value
else:
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
encoding = self.generator.tokenizer.encode(string, **special_tokens_kwargs)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
# Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
# Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts"
if self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
n_reordered_requests = len(re_ord)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks:
inps = []
cont_toks_list = []
inplens = []
conts = []
encoder_attns = []
padding_len_inp = None
padding_len_cont = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
multi_logits = torch.nn.functional.log_softmax(
self.generator.model(batched_inps, **call_kwargs).logits, dim=-1
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
)
logits = logits[inplen - contlen : ctx_len]
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
# check for one-token continuation cache hits.
# noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache(
req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
# Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.prefix_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
requests=rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
)
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def main(args: Arguments, eval_arguments: EvalArguments, generation_config: GenerationConfig):
device = "cuda" if torch.cuda.is_available() else "cpu"
setup(args, device=device)
transformers.utils.logging.set_verbosity_error()
model, tokenizer = load_model_and_tokenizer(args, device=device)
if generation_config.generation_strategy == "autoregressive":
generation_strategy: GenerationStrategy = AutoRegressiveGenerationStrategy()
elif generation_config.generation_strategy == "self_speculative":
generation_strategy: GenerationStrategy = SelfSpeculativeGenerationStrategy()
else:
raise Exception(
f"Unsupported generation strategy: {generation_config.generation_strategy}"
)
# initialize generator
generator = HuggingfaceLlamaGenerator(
tokenizer=tokenizer, model=model, generation_strategy=generation_strategy
)
# create evaluator
wrap = EvalHarnessLM(generator, generation_config, device)
# Warmup
warmup = 1
for _ in range(warmup):
model.generation_config.pad_token_id = tokenizer.eos_token_id
model.generate(**tokenizer("This is a warmup prompt", return_tensors="pt").to(device), max_new_tokens=10)
# Evaluate
results = simple_evaluate(wrap, **asdict(eval_arguments))
# TODO: log results, generation samples, etc.
print(results["results"])
wrap.metric_result.pop("predicted_text")
print(wrap.metric_result)
def process_cli_arguments() -> Tuple[Arguments, EvalArguments, GenerationConfig]:
parser = transformers.HfArgumentParser((Arguments, EvalArguments, GenerationConfig))
(
general_arguments,
eval_arguments,
generation_config,
_remaining,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if general_arguments.model_args:
general_arguments.model_args = simple_parse_args_string(general_arguments.model_args)
else:
general_arguments.model_args = {}
return general_arguments, eval_arguments, generation_config
if __name__ == "__main__":
args, eval_arguments, generation_config = process_cli_arguments()
main(args, eval_arguments, generation_config)