forked from Cornell-RelaxML/quip-sharp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_ppl.py
67 lines (55 loc) · 2.35 KB
/
eval_ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import math
import json
import argparse
import torch
import datasets
from lib.utils import gptq_data_utils
from lib.utils.unsafe_import import model_from_hf_path
import random
import glog
from tqdm import tqdm
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--hf_path', default='hfized/quantized_hada_70b', type=str)
parser.add_argument('--seqlen', default=4096, type=int)
parser.add_argument('--no_use_cuda_graph', action='store_true')
parser.add_argument('--no_use_flash_attn', action='store_true')
def main(args):
datasets = ['wikitext2', 'c4']
model, model_str = model_from_hf_path(args.hf_path,
use_cuda_graph=not args.no_use_cuda_graph,
use_flash_attn=not args.no_use_flash_attn)
for dataset in datasets:
input_tok = gptq_data_utils.get_test_tokens(dataset,
seed=args.seed,
seqlen=args.seqlen,
model=model_str)
nsamples = input_tok.numel() // args.seqlen
input_tok = input_tok[0, :(args.seqlen * nsamples)].view(nsamples, args.seqlen)
if not args.no_use_cuda_graph:
model.reset()
loss_fct = torch.nn.CrossEntropyLoss().cuda()
acc_loss = 0.0
progress = tqdm(range(nsamples))
for ii in progress:
input = input_tok[ii, :].cuda().view(1, -1)
output = model(input,
use_cache=False,
output_hidden_states=False,
output_attentions=False)[0]
shift_logits = output[:, :-1, :].contiguous()
shift_labels = input[:, 1:]
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/(ii+1)}")
avg_loss = acc_loss / nsamples
ppl = torch.exp(torch.tensor(avg_loss)).item()
glog.info(f'{dataset} perplexity: {ppl}')
if __name__ == '__main__':
torch.set_grad_enabled(False)
args = parser.parse_args()
random.seed(args.seed)
torch.random.manual_seed(args.seed)
main(args)