-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathppl.py
125 lines (94 loc) · 4.38 KB
/
ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
import argparse
import torch
import torch.nn as nn
from datasets import load_dataset
from gptq_triton import load_quant
from tqdm import tqdm
from transformers import AutoTokenizer, LlamaForCausalLM
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='Path to model, either a HuggingFace model or a quantized model')
parser.add_argument('--quant', action='store_true', help='Whether the model is quantized')
parser.add_argument('--stride', type=int, default=512, help='Stride for calculating perplexity')
parser.add_argument('--context-length', type=int, default=2048, help='Length of context to use')
def main():
args = parser.parse_args()
if not args.quant:
model = get_llama(args.model)
model.eval()
model.to('cuda')
else:
model = load_quant(args.model)
model.eval()
model.to('cuda')
# NOTE: Setting use_fast=False for now, as the alternative was an order of magnitude slower on a recent `transformers` commit
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
context_length = model.seqlen if args.context_length is None else args.context_length
for dataset in ['wikitext-2', 'ptb', 'c4']:
ppl = calculate_perplexity(model, tokenizer, dataset, max_length=context_length, stride=args.stride)
print(f"{dataset} perplexity: {ppl}")
def get_llama(model: str):
"""
Load a pretrained Llama model
"""
def skip(*args, **kwargs):
pass
# NOTE: This is a nasty hack, but it speeds up model building by a huge amount
old_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_)
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = 2048
# Restore the old initializers
torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = old_inits
return model
def get_dataset(dataset_name: str, tokenizer) -> torch.Tensor:
if dataset_name == "wikitext-2":
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt").input_ids
elif dataset_name == 'ptb':
test = load_dataset("ptb_text_only", 'penn_treebank', split="validation")
encodings = tokenizer("\n\n".join(test["sentence"]), return_tensors="pt").input_ids
elif dataset_name == 'c4':
# WARNING: Many of the files in the allenai/c4 repo are marked as "Unsafe" by HuggingFace, possibly containing a virus. This particular file is not, and I doubt it's an issue, but worth noting.
test = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
encodings = [tokenizer(x, return_tensors="pt").input_ids for x in test['text'][:1000]]
encodings = torch.cat(encodings, dim=1)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
return encodings
def calculate_perplexity(model, tokenizer, dataset: str, max_length: int, stride: int = 512) -> float:
print("Loading dataset...")
encodings = get_dataset(dataset, tokenizer)
seq_len = encodings.size(1)
print("Calculating perplexity...")
print(f"Sequence length: {seq_len}")
print(f"Max length: {max_length}")
print(f"Stride: {stride}")
nlls = []
prev_end_loc = 0
for begin_loc in (pbar := tqdm(range(0, seq_len - 1, stride))):
end_loc = min(seq_len - 1, begin_loc + max_length)
trg_len = end_loc - prev_end_loc # How many tokens we want to predict
input_ids = encodings[:, begin_loc:end_loc+1].to('cuda') # +1 for the labels
with torch.no_grad():
# Ask the model for logits
# NOTE: Instead of calling HF's model wrapper, we call the model directly to hopefully cut down on some memory overhead
outputs = model.model(input_ids[:, :-1], use_cache=False)
logits = model.lm_head(outputs[0][..., -trg_len:, :])
# The last trg_len tokens are the labels
labels = input_ids[:, -trg_len:].contiguous()
# Compute the NLL for this batch using flattened logits and labels
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
nlls.append(loss.to('cpu').to(torch.float32))
ppl = torch.exp(torch.stack(nlls).mean())
pbar.set_description(f"Perplexity: {ppl:.2f}")
prev_end_loc = end_loc
if end_loc == (seq_len - 1):
break
ppl = torch.exp(torch.stack(nlls).mean())
return ppl
if __name__ == '__main__':
main()