Skip to content

Commit 0f93169

Browse files
committed
update telegram
1 parent 6b5a535 commit 0f93169

File tree

10 files changed

+307
-77
lines changed

10 files changed

+307
-77
lines changed

app/components/__init__.py

Whitespace-only changes.

app/components/credential.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from urllib.parse import urlparse
2+
3+
import requests
4+
from pydantic import BaseModel
5+
6+
7+
class ProviderError(Exception):
8+
pass
9+
10+
11+
class Credential(BaseModel):
12+
api_key: str
13+
api_endpoint: str
14+
api_model: str
15+
16+
@classmethod
17+
def from_provider(cls, token, provider_url):
18+
"""
19+
使用 token POST 请求 provider_url 获取用户信息
20+
:param token: 用户 token
21+
:param provider_url: provider url
22+
:return: 用户信息
23+
:raises HTTPError: 请求失败
24+
:raises JSONDecodeError: 返回数据解析失败
25+
:raises ProviderError: provider 返回错误信息
26+
"""
27+
response = requests.post(provider_url, data={"token": token})
28+
response.raise_for_status()
29+
user_data = response.json()
30+
if user_data.get("error"):
31+
raise ProviderError(user_data["error"])
32+
return cls(
33+
api_key=user_data["api_key"],
34+
api_endpoint=user_data["api_endpoint"],
35+
api_model=user_data["api_model"],
36+
)
37+
38+
39+
def split_setting_string(input_string):
40+
if not isinstance(input_string, str):
41+
return None
42+
segments = input_string.split("$")
43+
44+
# 检查链接的有效性
45+
def is_valid_url(url):
46+
try:
47+
result = urlparse(url)
48+
return all([result.scheme, result.netloc])
49+
except ValueError:
50+
return False
51+
52+
# 开头为链接的情况
53+
if is_valid_url(segments[0]) and len(segments) >= 3:
54+
return segments[:3]
55+
# 第二个元素为链接,第一个元素为字符串的情况
56+
elif (
57+
len(segments) == 2
58+
and not is_valid_url(segments[0])
59+
and is_valid_url(segments[1])
60+
):
61+
return segments
62+
# 其他情况
63+
else:
64+
return None
+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2024/2/8 下午10:56
3+
# @Author : sudoskys
4+
# @File : __init__.py.py
5+
# @Software: PyCharm
6+
import time
7+
from typing import Optional
8+
9+
from loguru import logger
10+
from pydantic import BaseModel
11+
12+
from app.components.credential import Credential
13+
from app.const import DBNAME
14+
from llmkira.doc_manager import global_doc_client
15+
16+
17+
class ChatCost(BaseModel):
18+
user_id: str
19+
cost_token: int = 0
20+
endpoint: str = ""
21+
cost_model: str = ""
22+
produce_time: int = time.time()
23+
24+
25+
class GenerateHistory(object):
26+
def __init__(self, db_name: str = DBNAME, collection: str = "cost_history"):
27+
""" """
28+
self.client = global_doc_client.update_db_collection(
29+
db_name=db_name, collection_name=collection
30+
)
31+
32+
async def save(self, history: ChatCost):
33+
return self.client.insert_one(history.model_dump(mode="json"))
34+
35+
36+
class User(BaseModel):
37+
user_id: str
38+
last_use_time: int = time.time()
39+
credential: Optional[Credential] = None
40+
41+
42+
class UserManager(object):
43+
def __init__(self, db_name: str = DBNAME, collection: str = "user"):
44+
""" """
45+
self.client = global_doc_client.update_db_collection(
46+
db_name=db_name, collection_name=collection
47+
)
48+
49+
async def read(self, user_id: str) -> User:
50+
user_id = str(user_id)
51+
database_read = self.client.find_one({"user_id": user_id})
52+
if not database_read:
53+
logger.info(f"Create new user: {user_id}")
54+
return User(user_id=user_id)
55+
# database_read.update({"user_id": user_id})
56+
return User.model_validate(database_read)
57+
58+
async def save(self, user_model: User):
59+
user_model = user_model.model_copy(update={"last_use_time": int(time.time())})
60+
# 如果存在记录则更新
61+
if self.client.find_one({"user_id": user_model.user_id}):
62+
return self.client.update_one(
63+
{"user_id": user_model.user_id},
64+
{"$set": user_model.model_dump(mode="json")},
65+
)
66+
# 如果不存在记录则插入
67+
else:
68+
return self.client.insert_one(user_model.model_dump(mode="json"))
69+
70+
71+
COST_MANAGER = GenerateHistory()
72+
USER_MANAGER = UserManager()
73+
74+
75+
async def record_cost(
76+
user_id: str, cost_token: int, endpoint: str, cost_model: str, success: bool = True
77+
):
78+
try:
79+
await COST_MANAGER.save(
80+
ChatCost(
81+
user_id=user_id,
82+
produce_time=int(time.time()),
83+
endpoint=endpoint,
84+
cost_model=cost_model,
85+
cost_token=cost_token if success else 0,
86+
)
87+
)
88+
except Exception as exc:
89+
logger.error(f"🔥 record_cost error: {exc}")
90+
91+
92+
if __name__ == "__main__":
93+
pass
File renamed without changes.

llmkira/middleware/llm_task.py app/middleware/llm_task.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# @Time : 2023/8/18 上午9:37
33
# @Author : sudoskys
44
# @File : llm_task.py
5-
import os
65
from typing import List, Optional
76

87
from loguru import logger
98
from pydantic import SecretStr
109

10+
from app.components.credential import Credential
11+
from app.components.user_manager import record_cost
1112
from llmkira.kv_manager.instruction import InstructionManager
1213
from llmkira.memory import global_message_runtime
1314
from llmkira.openai.cell import Tool, Message, active_cell_string, SystemMessage
@@ -48,14 +49,13 @@ def __init__(
4849
self.message_history = global_message_runtime.update_session(
4950
session_id=session_uid
5051
)
51-
# TODO:实现用户配置读取
5252

5353
async def remember(self, *, message: Optional[Message] = None):
5454
"""
5555
写回消息到历史消息
5656
"""
5757
if message:
58-
await self.message_history.append(message=message)
58+
await self.message_history.append(messages=[message])
5959

6060
async def build_message(self, remember=True):
6161
"""
@@ -87,18 +87,20 @@ async def build_message(self, remember=True):
8787
user_message = message.format_user_message()
8888
message_run.append(user_message)
8989
if remember:
90-
await self.message_history.append(message=user_message)
90+
await self.message_history.append(messages=[user_message])
9191
return message_run
9292

9393
async def request_openai(
9494
self,
9595
remember: bool,
96+
credential: Credential,
9697
disable_tool: bool = False,
9798
) -> OpenAIResult:
9899
"""
99100
处理消息转换和调用工具
100101
:param remember: 是否自动写回
101102
:param disable_tool: 禁用函数
103+
:param credential: 凭证
102104
:return: OpenaiResult 返回结果
103105
:raise RuntimeError: 无法处理消息
104106
:raise AssertionError: 无法处理消息
@@ -113,7 +115,6 @@ async def request_openai(
113115
messages.append(SystemMessage(content=self.task.task_sign.instruction))
114116
messages.extend(await self.build_message(remember=remember))
115117
# TODO:实现消息时序切片
116-
117118
# 日志
118119
logger.info(
119120
f"[x] Openai request" f"\n--message {messages} " f"\n--tools {tools}"
@@ -125,21 +126,22 @@ async def request_openai(
125126
# 根据模型选择不同的驱动a
126127
assert messages, RuntimeError("llm_task:message cant be none...")
127128
endpoint: OpenAI = OpenAI(
128-
messages=messages,
129-
tools=tools,
130-
model="gpt-3.5-turbo", # FIXME:从用户配置中获取
129+
messages=messages, tools=tools, model=credential.api_model
131130
)
132131
# 调用Openai
133132
result: OpenAIResult = await endpoint.request(
134133
session=OpenAICredential(
135-
api_key=SecretStr(
136-
os.getenv("OPENAI_API_KEY", None)
137-
), # FIXME:从用户配置中获取
138-
base_url=os.getenv("OPENAI_API_ENDPOINT"), # FIXME:从用户配置中获取
134+
api_key=SecretStr(credential.api_key), base_url=credential.api_endpoint
139135
)
140136
)
141137
_message = result.default_message
142138
_usage = result.usage.total_tokens
139+
await record_cost(
140+
cost_model=credential.api_model,
141+
cost_token=_usage,
142+
endpoint=credential.api_endpoint,
143+
user_id=self.session_uid,
144+
)
143145
# 写回数据库
144146
await self.remember(message=_message)
145147
return result

app/receiver/receiver_client.py

+41-26
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from loguru import logger
1919
from telebot import formatting
2020

21+
from app.components.credential import Credential
22+
from app.components.user_manager import USER_MANAGER
23+
from app.middleware.llm_task import OpenaiMiddleware
2124
from llmkira.kv_manager.env import EnvManager
22-
from llmkira.middleware.llm_task import OpenaiMiddleware
2325
from llmkira.openai import OpenaiError
2426
from llmkira.openai.cell import ToolCall, Message, Tool
2527
from llmkira.openai.request import OpenAIResult
@@ -31,6 +33,11 @@
3133
from llmkira.task.snapshot import global_snapshot_storage
3234

3335

36+
async def read_user_credential(user_id: str) -> Optional[Credential]:
37+
user = await USER_MANAGER.read(user_id=user_id)
38+
return user.credential
39+
40+
3441
async def generate_authorization(
3542
secrets: Dict, tool_invocation: ToolCall
3643
) -> Tuple[dict, list, bool]:
@@ -147,7 +154,7 @@ async def forward(self, receiver: Location, message: list):
147154

148155
@abstractmethod
149156
async def reply(
150-
self, receiver: Location, message: Message, reply_to_message: bool = True
157+
self, receiver: Location, messages: List[Message], reply_to_message: bool = True
151158
):
152159
"""
153160
模型直转发,Message是Openai的类型
@@ -232,21 +239,28 @@ async def _flash(
232239
"""
233240
try:
234241
try:
242+
credentials = await read_user_credential(user_id=task.receiver.uid)
243+
assert credentials, "You need to /login first"
235244
llm_result = await llm.request_openai(
236245
remember=remember,
237246
disable_tool=disable_tool,
247+
credential=credentials,
238248
)
239249
assistant_message = llm_result.default_message
240250
logger.debug(f"Assistant:{assistant_message}")
241251
except OpenaiError as exc:
242252
await self.sender.error(receiver=task.receiver, text=exc.message)
243253
return exc
244-
except (RuntimeError, AssertionError) as exc:
254+
except RuntimeError as exc:
255+
logger.exception(exc)
245256
await self.sender.error(
246257
receiver=task.receiver,
247258
text="Can't get message validate from your history",
248259
)
249260
return exc
261+
except AssertionError as exc:
262+
await self.sender.error(receiver=task.receiver, text=str(exc))
263+
return exc
250264
except Exception as exc:
251265
logger.exception(exc)
252266
await self.sender.error(
@@ -269,7 +283,7 @@ async def _flash(
269283
)
270284
return logger.debug("Function loop ended")
271285
return await self.sender.reply(
272-
receiver=task.receiver, message=assistant_message
286+
receiver=task.receiver, messages=[assistant_message]
273287
)
274288
except Exception as e:
275289
raise e
@@ -364,30 +378,31 @@ async def on_message(self, message: AbstractIncomingMessage):
364378
snap_data = await global_snapshot_storage.read(
365379
user_id=task_head.receiver.uid
366380
)
367-
data = snap_data.data
368-
renew_snap_data = []
369-
for task in data:
370-
if not task.snapshot_credential and not task.processed:
371-
try:
372-
await Task.create_and_send(
373-
queue_name=task.channel, task=task.snapshot_data
374-
)
375-
except Exception as e:
376-
logger.exception(f"Response to snapshot error {e}")
381+
if snap_data is not None:
382+
data = snap_data.data
383+
renew_snap_data = []
384+
for task in data:
385+
if not task.snapshot_credential and not task.processed:
386+
try:
387+
await Task.create_and_send(
388+
queue_name=task.channel, task=task.snapshot_data
389+
)
390+
except Exception as e:
391+
logger.exception(f"Response to snapshot error {e}")
392+
else:
393+
logger.info(
394+
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
395+
)
396+
finally:
397+
task.processed_at = int(time.time())
398+
renew_snap_data.append(task)
377399
else:
378-
logger.info(
379-
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
380-
)
381-
finally:
382-
task.processed_at = int(time.time())
400+
task.processed_at = None
383401
renew_snap_data.append(task)
384-
else:
385-
task.processed_at = None
386-
renew_snap_data.append(task)
387-
snap_data.data = renew_snap_data
388-
await global_snapshot_storage.write(
389-
user_id=task_head.receiver.uid, snapshot=snap_data
390-
)
402+
snap_data.data = renew_snap_data
403+
await global_snapshot_storage.write(
404+
user_id=task_head.receiver.uid, snapshot=snap_data
405+
)
391406
except Exception as e:
392407
logger.exception(e)
393408
await message.reject(requeue=False)

app/receiver/telegram/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from app.receiver.receiver_client import BaseReceiver, BaseSender
1515
from app.setting.telegram import BotSetting
1616
from llmkira.kv_manager.file import File
17-
from llmkira.middleware.llm_task import OpenaiMiddleware
17+
from app.middleware.llm_task import OpenaiMiddleware
1818
from llmkira.openai.cell import Message
1919
from llmkira.openai.request import OpenAIResult
2020
from llmkira.task import Task, TaskHeader

0 commit comments

Comments
 (0)