-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathlanguage_models_bkp.py
127 lines (111 loc) · 4.24 KB
/
language_models_bkp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# import openai
import gc
import os
import time
from typing import Dict, List
import anthropic
# import google.generativeai as palm
import torch
class LanguageModel():
def __init__(self, model_name):
self.model_name = model_name
def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
"""
Generates responses for a batch of prompts using a language model.
"""
raise NotImplementedError
class HuggingFace(LanguageModel):
def __init__(self,model_name, model, tokenizer):
self.model_name = model_name
self.model = model
self.tokenizer = tokenizer
self.eos_token_ids = [self.tokenizer.eos_token_id]
def batched_generate(self,
full_prompts_list,
max_n_tokens: int,
temperature: float,
top_p: float = 1.0,):
inputs = self.tokenizer(full_prompts_list, return_tensors='pt', padding=True)
inputs = {k: v.to(self.model.device.index) for k, v in inputs.items()}
# Batch generation
if temperature > 0:
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_n_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.eos_token_ids,
top_p=top_p,
)
else:
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_n_tokens,
do_sample=False,
eos_token_id=self.eos_token_ids,
top_p=1,
temperature=1, # To prevent warning messages
)
# If the model is not an encoder-decoder type, slice off the input tokens
if not self.model.config.is_encoder_decoder:
output_ids = output_ids[:, inputs["input_ids"].shape[1]:]
# Batch decoding
outputs_list = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
for key in inputs:
inputs[key].to('cpu')
output_ids.to('cpu')
del inputs, output_ids
gc.collect()
torch.cuda.empty_cache()
return outputs_list
def extend_eos_tokens(self):
# Add closing braces for Vicuna/Llama eos when using attacker model
self.eos_token_ids.extend([
self.tokenizer.encode("}")[1],
29913,
9092,
16675])
class GPT(LanguageModel):
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
API_QUERY_SLEEP = 0.5
API_MAX_RETRY = 5
API_TIMEOUT = 20
# openai.api_key = os.getenv("OPENAI_API_KEY")
def generate(self, conv: List[Dict],
max_n_tokens: int,
temperature: float,
top_p: float):
'''
Args:
conv: List of dictionaries, OpenAI API format
max_n_tokens: int, max number of tokens to generate
temperature: float, temperature for sampling
top_p: float, top p for sampling
Returns:
str: generated response
'''
output = self.API_ERROR_OUTPUT
for _ in range(self.API_MAX_RETRY):
try:
response = openai.ChatCompletion.create(
model = self.model_name,
messages = conv,
max_tokens = max_n_tokens,
temperature = temperature,
top_p = top_p,
request_timeout = self.API_TIMEOUT,
)
output = response["choices"][0]["message"]["content"]
break
except openai.error.OpenAIError as e:
print(type(e), e)
time.sleep(self.API_RETRY_SLEEP)
time.sleep(self.API_QUERY_SLEEP)
return output
def batched_generate(self,
convs_list: List[List[Dict]],
max_n_tokens: int,
temperature: float,
top_p: float = 1.0,):
return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]