Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

模仿临时key增加临时chatgpt格式api接入方法。在输入框输入如下json可以临时切换使用该api,页面刷新即失效。 #1761

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
30 changes: 22 additions & 8 deletions request_llms/bridge_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
try:
# make a POST request to the API endpoint, stream=False
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
Expand Down Expand Up @@ -147,7 +147,12 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
return result


def is_any_tmp_model(inputs):
try:
tmp_model_info=json.loads(inputs).keys()
return "tmp_key" in tmp_model_info and "tmp_model" in tmp_model_info and "tmp_endpoint" in tmp_model_info
except:
return False
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
"""
Expand All @@ -164,8 +169,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
return
elif not is_any_api_key(chatbot._cookies['api_key']):
chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。"))
elif is_any_tmp_model(inputs):
chatbot._cookies['tmp_key'] = json.loads(inputs)['tmp_key']
chatbot._cookies['tmp_model'] = json.loads(inputs)['tmp_model']
chatbot._cookies['tmp_endpoint'] = json.loads(inputs)['tmp_endpoint']
chatbot.append(("输入已识别为临时openai格式的模型,页面刷新后将失效", '临时模型:'+json.loads(inputs)['tmp_model']))
yield from update_ui(chatbot=chatbot, history=history, msg="临时模型已导入") # 刷新界面
return
elif not is_any_api_key(chatbot._cookies['api_key']) and not chatbot._cookies['tmp_key']:
chatbot.append((inputs, '缺少API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交'))
yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面
return

Expand Down Expand Up @@ -195,7 +207,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
# 检查endpoint是否合法
try:
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint']
except:
tb_str = '```\n' + trimmed_format_exc() + '```'
chatbot[-1] = (inputs, tb_str)
Expand Down Expand Up @@ -320,11 +332,13 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"""
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
"""
if not is_any_api_key(llm_kwargs['api_key']):
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
if not is_any_api_key(llm_kwargs['api_key']) and not llm_kwargs['tmp_key']:
raise AssertionError('你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交')

if llm_kwargs['llm_model'].startswith('vllm-'):
api_key = 'no-api-key'
elif llm_kwargs['tmp_key']:
api_key = llm_kwargs['tmp_key']
else:
api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])

Expand Down Expand Up @@ -383,7 +397,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
logging.info("Random select model:" + model)

payload = {
"model": model,
"model": model if not llm_kwargs['tmp_model'] else llm_kwargs['tmp_model'] ,
"messages": messages,
"temperature": llm_kwargs['temperature'], # 1.0,
"top_p": llm_kwargs['top_p'], # 1.0,
Expand Down
9 changes: 9 additions & 0 deletions toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,18 @@ def decorated(request: gradio.Request, cookies:dict, max_length:int, llm_model:s
cookies.update({
'top_p': top_p,
'api_key': cookies['api_key'],
'tmp_key': cookies['tmp_key'],
'tmp_model': cookies['tmp_model'],
'tmp_endpoint': cookies['tmp_endpoint'],
'llm_model': llm_model,
'temperature': temperature,
'user_name': user_name,
})
llm_kwargs = {
'api_key': cookies['api_key'],
'tmp_key': cookies['tmp_key'],
'tmp_model': cookies['tmp_model'],
'tmp_endpoint': cookies['tmp_endpoint'],
'llm_model': llm_model,
'top_p': top_p,
'max_length': max_length,
Expand Down Expand Up @@ -607,6 +613,9 @@ def load_chat_cookies():
"api_key": API_KEY,
"llm_model": LLM_MODEL,
"customize_fn_overwrite": customize_fn_overwrite_,
"tmp_key":'',
"tmp_model":'',
"tmp_endpoint":'',
}


Expand Down