forked from jingyaogong/minimind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfast_infenence.py
124 lines (101 loc) · 4.2 KB
/
fast_infenence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
st.set_page_config(page_title="MiniMind-V1 108M(无历史上文)")
st.title("MiniMind-V1 108M(无历史上文)")
model_id = "minimind-v1"
# -----------------------------------------------------------------------------
temperature = 0.7
top_k = 8
max_seq_len = 1 * 1024
# -----------------------------------------------------------------------------
@st.cache_resource
def load_model_tokenizer():
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False,
trust_remote_code=True
)
model = model.eval()
generation_config = GenerationConfig.from_pretrained(model_id)
return model, tokenizer, generation_config
def clear_chat_messages():
del st.session_state.messages
def init_chat_messages():
with st.chat_message("assistant", avatar='🤖'):
st.markdown("您好,我是由Joya开发的MiniMind,很高兴为您服务😄")
if "messages" in st.session_state:
for message in st.session_state.messages:
avatar = "🧑💻" if message["role"] == "user" else "🤖"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
else:
st.session_state.messages = []
return st.session_state.messages
# max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 1024, 512, step=1)
# top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
# top_k = st.sidebar.slider("top_k", 0, 100, 0, step=1)
# temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step=0.01)
# do_sample = st.sidebar.checkbox("do_sample", value=False)
def main():
model, tokenizer, generation_config = load_model_tokenizer()
messages = init_chat_messages()
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
with st.chat_message("user", avatar='🧑💻'):
st.markdown(prompt)
messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar='🤖'):
placeholder = st.empty()
chat_messages = []
chat_messages.append({"role": "user", "content": prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True
)[-(max_seq_len - 1):]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long)[None, ...])
response = ''
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature,
top_k=top_k, stream=True)
try:
y = next(res_y)
except StopIteration:
return
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
placeholder.markdown(answer)
response = answer
try:
y = next(res_y)
except:
break
# if contain_history_chat:
# assistant_answer = answer.replace(new_prompt, "")
# messages.append({"role": "assistant", "content": assistant_answer})
messages.append({"role": "assistant", "content": response})
st.button("清空对话", on_click=clear_chat_messages)
if __name__ == "__main__":
main()