Skip to content

Commit

Permalink
Merge branch 'main' into new_tt_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
csunny committed Nov 28, 2023
2 parents 9919cd3 + 51a6830 commit 911b47d
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 82 deletions.
4 changes: 2 additions & 2 deletions pilot/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self) -> None:

# wenxin
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_API_SECRET")
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
Expand All @@ -84,7 +84,7 @@ def __init__(self) -> None:
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version
os.environ["spark_proxyllm_proxy_app_id"] = self.spark_proxy_api_appid
os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid

# baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
Expand Down
7 changes: 7 additions & 0 deletions pilot/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ class ProxyModelParameters(BaseModelParameters):
},
)

proxy_api_secret: Optional[str] = field(
default=None,
metadata={
"help": "The app secret for current proxy LLM(Just for spark proxy LLM now)."
},
)

proxy_api_type: Optional[str] = field(
default=None,
metadata={
Expand Down
11 changes: 11 additions & 0 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def chatgpt_generate_stream(
res = client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
# logger.info(str(r))
# Azure Openai reponse may have empty choices body in the first chunk
# to avoid index out of range error
if not r.get("choices"):
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
Expand All @@ -186,6 +191,8 @@ def chatgpt_generate_stream(

text = ""
for r in res:
if not r.get("choices"):
continue
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
Expand Down Expand Up @@ -220,6 +227,8 @@ async def async_chatgpt_generate_stream(
res = await client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
if not r.get("choices"):
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
Expand All @@ -233,6 +242,8 @@ async def async_chatgpt_generate_stream(

text = ""
async for r in res:
if not r.get("choices"):
continue
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
Expand Down
118 changes: 72 additions & 46 deletions pilot/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import json
import base64
import hmac
import hashlib
import websockets
from websockets.sync.client import connect
from datetime import datetime
from typing import List
from time import mktime
Expand All @@ -13,7 +12,22 @@
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.proxy.llms.proxy_model import ProxyModel

SPARK_DEFAULT_API_VERSION = "v2"
SPARK_DEFAULT_API_VERSION = "v3"


def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length


def checklen(text):
while getlength(text) > 8192:
del text[0]
return text


def spark_generate_stream(
Expand All @@ -23,41 +37,41 @@ def spark_generate_stream(
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
proxy_app_id = model_params.proxy_app_id
proxy_app_id = model_params.proxy_api_app_id

if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
else:
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"

messages: List[ModelMessage] = params["messages"]

last_user_input = None
for index in range(len(messages) - 1, -1, -1):
print(f"index: {index}")
if messages[index].role == ModelMessageRoleType.HUMAN:
last_user_input = {"role": "user", "content": messages[index].content}
del messages[index]
break

history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
# There is no role for system in spark LLM
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass

spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()

temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
question = checklen(history + [last_user_input])

print('last_user_input.get("content")', last_user_input.get("content"))
data = {
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
"parameter": {
"chat": {
"domain": domain,
Expand All @@ -67,23 +81,31 @@ def spark_generate_stream(
"temperature": params.get("temperature"),
}
},
"payload": {"message": {"text": last_user_input.get("content")}},
"payload": {"message": {"text": question}},
}

async_call(request_url, data)


async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
return get_response(request_url, data)


def get_response(request_url, data):
with connect(request_url) as ws:
ws.send(json.dumps(data, ensure_ascii=False))
result = ""
while True:
try:
chunk = ws.recv()
response = json.loads(chunk)
print("look out the response: ", response)
choices = response.get("payload", {}).get("choices", {})
if text := choices.get("text"):
result += text[0]["content"]
if choices.get("status") == 2:
break
except Exception:
break
yield result


class SparkAPI:
Expand All @@ -99,29 +121,33 @@ def __init__(
self.spark_url = spark_url

def gen_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))

_signature = "host: " + self.host + "\n"
_signature += "data: " + date + "\n"
_signature += "GET " + self.path + " HTTP/1.1"
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"

_signature_sha = hmac.new(
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
_signature.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()

_signature_sha_base64 = base64.b64encode(_signature_sha).decode(
encoding="utf-8"
)
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")

authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)

# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}

# 拼接鉴权参数,生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url
37 changes: 3 additions & 34 deletions pilot/model/proxy/llms/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def wenxin_generate_stream(
if not model_version:
yield f"Unsupport model version {model_name}"

keys: [] = model_params.proxy_api_key.split(";")
proxy_api_key = keys[0]
proxy_api_secret = keys[1]
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
access_token = _build_access_token(proxy_api_key, proxy_api_secret)

headers = {"Content-Type": "application/json", "Accept": "application/json"}
Expand All @@ -88,37 +87,7 @@ def wenxin_generate_stream(
yield "Failed to get access token. please set the correct api_key and secret key."

messages: List[ModelMessage] = params["messages"]
# Add history conversation
# system = ""
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
# role_define = messages.pop(0)
# system = role_define.content
# else:
# message = messages.pop(0)
# if message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
# for message in messages:
# if message.role == ModelMessageRoleType.SYSTEM:
# history.append({"role": "user", "content": message.content})
# # elif message.role == ModelMessageRoleType.HUMAN:
# # history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.AI:
# history.append({"role": "assistant", "content": message.content})
# else:
# pass
#
# # temp_his = history[::-1]
# temp_his = history
# last_user_input = None
# for m in temp_his:
# if m["role"] == "user":
# last_user_input = m
# break
#
# if last_user_input:
# history.remove(last_user_input)
# history.append(last_user_input)
#

history, systems = __convert_2_wenxin_messages(messages)
system = ""
if systems and len(systems) > 0:
Expand Down

0 comments on commit 911b47d

Please sign in to comment.