6
6
from gpt_server .model_worker .base .model_worker_base import ModelWorkerBase
7
7
8
8
9
- def build_chat_input (tokenizer , messages : List [dict ], max_new_tokens : int = 0 ):
10
- user_token_id = 195
11
- assistant_token_id = 196
12
-
13
- def _parse_messages (messages , split_role = "user" ):
14
- system , rounds = "" , []
15
- round = []
16
- for i , message in enumerate (messages ):
17
- if message ["role" ] == "system" :
18
- assert i == 0
19
- system = message ["content" ]
20
- continue
21
- if message ["role" ] == split_role and round :
22
- rounds .append (round )
23
- round = []
24
- round .append (message )
25
- if round :
26
- rounds .append (round )
27
- return system , rounds
28
-
29
- max_new_tokens = max_new_tokens or 2048
30
- max_input_tokens = 4096 - max_new_tokens
31
- system , rounds = _parse_messages (messages , split_role = "user" )
32
- system_tokens = tokenizer .encode (system )
33
- max_history_tokens = max_input_tokens - len (system_tokens )
34
-
35
- history_tokens = []
36
- for round in rounds [::- 1 ]:
37
- round_tokens = []
38
- for message in round :
39
- if message ["role" ] == "user" :
40
- round_tokens .append (user_token_id )
41
- else :
42
- round_tokens .append (assistant_token_id )
43
- round_tokens .extend (tokenizer .encode (message ["content" ]))
44
- if (
45
- len (history_tokens ) == 0
46
- or len (history_tokens ) + len (round_tokens ) <= max_history_tokens
47
- ):
48
- history_tokens = round_tokens + history_tokens # concat left
49
- if len (history_tokens ) < max_history_tokens :
50
- continue
51
- break
52
-
53
- input_tokens = system_tokens + history_tokens
54
- if messages [- 1 ]["role" ] != "assistant" :
55
- input_tokens .append (assistant_token_id )
56
- input_tokens = input_tokens [- max_input_tokens :] # truncate left
57
- return torch .LongTensor ([input_tokens ])
58
-
59
-
60
9
class BaiChuanWorker (ModelWorkerBase ):
61
10
def __init__ (
62
11
self ,
@@ -78,9 +27,7 @@ def __init__(
78
27
conv_template ,
79
28
model_type = "AutoModelForCausalLM" ,
80
29
)
81
- self .stop_words_ids = [
82
- 2 , # </s>
83
- ]
30
+ self .stop_words_ids = []
84
31
self .stop = [
85
32
self .tokenizer .decode (skip_word ) for skip_word in self .stop_words_ids
86
33
]
@@ -89,29 +36,11 @@ def __init__(
89
36
async def generate_stream_gate (self , params ):
90
37
self .call_ct += 1
91
38
try :
92
- messages = params ["messages" ]
93
- if isinstance (messages , list ):
94
- task = "chat"
95
- elif isinstance (messages , str ):
96
- task = "completion"
97
- if task == "chat" :
98
- input_ids = build_chat_input (
99
- tokenizer = self .tokenizer , messages = messages
100
- )
101
- text = self .tokenizer .decode (input_ids .tolist ()[0 ])
102
- elif task == "completion" :
103
- text = messages
104
- input_ids = self .tokenizer ([text ], return_tensors = "pt" ).input_ids
105
-
106
- params ["messages" ] = messages
107
- params ["prompt" ] = text
108
39
params ["stop" ].extend (self .stop )
109
40
params ["stop_words_ids" ] = self .stop_words_ids
110
- params ["input_ids" ] = input_ids
111
41
112
42
async for ret in self .backend .stream_chat (params = params ):
113
43
response = ret ["text" ]
114
-
115
44
yield json .dumps (ret ).encode () + b"\0 "
116
45
117
46
except torch .cuda .OutOfMemoryError as e :
0 commit comments