-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathifeval_eval.py
178 lines (143 loc) · 7.22 KB
/
ifeval_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Ref: https://github.com/kojima-takeshi188/zero_shot_cot
# Ref: https://github.com/sylinrl/TruthfulQA/blob/main/truthfulqa/metrics.py
# Ref: https://github.com/sylinrl/TruthfulQA/blob/main/truthfulqa/utilities.py
import re
import os
import json
import random
import transformers
from tqdm import tqdm
import argparse
import pandas as pd
import ssl
import urllib.request
import zipfile
from dola_t5 import DoLa
transformers.logging.set_verbosity(40)
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
N_SHOT = 3
COT_FLAG = True
DEBUG = False
ANSWER_TRIGGER = "So the answer is"
def load_jsonl(file_path):
with open(file_path) as f:
list_prompts = [json.loads(line)['prompt'] for line in f]
return list_prompts
def create_demo_text():
question, answer = [], []
question.append("Write a sentence describing the flavor of coffee. Make sure the word 'roasted' appears at least two times in the sentence, and include a bolded word. Like: *this is bolded text*.\"")
answer.append("The bold, *roasted* flavor of coffee envelopes the palate, infusing each sip with rich, *roasted* notes reminiscent of toasted caramel and dark chocolate.") # Based on answer_index -3
question.append("List the months of the year using all capital letters.")
answer.append("JANUARY, FEBRUARY, MARCH, APRIL, MAY, JUNE, JULY, AUGUST, SEPTEMBER, NOVEMBER, DECEMBER.")
demo_text = 'Take note of the instructions and responses in the following examples:' + '\n\n'
for i in range(len(question)):
demo_text += f'Example {i}: ' + "\nInstruction" + question[i] + "\nResponse" + answer[i] + "\n\n"
return demo_text
def build_prompt(input_text):
demo = create_demo_text()
input_text_prompt = demo + "Now your task is: " + input_text
return input_text_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
parser.add_argument("--num-gpus", type=str, default="1")
parser.add_argument("--max_gpu_memory", type=int, default=27)
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
parser.add_argument("--data-path", type=str, default="./tfqa")
parser.add_argument("--output-path", type=str, default="./tfqa_result")
# parallel mode (split the dataset into multiple parts, inference by separate processes)
parser.add_argument("--early-exit-layers", type=str, default="-1")
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--total-shard", type=int, default=8)
parser.add_argument("--shard-id", type=int, default=None)
parser.add_argument("--do-rating", action="store_true")
parser.add_argument("--gpt3-config", type=str, default=None)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=1024)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.9)
parser.add_argument("--repetition_penalty", type=float, default=None)
parser.add_argument("--relative_top", type=float, default=0.1)
parser.add_argument("--print-logits", action="store_true")
args = parser.parse_args()
model_name = args.model_name
num_gpus = args.num_gpus
device = args.device
data_path = args.data_path
# Get test file
fp = data_path + 'ifeval-input-data.jsonl'
list_data_dict = load_jsonl(fp)
if args.debug:
list_data_dict = list_data_dict[:10]
if args.parallel:
chunk_size = len(list_data_dict) // args.total_shard
list_data_dict = list_data_dict[args.shard_id * chunk_size: (args.shard_id + 1) * chunk_size]
llm = DoLa(model_name, device, num_gpus, args.max_gpu_memory)
# stop_word_list = ["Q:"]
# llm.set_stop_words(stop_word_list)
early_exit_layers = [int(x) for x in args.early_exit_layers.split(',')]
if len(early_exit_layers) == 1:
print("MODE: naive decoding from the last layer", flush=True)
mode = "baseline"
mature_layer = None
premature_layer = None
candidate_premature_layers = None
if args.repetition_penalty is None:
args.repetition_penalty = 1.2
elif len(early_exit_layers) == 2:
print(f"MODE: DoLa-static decoding with mature layer: {early_exit_layers[1]} and premature layer: {early_exit_layers[0]}")
mode = "early_exit_contrastive"
mature_layer = early_exit_layers[1]
premature_layer = early_exit_layers[0]
candidate_premature_layers = None
if args.repetition_penalty is None:
args.repetition_penalty = 1.2
else:
print(f"MODE: DoLa decoding with mature layer: {early_exit_layers[-1]} and premature layers: {early_exit_layers[:-1]}")
mode = "dola"
mature_layer = early_exit_layers[-1]
premature_layer = None
candidate_premature_layers = early_exit_layers[:-1]
premature_layer_dist = {l:0 for l in candidate_premature_layers}
if args.repetition_penalty is None:
args.repetition_penalty = 1.2
results = []
for i, prompt in enumerate(tqdm(list_data_dict)):
result_dict = {}
# input_text = build_prompt(prompt)
input_text = prompt
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, top_p=args.top_p,
top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty,
mode=mode, mature_layer=mature_layer, premature_layer=premature_layer,
candidate_premature_layers=candidate_premature_layers, print_logits=args.print_logits)
model_completion, c_dist = llm.generate(input_text, **generate_kwargs)
# for stop_word in stop_word_list:
# length_to_remove = len(stop_word)
# if model_completion[-length_to_remove:] == stop_word:
# model_completion = model_completion[:-length_to_remove]
model_completion = model_completion.strip()
if mode == "dola":
for k, v in c_dist.items():
premature_layer_dist[k] += v
result_dict['prompt'] = prompt
result_dict['response'] = model_completion
results.append(result_dict)
if DEBUG:
print(f'Full input_text:\n{input_text}\n\n')
print(f'Question: {prompt}\n\n'
f'Model Completion: {model_completion}\n\n')
if mode == "dola" and args.debug:
total_tokens = sum(premature_layer_dist.values())
if total_tokens > 0:
for l in candidate_premature_layers:
print('Premature layer {0} was used {1} times, {2}%'.format(l, premature_layer_dist[l], round(premature_layer_dist[l] / total_tokens * 100, 2)))
# save results to a json file
model_tag = model_name.split('/')[-1] if model_name[-1] != '/' else model_name.split('/')[-2]
output_file = args.output_path if args.shard_id is None else (args.output_path+"_"+str(args.shard_id)+".jsonl")
# Write out in jsonl format
with open(output_file, 'w') as f:
for result in results:
result_json_str = json.dumps(result)
f.write(result_json_str + '\n')