From 89098fba7a44b48da00c42d44ca1c5a5a7a8b629 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 1 Nov 2025 17:53:43 +0000 Subject: [PATCH 1/6] Initial plan From 84316d1bbe30fc6612ac67aafd3667ac5df6d531 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 1 Nov 2025 17:58:20 +0000 Subject: [PATCH 2/6] Fix function tool duplication bug during plugin reload Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --- astrbot/core/star/star_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index ef50917fe..7949eb52f 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -751,6 +751,18 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): ]: del star_handlers_registry.star_handlers_map[k] + # 移除插件注册的函数调用工具 + removed_tools = [] + for func_tool in list(llm_tools.func_list): + if ( + func_tool.handler_module_path == plugin_module_path + and func_tool.origin != "mcp" + ): + llm_tools.func_list.remove(func_tool) + removed_tools.append(func_tool.name) + if removed_tools: + logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}") + if plugin is None: return From 3dd467fe2b25822fe0972783bca7ad5f9646bb9f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 1 Nov 2025 18:04:22 +0000 Subject: [PATCH 3/6] Improve function tool removal to handle edge cases Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --- astrbot/core/star/star_manager.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 7949eb52f..998151474 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -754,10 +754,21 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): # 移除插件注册的函数调用工具 removed_tools = [] for func_tool in list(llm_tools.func_list): - if ( - func_tool.handler_module_path == plugin_module_path - and func_tool.origin != "mcp" - ): + # 检查工具是否属于此插件: + # 1. 通过 handler_module_path 匹配(已绑定的工具) + # 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的) + should_remove = False + if func_tool.origin != "mcp": + if func_tool.handler_module_path == plugin_module_path: + should_remove = True + elif ( + func_tool.handler + and hasattr(func_tool.handler, "__module__") + and func_tool.handler.__module__ == plugin_module_path + ): + should_remove = True + + if should_remove: llm_tools.func_list.remove(func_tool) removed_tools.append(func_tool.name) if removed_tools: From fc742d99b302985117049dd3ac6c3644a34aab01 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 1 Nov 2025 18:06:21 +0000 Subject: [PATCH 4/6] Refactor tool removal to use list comprehension for consistency Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --- astrbot/core/star/star_manager.py | 38 ++++++++++++++++--------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 998151474..662c91e17 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -752,25 +752,27 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): del star_handlers_registry.star_handlers_map[k] # 移除插件注册的函数调用工具 - removed_tools = [] - for func_tool in list(llm_tools.func_list): - # 检查工具是否属于此插件: - # 1. 通过 handler_module_path 匹配(已绑定的工具) - # 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的) - should_remove = False - if func_tool.origin != "mcp": - if func_tool.handler_module_path == plugin_module_path: - should_remove = True - elif ( - func_tool.handler - and hasattr(func_tool.handler, "__module__") - and func_tool.handler.__module__ == plugin_module_path - ): - should_remove = True + def _should_remove_tool(func_tool): + """检查工具是否属于此插件: + 1. 通过 handler_module_path 匹配(已绑定的工具) + 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的) + """ + if func_tool.origin == "mcp": + return False + if func_tool.handler_module_path == plugin_module_path: + return True + if ( + func_tool.handler + and hasattr(func_tool.handler, "__module__") + and func_tool.handler.__module__ == plugin_module_path + ): + return True + return False - if should_remove: - llm_tools.func_list.remove(func_tool) - removed_tools.append(func_tool.name) + removed_tools = [f.name for f in llm_tools.func_list if _should_remove_tool(f)] + llm_tools.func_list = [ + f for f in llm_tools.func_list if not _should_remove_tool(f) + ] if removed_tools: logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}") From a82c1850be840e2cc936f5ba5dae4c92312058c9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 1 Nov 2025 18:09:08 +0000 Subject: [PATCH 5/6] Optimize tool removal to single pass and improve documentation Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --- astrbot/core/star/star_manager.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 662c91e17..4966735d5 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -753,9 +753,13 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): # 移除插件注册的函数调用工具 def _should_remove_tool(func_tool): - """检查工具是否属于此插件: - 1. 通过 handler_module_path 匹配(已绑定的工具) - 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的) + """检查工具是否属于此插件 + + 通过两种方式匹配工具: + 1. handler_module_path 匹配:适用于已完成绑定的工具(在插件加载后,handler 被 functools.partial 包装) + 2. handler.__module__ 匹配:适用于尚未绑定的工具(在 __init__ 中通过 add_llm_tools 添加的工具) + + 注意:MCP 工具由 MCP 客户端管理,不在此移除 """ if func_tool.origin == "mcp": return False @@ -769,10 +773,16 @@ def _should_remove_tool(func_tool): return True return False - removed_tools = [f.name for f in llm_tools.func_list if _should_remove_tool(f)] - llm_tools.func_list = [ - f for f in llm_tools.func_list if not _should_remove_tool(f) - ] + # 一次遍历完成工具的分类和移除 + removed_tools = [] + remaining_tools = [] + for func_tool in llm_tools.func_list: + if _should_remove_tool(func_tool): + removed_tools.append(func_tool.name) + else: + remaining_tools.append(func_tool) + llm_tools.func_list = remaining_tools + if removed_tools: logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}") From 6ebea8d538a204025dbebbbda4f03ddfd6dec2ed Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 11:02:26 +0000 Subject: [PATCH 6/6] Merge master and resolve conflict in star_manager.py Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --- astrbot/cli/commands/cmd_conf.py | 2 +- astrbot/cli/commands/cmd_init.py | 3 +- astrbot/cli/utils/plugin.py | 8 +- astrbot/cli/utils/version_comparator.py | 9 +- astrbot/core/agent/hooks.py | 2 - astrbot/core/agent/mcp_client.py | 36 ++++ astrbot/core/agent/message.py | 168 ++++++++++++++++++ astrbot/core/agent/run_context.py | 4 +- .../agent/runners/tool_loop_agent_runner.py | 19 +- astrbot/core/agent/tool.py | 76 +++++--- astrbot/core/astr_agent_context.py | 3 +- astrbot/core/config/astrbot_config.py | 4 +- astrbot/core/conversation_mgr.py | 36 ++++ astrbot/core/db/migration/migra_3_to_4.py | 7 +- astrbot/core/db/migration/sqlite_v3.py | 10 +- astrbot/core/persona_mgr.py | 6 +- .../strategies/baidu_aip.py | 7 +- astrbot/core/pipeline/context.py | 3 +- astrbot/core/pipeline/context_utils.py | 65 +++++++ .../process_stage/method/llm_request.py | 113 ++++++++---- .../core/pipeline/result_decorate/stage.py | 5 +- astrbot/core/platform/astr_message_event.py | 39 ++-- astrbot/core/platform/astrbot_message.py | 12 +- astrbot/core/platform/message_session.py | 2 +- astrbot/core/platform/platform_metadata.py | 8 +- astrbot/core/platform/register.py | 6 +- .../aiocqhttp/aiocqhttp_message_event.py | 2 +- .../aiocqhttp/aiocqhttp_platform_adapter.py | 6 +- .../core/platform/sources/discord/client.py | 2 +- .../platform/sources/discord/components.py | 22 +-- .../discord/discord_platform_adapter.py | 4 +- .../sources/discord/discord_platform_event.py | 7 +- .../qqofficial/qqofficial_message_event.py | 22 +-- .../platform/sources/slack/slack_adapter.py | 16 +- .../platform/sources/slack/slack_event.py | 9 +- .../sources/wechatpadpro/xml_data_parser.py | 2 +- astrbot/core/platform_message_history_mgr.py | 4 +- astrbot/core/provider/entities.py | 88 ++++----- astrbot/core/provider/func_tool_manager.py | 21 ++- astrbot/core/provider/provider.py | 70 +++++--- astrbot/core/provider/register.py | 4 +- .../core/provider/sources/anthropic_source.py | 19 +- astrbot/core/provider/sources/coze_source.py | 1 + .../core/provider/sources/dashscope_source.py | 9 +- .../core/provider/sources/gemini_source.py | 21 ++- .../core/provider/sources/openai_source.py | 31 ++-- astrbot/core/star/context.py | 57 ++++-- astrbot/core/star/filter/command.py | 16 +- astrbot/core/star/filter/command_group.py | 29 +-- astrbot/core/star/star_manager.py | 16 +- astrbot/core/utils/dify_api_client.py | 8 +- astrbot/core/utils/io.py | 4 +- astrbot/core/utils/pip_installer.py | 8 +- astrbot/dashboard/routes/chat.py | 8 +- astrbot/dashboard/routes/config.py | 4 +- astrbot/dashboard/server.py | 11 +- main.py | 12 +- packages/astrbot/commands/conversation.py | 27 ++- packages/astrbot/commands/persona.py | 7 +- packages/astrbot/commands/plugin.py | 25 +-- packages/astrbot/commands/provider.py | 55 +++--- packages/astrbot/long_term_memory.py | 10 +- packages/reminder/main.py | 7 +- 63 files changed, 897 insertions(+), 420 deletions(-) create mode 100644 astrbot/core/agent/message.py diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index 86f78cbaa..a9bd40f00 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -178,7 +178,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str = None): +def get_config(key: str | None = None): """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py index 993995a66..6c0c34b99 100644 --- a/astrbot/cli/commands/cmd_init.py +++ b/astrbot/cli/commands/cmd_init.py @@ -1,4 +1,5 @@ import asyncio +from pathlib import Path import click from filelock import FileLock, Timeout @@ -6,7 +7,7 @@ from ..utils import check_dashboard, get_astrbot_root -async def initialize_astrbot(astrbot_root) -> None: +async def initialize_astrbot(astrbot_root: Path) -> None: """执行 AstrBot 初始化逻辑""" dot_astrbot = astrbot_root / ".astrbot" diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index 55edf4de2..cd76a07c8 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -221,9 +221,9 @@ def manage_plugin( raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新") # 备份现有插件 - if is_update and backup_path.exists(): + if is_update and backup_path is not None and backup_path.exists(): shutil.rmtree(backup_path) - if is_update: + if is_update and backup_path is not None: shutil.copytree(target_path, backup_path) try: @@ -233,13 +233,13 @@ def manage_plugin( get_git_repo(repo_url, target_path, proxy) # 更新成功,删除备份 - if is_update and backup_path.exists(): + if is_update and backup_path is not None and backup_path.exists(): shutil.rmtree(backup_path) click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功") except Exception as e: if target_path.exists(): shutil.rmtree(target_path, ignore_errors=True) - if is_update and backup_path.exists(): + if is_update and backup_path is not None and backup_path.exists(): shutil.move(backup_path, target_path) raise click.ClickException( f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}", diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index 99d71d34d..0aaf8dcab 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -62,9 +62,12 @@ def split_version(version): return -1 if isinstance(p1, str) and isinstance(p2, int): return 1 - if (isinstance(p1, int) and isinstance(p2, int)) or ( - isinstance(p1, str) and isinstance(p2, str) - ): + if isinstance(p1, int) and isinstance(p2, int): + if p1 > p2: + return 1 + if p1 < p2: + return -1 + elif isinstance(p1, str) and isinstance(p2, str): if p1 > p2: return 1 if p1 < p2: diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 949ebd3fe..d834240b7 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Generic import mcp @@ -9,7 +8,6 @@ from .run_context import ContextWrapper, TContext -@dataclass class BaseAgentRunHooks(Generic[TContext]): async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... async def on_tool_start( diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 303973a0d..05980b212 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,10 +2,15 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta +from typing import Generic from astrbot import logger +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe +from .run_context import TContext +from .tool import FunctionTool + try: import mcp from mcp.client.sse import sse_client @@ -221,3 +226,34 @@ async def cleanup(self): """Clean up resources""" await self.exit_stack.aclose() self.running_event.set() # Set the running event to indicate cleanup is done + + +class MCPTool(FunctionTool, Generic[TContext]): + """A function tool that calls an MCP service.""" + + def __init__( + self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs + ): + super().__init__( + name=mcp_tool.name, + description=mcp_tool.description or "", + parameters=mcp_tool.inputSchema, + ) + self.mcp_tool = mcp_tool + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, context: ContextWrapper[TContext], **kwargs + ) -> mcp.types.CallToolResult: + session = self.mcp_client.session + if not session: + raise ValueError("MCP session is not available for MCP function tools.") + res = await session.call_tool( + name=self.mcp_tool.name, + arguments=kwargs, + read_timeout_seconds=timedelta( + seconds=context.tool_call_timeout, + ), + ) + return res diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py new file mode 100644 index 000000000..11128c0f6 --- /dev/null +++ b/astrbot/core/agent/message.py @@ -0,0 +1,168 @@ +# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation. +# License: Apache License 2.0 + +from typing import Any, ClassVar, Literal, cast + +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic_core import core_schema + + +class ContentPart(BaseModel): + """A part of the content in a message.""" + + __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {} + + type: str + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" + + type_value = getattr(cls, "type", None) + if type_value is None or not isinstance(type_value, str): + raise ValueError(invalid_subclass_error_msg) + + cls.__content_part_registry[type_value] = cls + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + # If we're dealing with the base ContentPart class, use custom validation + if cls.__name__ == "ContentPart": + + def validate_content_part(value: Any) -> Any: + # if it's already an instance of a ContentPart subclass, return it + if hasattr(value, "__class__") and issubclass(value.__class__, cls): + return value + + # if it's a dict with a type field, dispatch to the appropriate subclass + if isinstance(value, dict) and "type" in value: + type_value: Any | None = cast(dict[str, Any], value).get("type") + if not isinstance(type_value, str): + raise ValueError(f"Cannot validate {value} as ContentPart") + target_class = cls.__content_part_registry[type_value] + return target_class.model_validate(value) + + raise ValueError(f"Cannot validate {value} as ContentPart") + + return core_schema.no_info_plain_validator_function(validate_content_part) + + # for subclasses, use the default schema + return handler(source_type) + + +class TextPart(ContentPart): + """ + >>> TextPart(text="Hello, world!").model_dump() + {'type': 'text', 'text': 'Hello, world!'} + """ + + type: str = "text" + text: str + + +class ImageURLPart(ContentPart): + """ + >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump() + {'type': 'image_url', 'image_url': 'http://example.com/image.jpg'} + """ + + class ImageURL(BaseModel): + url: str + """The URL of the image, can be data URI scheme like `data:image/png;base64,...`.""" + id: str | None = None + """The ID of the image, to allow LLMs to distinguish different images.""" + + type: str = "image_url" + image_url: str + + +class AudioURLPart(ContentPart): + """ + >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump() + {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}} + """ + + class AudioURL(BaseModel): + url: str + """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`.""" + id: str | None = None + """The ID of the audio, to allow LLMs to distinguish different audios.""" + + type: str = "audio_url" + audio_url: AudioURL + + +class ToolCall(BaseModel): + """ + A tool call requested by the assistant. + + >>> ToolCall( + ... id="123", + ... function=ToolCall.FunctionBody( + ... name="function", + ... arguments="{}" + ... ), + ... ).model_dump() + {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}} + """ + + class FunctionBody(BaseModel): + name: str + arguments: str | None + + type: Literal["function"] = "function" + + id: str + """The ID of the tool call.""" + function: FunctionBody + """The function body of the tool call.""" + + +class ToolCallPart(BaseModel): + """A part of the tool call.""" + + arguments_part: str | None = None + """A part of the arguments of the tool call.""" + + +class Message(BaseModel): + """A message in a conversation.""" + + role: Literal[ + "system", + "user", + "assistant", + "tool", + ] + + content: str | list[ContentPart] + """The content of the message.""" + + +class AssistantMessageSegment(Message): + """A message segment from the assistant.""" + + role: Literal["assistant"] = "assistant" + tool_calls: list[ToolCall] | list[dict] | None = None + + +class ToolCallMessageSegment(Message): + """A message segment representing a tool call.""" + + role: Literal["tool"] = "tool" + tool_call_id: str + + +class UserMessageSegment(Message): + """A message segment from the user.""" + + role: Literal["user"] = "user" + + +class SystemMessageSegment(Message): + """A message segment from the system.""" + + role: Literal["system"] = "system" diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 634735ccc..395817679 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -3,8 +3,6 @@ from typing_extensions import TypeVar -from astrbot.core.platform.astr_message_event import AstrMessageEvent - TContext = TypeVar("TContext", default=Any) @@ -13,7 +11,7 @@ class ContextWrapper(Generic[TContext]): """A context for running an agent, which can be used to pass additional data or state.""" context: TContext - event: AstrMessageEvent + tool_call_timeout: int = 60 # Default tool call timeout in seconds NoContext = ContextWrapper[None] diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index cb89fb612..23071d446 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -16,15 +16,14 @@ MessageChain, ) from astrbot.core.provider.entities import ( - AssistantMessageSegment, LLMResponse, ProviderRequest, - ToolCallMessageSegment, ToolCallsResult, ) from astrbot.core.provider.provider import Provider from ..hooks import BaseAgentRunHooks +from ..message import AssistantMessageSegment, ToolCallMessageSegment from ..response import AgentResponseData from ..run_context import ContextWrapper, TContext from ..tool_executor import BaseFunctionToolExecutor @@ -171,8 +170,7 @@ async def step(self): # 将结果添加到上下文中 tool_calls_result = ToolCallsResult( tool_calls_info=AssistantMessageSegment( - role="assistant", - tool_calls=llm_resp.to_openai_tool_calls(), + tool_calls=llm_resp.to_openai_to_calls_model(), content=llm_resp.completion_text, ), tool_calls_result=tool_call_result_blocks, @@ -238,7 +236,6 @@ async def _handle_function_tools( else: # 如果没有 handler(如 MCP 工具),使用所有参数 valid_params = func_tool_args - logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数") try: await self.agent_hooks.on_tool_start( @@ -319,13 +316,11 @@ async def _handle_function_tools( elif resp is None: # Tool 直接请求发送消息给用户 # 这里我们将直接结束 Agent Loop。 + # 发送消息逻辑在 ToolExecutor 中处理了。 + logger.warning( + f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。" + ) self._transition_state(AgentState.DONE) - if res := self.run_context.event.get_result(): - if res.chain: - yield MessageChain( - chain=res.chain, - type="tool_direct_result", - ) else: # 不应该出现其他类型 logger.warning( @@ -341,8 +336,6 @@ async def _handle_function_tools( ) except Exception as e: logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) - - self.run_context.event.clear_result() except Exception as e: logger.warning(traceback.format_exc()) tool_call_result_blocks.append( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 3c36def63..e9738dc0f 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,52 +1,76 @@ from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Generic +import jsonschema +import mcp from deprecated import deprecated +from pydantic import model_validator +from pydantic.dataclasses import dataclass -from .mcp_client import MCPClient +from .run_context import ContextWrapper, TContext + +ParametersType = dict[str, Any] @dataclass -class FunctionTool: - """A class representing a function tool that can be used in function calling.""" +class ToolSchema: + """A class representing the schema of a tool for function calling.""" name: str - parameters: dict | None = None - description: str | None = None + """The name of the tool.""" + + description: str + """The description of the tool.""" + + parameters: ParametersType + """The parameters of the tool, in JSON Schema format.""" + + @model_validator(mode="after") + def validate_parameters(self) -> "ToolSchema": + jsonschema.validate( + self.parameters, jsonschema.Draft202012Validator.META_SCHEMA + ) + return self + + +@dataclass +class FunctionTool(ToolSchema, Generic[TContext]): + """A callable tool, for function calling.""" + handler: Callable[..., Awaitable[Any]] | None = None - """处理函数, 当 origin 为 mcp 时,这个为空""" - handler_module_path: str | None = None - """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + """a callable that implements the tool's functionality. It should be an async function.""" - 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + handler_module_path: str | None = None + """ + The module path of the handler function. This is empty when the origin is mcp. + This field must be retained, as the handler will be wrapped in functools.partial during initialization, + causing the handler's __module__ to be functools """ active: bool = True - """是否激活""" - - origin: Literal["local", "mcp"] = "local" - """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" - - # MCP 相关字段 - mcp_server_name: str | None = None - """MCP 服务名称,当 origin 为 mcp 时有效""" - mcp_client: MCPClient | None = None - """MCP 客户端,当 origin 为 mcp 时有效""" + """ + Whether the tool is active. This field is a special field for AstrBot. + You can ignore it when integrating with other frameworks. + """ def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" def __dict__(self) -> dict[str, Any]: - """将 FunctionTool 转换为字典格式""" return { "name": self.name, "parameters": self.parameters, "description": self.description, "active": self.active, - "origin": self.origin, - "mcp_server_name": self.mcp_server_name, } + async def call( + self, context: ContextWrapper[TContext], **kwargs + ) -> str | mcp.types.CallToolResult: + """Run the tool with the given arguments. The handler field has priority.""" + raise NotImplementedError( + "FunctionTool.call() must be implemented by subclasses or set a handler." + ) + class ToolSet: """A set of function tools that can be used in function calling. @@ -225,7 +249,7 @@ def convert_schema(schema: dict) -> dict: tools = [] for tool in self.tools: - d = { + d: dict[str, Any] = { "name": tool.name, "description": tool.description, } diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index e21ddb9c6..28b242253 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest @@ -10,4 +11,4 @@ class AstrAgentContext: first_provider_request: ProviderRequest curr_provider_request: ProviderRequest streaming: bool - tool_call_timeout: int = 60 # Default tool call timeout in seconds + event: AstrMessageEvent diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 68b73cd29..786d29c81 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -28,7 +28,7 @@ def __init__( self, config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, - schema: dict = None, + schema: dict | None = None, ): super().__init__() @@ -142,7 +142,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict = None): + def save_config(self, replace_config: dict | None = None): """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2be406100..287fe03c4 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -8,6 +8,7 @@ from collections.abc import Awaitable, Callable from astrbot.core import sp +from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -319,6 +320,41 @@ async def update_conversation_persona_id( persona_id=persona_id, ) + async def add_message_pair( + self, + cid: str, + user_message: UserMessageSegment | dict, + assistant_message: AssistantMessageSegment | dict, + ) -> None: + """Add a user-assistant message pair to the conversation history. + + Args: + cid (str): Conversation ID + user_message (UserMessageSegment | dict): OpenAI-format user message object or dict + assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict + + Raises: + Exception: If the conversation with the given ID is not found + """ + conv = await self.db.get_conversation_by_id(cid=cid) + if not conv: + raise Exception(f"Conversation with id {cid} not found") + history = conv.content or [] + if isinstance(user_message, UserMessageSegment): + user_msg_dict = user_message.model_dump() + else: + user_msg_dict = user_message + if isinstance(assistant_message, AssistantMessageSegment): + assistant_msg_dict = assistant_message.model_dump() + else: + assistant_msg_dict = assistant_message + history.append(user_msg_dict) + history.append(assistant_msg_dict) + await self.db.update_conversation( + cid=cid, + content=history, + ) + async def get_human_readable_context( self, unified_msg_origin: str, diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 13a14c327..a75c60a1b 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -250,14 +250,15 @@ async def migration_persona_data( try: begin_dialogs = persona.get("begin_dialogs", []) mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) - mood_prompt = "" + parts = [] user_turn = True for mood_dialog in mood_imitation_dialogs: if user_turn: - mood_prompt += f"A: {mood_dialog}\n" + parts.append(f"A: {mood_dialog}\n") else: - mood_prompt += f"B: {mood_dialog}\n" + parts.append(f"B: {mood_dialog}\n") user_turn = not user_turn + mood_prompt = "".join(parts) system_prompt = persona.get("prompt", "") if mood_prompt: system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}" diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index 7b341c316..a301028d1 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -384,11 +384,11 @@ def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: list[str] = None, - message_types: list[str] = None, - search_query: str = None, - exclude_ids: list[str] = None, - exclude_platforms: list[str] = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, + search_query: str | None = None, + exclude_ids: list[str] | None = None, + exclude_platforms: list[str] | None = None, ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index 482b5887c..5d1743ab9 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -68,9 +68,9 @@ async def delete_persona(self, persona_id: str): async def update_persona( self, persona_id: str, - system_prompt: str = None, - begin_dialogs: list[str] = None, - tools: list[str] = None, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, ): """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" existing_persona = await self.db.get_persona_by_id(persona_id) diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py index c11822896..bfa82de0e 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -21,8 +21,9 @@ def check(self, content: str) -> tuple[bool, str]: if "data" not in res: return False, "" count = len(res["data"]) - info = f"百度审核服务发现 {count} 处违规:\n" + parts = [f"百度审核服务发现 {count} 处违规:\n"] for i in res["data"]: - info += f"{i['msg']};\n" - info += "\n判断结果:" + res["conclusion"] + parts.append(f"{i['msg']};\n") + parts.append("\n判断结果:" + res["conclusion"]) + info = "".join(parts) return False, info diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index a6cd567e0..44186764e 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -3,7 +3,7 @@ from astrbot.core.config import AstrBotConfig from astrbot.core.star import PluginManager -from .context_utils import call_event_hook, call_handler +from .context_utils import call_event_hook, call_handler, call_local_llm_tool @dataclass @@ -15,3 +15,4 @@ class PipelineContext: astrbot_config_id: str call_handler = call_handler call_event_hook = call_event_hook + call_local_llm_tool = call_local_llm_tool diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 73d28c5d1..371816b6e 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -3,6 +3,8 @@ import typing as T from astrbot import logger +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star import star_map @@ -105,3 +107,66 @@ async def call_event_hook( return True return event.is_stopped() + + +async def call_local_llm_tool( + context: ContextWrapper[AstrAgentContext], + handler: T.Callable[..., T.Awaitable[T.Any]], + method_name: str, + *args, + **kwargs, +) -> T.AsyncGenerator[T.Any, None]: + """执行本地 LLM 工具的处理函数并处理其返回结果""" + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + event = context.context.event + + try: + if method_name == "run" or method_name == "decorator_handler": + ready_to_call = handler(event, *args, **kwargs) + elif method_name == "call": + ready_to_call = handler(context, *args, **kwargs) + else: + raise ValueError(f"未知的方法名: {method_name}") + except ValueError as e: + logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True) + except TypeError: + logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + except Exception as e: + trace_ = traceback.format_exc() + logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}") + + if not ready_to_call: + return + + if inspect.isasyncgen(ready_to_call): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index d1cffc43f..03352cc40 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -5,11 +5,14 @@ import json import traceback from collections.abc import AsyncGenerator -from datetime import timedelta +from typing import Any + +from mcp.types import CallToolResult from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import FunctionTool, ToolSet @@ -33,7 +36,7 @@ from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric -from ...context import PipelineContext, call_event_hook, call_handler +from ...context import PipelineContext, call_event_hook, call_local_llm_tool from ..stage import Stage from ..utils import inject_kb_context @@ -65,18 +68,16 @@ async def execute(cls, tool, run_context, **tool_args): yield r return - if tool.origin == "local": - async for r in cls._execute_local(tool, run_context, **tool_args): + elif isinstance(tool, MCPTool): + async for r in cls._execute_mcp(tool, run_context, **tool_args): yield r return - elif tool.origin == "mcp": - async for r in cls._execute_mcp(tool, run_context, **tool_args): + else: + async for r in cls._execute_local(tool, run_context, **tool_args): yield r return - raise Exception(f"Unknown function origin: {tool.origin}") - @classmethod async def _execute_handoff( cls, @@ -113,10 +114,13 @@ async def _execute_handoff( first_provider_request=run_context.context.first_provider_request, curr_provider_request=request, streaming=run_context.context.streaming, + event=run_context.context.event, ) + event = run_context.context.event + logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") - await run_context.event.send( + await event.send( MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), ) @@ -125,7 +129,7 @@ async def _execute_handoff( request=request, run_context=AgentContextWrapper( context=astr_agent_ctx, - event=run_context.event, + tool_call_timeout=run_context.tool_call_timeout, ), tool_executor=FunctionToolExecutor(), agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), @@ -175,25 +179,46 @@ async def _execute_local( run_context: ContextWrapper[AstrAgentContext], **tool_args, ): - if not run_context.event: + event = run_context.context.event + if not event: raise ValueError("Event must be provided for local function tools.") - # 检查 tool 下有没有 run 方法 - if not tool.handler and not hasattr(tool, "run"): - raise ValueError("Tool must have a valid handler or 'run' method.") - awaitable = tool.handler or tool.run + is_override_call = False + for ty in type(tool).mro(): + if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: + logger.debug(f"Found call in: {ty}") + is_override_call = True + break - wrapper = call_handler( - event=run_context.event, + # 检查 tool 下有没有 run 方法 + if not tool.handler and not hasattr(tool, "run") and not is_override_call: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + awaitable = None + method_name = "" + if tool.handler: + awaitable = tool.handler + method_name = "decorator_handler" + elif is_override_call: + awaitable = tool.call + method_name = "call" + elif hasattr(tool, "run"): + awaitable = getattr(tool, "run") + method_name = "run" + if awaitable is None: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + wrapper = call_local_llm_tool( + context=run_context, handler=awaitable, + method_name=method_name, **tool_args, ) - # async for resp in wrapper: while True: try: resp = await asyncio.wait_for( anext(wrapper), - timeout=run_context.context.tool_call_timeout, + timeout=run_context.tool_call_timeout, ) if resp is not None: if isinstance(resp, mcp.types.CallToolResult): @@ -208,10 +233,24 @@ async def _execute_local( # NOTE: Tool 在这里直接请求发送消息给用户 # TODO: 是否需要判断 event.get_result() 是否为空? # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" + if res := run_context.context.event.get_result(): + if res.chain: + try: + await event.send( + MessageChain( + chain=res.chain, + type="tool_direct_result", + ) + ) + except Exception as e: + logger.error( + f"Tool 直接发送消息失败: {e}", + exc_info=True, + ) yield None except asyncio.TimeoutError: raise Exception( - f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds.", + f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.", ) except StopAsyncIteration: break @@ -223,19 +262,7 @@ async def _execute_mcp( run_context: ContextWrapper[AstrAgentContext], **tool_args, ): - if not tool.mcp_client: - raise ValueError("MCP client is not available for MCP function tools.") - - session = tool.mcp_client.session - if not session: - raise ValueError("MCP session is not available for MCP function tools.") - res = await session.call_tool( - name=tool.name, - arguments=tool_args, - read_timeout_seconds=timedelta( - seconds=run_context.context.tool_call_timeout, - ), - ) + res = await tool.call(run_context, **tool_args) if not res: return yield res @@ -245,11 +272,20 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 await call_event_hook( - run_context.event, + run_context.context.event, EventType.OnLLMResponseEvent, llm_response, ) + async def on_tool_end( + self, + run_context: ContextWrapper[AstrAgentContext], + tool: FunctionTool[Any], + tool_args: dict | None, + tool_result: CallToolResult | None, + ): + run_context.context.event.clear_result() + MAIN_AGENT_HOOKS = MainAgentHooks() @@ -260,7 +296,7 @@ async def run_agent( show_tool_use: bool = True, ) -> AsyncGenerator[MessageChain, None]: step_idx = 0 - astr_event = agent_runner.run_context.event + astr_event = agent_runner.run_context.context.event while step_idx < max_step: step_idx += 1 try: @@ -513,12 +549,15 @@ async def process( first_provider_request=req, curr_provider_request=req, streaming=self.streaming_response, - tool_call_timeout=self.tool_call_timeout, + event=event, ) await agent_runner.reset( provider=provider, request=req, - run_context=AgentContextWrapper(context=astr_agent_ctx, event=event), + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=self.tool_call_timeout, + ), tool_executor=FunctionToolExecutor(), agent_hooks=MAIN_AGENT_HOOKS, streaming=self.streaming_response, diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 08661a367..5dfb52f6f 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -246,12 +246,13 @@ async def process( elif ( result.use_t2i_ is None and self.ctx.astrbot_config["t2i"] ) or result.use_t2i_: - plain_str = "" + parts = [] for comp in result.chain: if isinstance(comp, Plain): - plain_str += "\n\n" + comp.text + parts.append("\n\n" + comp.text) else: break + plain_str = "".join(parts) if plain_str and len(plain_str) > self.t2i_word_threshold: render_start = time.time() try: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 9605eaffb..6402aeaed 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -91,33 +91,34 @@ def get_message_str(self) -> str: return self.message_str def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: - outline = "" if not chain: - return outline + return "" + + parts = [] for i in chain: if isinstance(i, Plain): - outline += i.text + parts.append(i.text) elif isinstance(i, Image): - outline += "[图片]" + parts.append("[图片]") elif isinstance(i, Face): - outline += f"[表情:{i.id}]" + parts.append(f"[表情:{i.id}]") elif isinstance(i, At): - outline += f"[At:{i.qq}]" + parts.append(f"[At:{i.qq}]") elif isinstance(i, AtAll): - outline += "[At:全体成员]" + parts.append("[At:全体成员]") elif isinstance(i, Forward): # 转发消息 - outline += "[转发消息]" + parts.append("[转发消息]") elif isinstance(i, Reply): # 引用回复 if i.message_str: - outline += f"[引用消息({i.sender_nickname}: {i.message_str})]" + parts.append(f"[引用消息({i.sender_nickname}: {i.message_str})]") else: - outline += "[引用消息]" + parts.append("[引用消息]") else: - outline += f"[{i.type}]" - outline += " " - return outline + parts.append(f"[{i.type}]") + parts.append(" ") + return "".join(parts) def get_message_outline(self) -> str: """获取消息概要。 @@ -320,10 +321,10 @@ def request_llm( prompt: str, func_tool_manager=None, session_id: str = None, - image_urls: list[str] = [], - contexts: list = [], + image_urls: list[str] | None = None, + contexts: list | None = None, system_prompt: str = "", - conversation: Conversation = None, + conversation: Conversation | None = None, ) -> ProviderRequest: """创建一个 LLM 请求。 @@ -346,6 +347,10 @@ def request_llm( conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。 """ + if image_urls is None: + image_urls = [] + if contexts is None: + contexts = [] if len(contexts) > 0 and conversation: conversation = None @@ -389,7 +394,7 @@ async def react(self, emoji: str): """ await self.send(MessageChain([Plain(emoji)])) - async def get_group(self, group_id: str = None, **kwargs) -> Group | None: + async def get_group(self, group_id: str | None = None, **kwargs) -> Group | None: """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 适配情况: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index dcc70b0f2..0ada18506 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -9,7 +9,7 @@ @dataclass class MessageMember: user_id: str # 发送者id - nickname: str = None + nickname: str | None = None def __str__(self): # 使用 f-string 来构建返回的字符串表示形式 @@ -23,15 +23,15 @@ def __str__(self): class Group: group_id: str """群号""" - group_name: str = None + group_name: str | None = None """群名称""" - group_avatar: str = None + group_avatar: str | None = None """群头像""" - group_owner: str = None + group_owner: str | None = None """群主 id""" - group_admins: list[str] = None + group_admins: list[str] | None = None """群管理员 id""" - members: list[MessageMember] = None + members: list[MessageMember] | None = None """所有群成员""" def __str__(self): diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index 62240b621..bca5300b8 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -13,7 +13,7 @@ class MessageSession: """平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。""" message_type: MessageType session_id: str - platform_id: str = None + platform_id: str | None = None def __str__(self): return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 37f8527a1..d75811245 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -7,12 +7,12 @@ class PlatformMetadata: """平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" description: str """平台的描述""" - id: str = None + id: str | None = None """平台的唯一标识符,用于配置中识别特定平台""" - default_config_tmpl: dict = None + default_config_tmpl: dict | None = None """平台的默认配置模板""" - adapter_display_name: str = None + adapter_display_name: str | None = None """显示在 WebUI 配置页中的平台名称,如空则是 name""" - logo_path: str = None + logo_path: str | None = None """平台适配器的 logo 文件路径(相对于插件目录)""" diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 4cd62ede0..0c6267492 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -11,9 +11,9 @@ def register_platform_adapter( adapter_name: str, desc: str, - default_config_tmpl: dict = None, - adapter_display_name: str = None, - logo_path: str = None, + default_config_tmpl: dict | None = None, + adapter_display_name: str | None = None, + logo_path: str | None = None, ): """用于注册平台适配器的带参装饰器。 diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 5fe605a74..ce8fd56df 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -94,7 +94,7 @@ async def send_message( message_chain: MessageChain, event: Event | None = None, is_group: bool = False, - session_id: str = None, + session_id: str | None = None, ): """发送消息至 QQ 协议端(aiocqhttp)。 diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index bb9c4474a..85bb2fead 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -315,6 +315,8 @@ async def _convert_handle_message_event( abm.message.append(a) elif t == "at": first_at_self_processed = False + # Accumulate @ mention text for efficient concatenation + at_parts = [] for m in m_group: try: @@ -354,13 +356,15 @@ async def _convert_handle_message_event( first_at_self_processed = True else: # 非第一个@机器人或@其他用户,添加到message_str - message_str += f" @{nickname}({m['data']['qq']}) " + at_parts.append(f" @{nickname}({m['data']['qq']}) ") else: abm.message.append(At(qq=str(m["data"]["qq"]), name="")) except ActionFailed as e: logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。") except BaseException as e: logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。") + + message_str += "".join(at_parts) else: for m in m_group: a = ComponentTypes[t](**m["data"]) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 0a2982ce8..5d29e3429 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -14,7 +14,7 @@ class DiscordBotClient(discord.Bot): """Discord客户端封装""" - def __init__(self, token: str, proxy: str = None): + def __init__(self, token: str, proxy: str | None = None): self.token = token self.proxy = proxy diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index dbddd1686..d3e69e763 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -11,14 +11,14 @@ class DiscordEmbed(BaseMessageComponent): def __init__( self, - title: str = None, - description: str = None, - color: int = None, - url: str = None, - thumbnail: str = None, - image: str = None, - footer: str = None, - fields: list[dict] = None, + title: str | None = None, + description: str | None = None, + color: int | None = None, + url: str | None = None, + thumbnail: str | None = None, + image: str | None = None, + footer: str | None = None, + fields: list[dict] | None = None, ): self.title = title self.description = description @@ -66,10 +66,10 @@ class DiscordButton(BaseMessageComponent): def __init__( self, label: str, - custom_id: str = None, + custom_id: str | None = None, style: str = "primary", - emoji: str = None, - url: str = None, + emoji: str | None = None, + url: str | None = None, disabled: bool = False, ): self.label = label diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 5dc1fd8a6..276d3dce5 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -386,7 +386,9 @@ async def _collect_and_register_commands(self): def _create_dynamic_callback(self, cmd_name: str): """为每个指令动态创建一个异步回调函数""" - async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None): + async def dynamic_callback( + ctx: discord.ApplicationContext, params: str | None = None + ): # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter logger.debug(f"[Discord] 回调函数触发: {cmd_name}") logger.debug(f"[Discord] 回调函数参数: {ctx}") diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 3c701c4ce..06f921bc4 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -113,18 +113,18 @@ async def _parse_to_discord( message: MessageChain, ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" - content = "" + content_parts = [] files = [] view = None embeds = [] reference_message_id = None for i in message.chain: # 遍历消息链 if isinstance(i, Plain): # 如果是文字类型的 - content += i.text + content_parts.append(i.text) elif isinstance(i, Reply): reference_message_id = i.id elif isinstance(i, At): - content += f"<@{i.qq}>" + content_parts.append(f"<@{i.qq}>") elif isinstance(i, Image): logger.debug(f"[Discord] 开始处理 Image 组件: {i}") try: @@ -238,6 +238,7 @@ async def _parse_to_discord( else: logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}") + content = "".join(content_parts) if len(content) > 2000: logger.warning("[Discord] 消息内容超过2000字符,将被截断。") content = content[:2000] diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index f3c2ef0e5..fe1496644 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -76,7 +76,7 @@ async def send_streaming(self, generator, use_fallback: bool = False): return await super().send_streaming(generator, use_fallback) - async def _post_send(self, stream: dict = None): + async def _post_send(self, stream: dict | None = None): if not self.send_buffer: return None @@ -265,17 +265,17 @@ async def post_c2c_message( self, openid: str, msg_type: int = 0, - content: str = None, - embed: message.Embed = None, - ark: message.Ark = None, - message_reference: message.Reference = None, - media: message.Media = None, - msg_id: str = None, + content: str | None = None, + embed: message.Embed | None = None, + ark: message.Ark | None = None, + message_reference: message.Reference | None = None, + media: message.Media | None = None, + msg_id: str | None = None, msg_seq: str = 1, - event_id: str = None, - markdown: message.MarkdownPayload = None, - keyboard: message.Keyboard = None, - stream: dict = None, + event_id: str | None = None, + markdown: message.MarkdownPayload | None = None, + keyboard: message.Keyboard | None = None, + stream: dict | None = None, ) -> message.Message: payload = locals() payload.pop("self", None) diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 9f21656ed..6c74a4713 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -222,39 +222,41 @@ def _parse_blocks(self, blocks: list) -> list: if element.get("type") == "rich_text_section": # 处理富文本段落 section_elements = element.get("elements", []) - text_content = "" - + text_parts = [] for section_element in section_elements: element_type = section_element.get("type", "") if element_type == "text": # 普通文本 - text_content += section_element.get("text", "") + text_parts.append(section_element.get("text", "")) elif element_type == "user": # @用户提及 user_id = section_element.get("user_id", "") if user_id: # 将之前的文本内容先添加到组件中 + text_content = "".join(text_parts) if text_content.strip(): message_components.append( Plain(text=text_content), ) - text_content = "" + text_parts = [] # 添加@提及组件 message_components.append(At(qq=user_id, name="")) elif element_type == "channel": # #频道提及 channel_id = section_element.get("channel_id", "") - text_content += f"#{channel_id}" + text_parts.append(f"#{channel_id}") elif element_type == "link": # 链接 url = section_element.get("url", "") link_text = section_element.get("text", url) - text_content += f"[{link_text}]({url})" + text_parts.append(f"[{link_text}]({url})") elif element_type == "emoji": # 表情符号 emoji_name = section_element.get("name", "") - text_content += f":{emoji_name}:" + text_parts.append(f":{emoji_name}:") + + text_content = "".join(text_parts) if text_content.strip(): message_components.append(Plain(text=text_content)) diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 21c1b0fed..c918abbac 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -148,14 +148,15 @@ async def send(self, message: MessageChain): ) except Exception: # 如果块发送失败,尝试只发送文本 - fallback_text = "" + parts = [] for segment in message.chain: if isinstance(segment, Plain): - fallback_text += segment.text + parts.append(segment.text) elif isinstance(segment, File): - fallback_text += f" [文件: {segment.name}] " + parts.append(f" [文件: {segment.name}] ") elif isinstance(segment, Image): - fallback_text += " [图片] " + parts.append(" [图片] ") + fallback_text = "".join(parts) if self.get_group_id(): await self.web_client.chat_postMessage( diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py index d372211c9..09924edb6 100644 --- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -18,7 +18,7 @@ def __init__( is_private_chat: bool = False, cached_texts=None, cached_images=None, - raw_message: dict = None, + raw_message: dict | None = None, downloader=None, ): self._xml = None diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index fa9a9733c..0e079e893 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -11,8 +11,8 @@ async def insert( platform_id: str, user_id: str, content: list[dict], # TODO: parse from message chain - sender_id: str = None, - sender_name: str = None, + sender_id: str | None = None, + sender_name: str | None = None, ): """Insert a new platform message history record.""" await self.db.insert_platform_message_history( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 28dc63f72..2f1e84419 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -4,15 +4,17 @@ from dataclasses import dataclass, field from typing import Any -from anthropic.types import Message +from anthropic.types import Message as AnthropicMessage from google.genai.types import GenerateContentResponse from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) import astrbot.core.message.components as Comp from astrbot import logger +from astrbot.core.agent.message import ( + AssistantMessageSegment, + ToolCall, + ToolCallMessageSegment, +) from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain @@ -32,9 +34,9 @@ class ProviderMetaData: type: str """提供商适配器名称,如 openai, ollama""" desc: str = "" - """提供商适配器描述.""" + """提供商适配器描述""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: type | None = None + cls_type: Any = None default_config_tmpl: dict | None = None """平台的默认配置模板""" @@ -42,44 +44,6 @@ class ProviderMetaData: """显示在 WebUI 配置页中的提供商名称,如空则是 type""" -@dataclass -class ToolCallMessageSegment: - """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" - - tool_call_id: str - content: str - role: str = "tool" - - def to_dict(self): - return { - "tool_call_id": self.tool_call_id, - "content": self.content, - "role": self.role, - } - - -@dataclass -class AssistantMessageSegment: - """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" - - content: str | None = None - tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list) - role: str = "assistant" - - def to_dict(self): - ret: dict[str, str | list[dict]] = { - "role": self.role, - } - if self.content: - ret["content"] = self.content - if self.tool_calls: - tool_calls_dict = [ - tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls - ] - ret["tool_calls"] = tool_calls_dict - return ret - - @dataclass class ToolCallsResult: """工具调用结果""" @@ -91,8 +55,8 @@ class ToolCallsResult: def to_openai_messages(self) -> list[dict]: ret = [ - self.tool_calls_info.to_dict(), - *[item.to_dict() for item in self.tool_calls_result], + self.tool_calls_info.model_dump(), + *[item.model_dump() for item in self.tool_calls_result], ] return ret @@ -108,16 +72,16 @@ class ProviderRequest: func_tool: ToolSet | None = None """可用的函数工具""" contexts: list[dict] = field(default_factory=list) - """上下文。格式与 openai 的上下文格式一致: + """ + OpenAI 格式上下文列表。 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages """ system_prompt: str = "" """系统提示词""" conversation: Conversation | None = None - + """关联的对话对象""" tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" - model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" @@ -227,7 +191,9 @@ class LLMResponse: tools_call_ids: list[str] = field(default_factory=list) """工具调用 ID""" - raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None + raw_completion: ( + ChatCompletion | GenerateContentResponse | AnthropicMessage | None + ) = None _new_record: dict[str, Any] | None = None _completion_text: str = "" @@ -243,7 +209,10 @@ def __init__( tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, - raw_completion: ChatCompletion | None = None, + raw_completion: ChatCompletion + | GenerateContentResponse + | AnthropicMessage + | None = None, _new_record: dict[str, Any] | None = None, is_chunk: bool = False, ): @@ -294,7 +263,7 @@ def completion_text(self, value): self._completion_text = value def to_openai_tool_calls(self) -> list[dict]: - """将工具调用信息转换为 OpenAI 格式""" + """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): ret.append( @@ -309,6 +278,21 @@ def to_openai_tool_calls(self) -> list[dict]: ) return ret + def to_openai_to_calls_model(self) -> list[ToolCall]: + """The same as to_openai_tool_calls but return pydantic model.""" + ret = [] + for idx, tool_call_arg in enumerate(self.tools_call_args): + ret.append( + ToolCall( + id=self.tools_call_ids[idx], + function=ToolCall.FunctionBody( + name=self.tools_call_name[idx], + arguments=json.dumps(tool_call_arg), + ), + ), + ) + return ret + @dataclass class RerankResult: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index b3ef1ed5c..36aad2ae9 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -10,7 +10,7 @@ from astrbot import logger from astrbot.core import sp -from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.agent.mcp_client import MCPClient, MCPTool from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -254,18 +254,15 @@ async def _init_mcp_client(self, name: str, config: dict) -> None: self.func_list = [ f for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] # 将 MCP 工具转换为 FuncTool 并添加到 func_list for tool in mcp_client.tools: - func_tool = FuncTool( - name=tool.name, - parameters=tool.inputSchema, - description=tool.description, - origin="mcp", - mcp_server_name=name, + func_tool = MCPTool( + mcp_tool=tool, mcp_client=mcp_client, + mcp_server_name=name, ) self.func_list.append(func_tool) @@ -284,7 +281,7 @@ async def _terminate_mcp_client(self, name: str) -> None: self.func_list = [ f for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] logger.info(f"已关闭 MCP 服务 {name}") @@ -374,7 +371,7 @@ async def disable_mcp_server( self.func_list = [ f for f in self.func_list - if f.origin != "mcp" or f.mcp_server_name != name + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] else: running_events = [ @@ -388,7 +385,9 @@ async def disable_mcp_server( finally: self.mcp_client_event.clear() self.mcp_client_dict.clear() - self.func_list = [f for f in self.func_list if f.origin != "mcp"] + self.func_list = [ + f for f in self.func_list if not isinstance(f, MCPTool) + ] def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: """获得 OpenAI API 风格的**已经激活**的工具描述""" diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 23abfcfc0..7ab8f00ba 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -3,6 +3,7 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass +from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Personality from astrbot.core.provider.entities import ( @@ -23,24 +24,28 @@ class ProviderMeta: class AbstractProvider(abc.ABC): + """Provider Abstract Class""" + def __init__(self, provider_config: dict) -> None: super().__init__() self.model_name = "" self.provider_config = provider_config def set_model(self, model_name: str): - """设置当前使用的模型名称""" + """Set the current model name""" self.model_name = model_name def get_model(self) -> str: - """获得当前使用的模型名称""" + """Get the current model name""" return self.model_name def meta(self) -> ProviderMeta: - """获取 Provider 的元数据""" + """Get the provider metadata""" provider_type_name = self.provider_config["type"] meta_data = provider_cls_map.get(provider_type_name) provider_type = meta_data.provider_type if meta_data else None + if provider_type is None: + raise ValueError(f"Cannot find provider type: {provider_type_name}") return ProviderMeta( id=self.provider_config["id"], model=self.get_model(), @@ -50,6 +55,8 @@ def meta(self) -> ProviderMeta: class Provider(AbstractProvider): + """Chat Provider""" + def __init__( self, provider_config: dict, @@ -84,24 +91,24 @@ async def get_models(self) -> list[str]: @abc.abstractmethod async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: ToolSet = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + prompt: str | None = None, + session_id: str | None = None, + image_urls: list[str] | None = None, + func_tool: ToolSet | None = None, + contexts: list[Message] | list[dict] | None = None, + system_prompt: str | None = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 Args: - prompt: 提示词 + prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 - tools: Function-calling 工具 - contexts: 上下文 + tools: tool set + contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 @@ -114,24 +121,24 @@ async def text_chat( async def text_chat_stream( self, - prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: ToolSet = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + prompt: str | None = None, + session_id: str | None = None, + image_urls: list[str] | None = None, + func_tool: ToolSet | None = None, + contexts: list[Message] | list[dict] | None = None, + system_prompt: str | None = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 Args: - prompt: 提示词 + prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 - tools: Function-calling 工具 - contexts: 上下文 + tools: tool set + contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 @@ -140,6 +147,7 @@ async def text_chat_stream( - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ + ... async def pop_record(self, context: list): """弹出 context 第一条非系统提示词对话记录""" @@ -156,6 +164,22 @@ async def pop_record(self, context: list): for idx in reversed(indexs_to_pop): context.pop(idx) + def _ensure_message_to_dicts( + self, + messages: list[dict] | list[Message] | None, + ) -> list[dict]: + """Convert a list of Message objects to a list of dictionaries.""" + if not messages: + return [] + dicts: list[dict] = [] + for message in messages: + if isinstance(message, Message): + dicts.append(message.model_dump()) + else: + dicts.append(message) + + return dicts + class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index eb8c72aea..1aead54df 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -15,8 +15,8 @@ def register_provider_adapter( provider_type_name: str, desc: str, provider_type: ProviderType = ProviderType.CHAT_COMPLETION, - default_config_tmpl: dict = None, - provider_display_name: str = None, + default_config_tmpl: dict | None = None, + provider_display_name: str | None = None, ): """用于注册平台适配器的带参装饰器""" diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 6f292f076..77c85cef4 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -243,7 +243,7 @@ async def _query_stream( async def text_chat( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -255,8 +255,13 @@ async def text_chat( ) -> LLMResponse: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) + if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -306,8 +311,12 @@ async def text_chat_stream( ): if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py index caee65020..23a8b3b76 100644 --- a/astrbot/core/provider/sources/coze_source.py +++ b/astrbot/core/provider/sources/coze_source.py @@ -331,6 +331,7 @@ async def text_chat_stream( }, ) + contexts = self._ensure_message_to_dicts(contexts) if not self.auto_save_history and contexts: # 如果关闭了自动保存历史,传入上下文 for ctx in contexts: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 92613dc1a..9b262c001 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -66,13 +66,15 @@ async def text_chat( self, prompt: str, session_id=None, - image_urls=[], + image_urls=None, func_tool=None, contexts=None, system_prompt=None, model=None, **kwargs, ) -> LLMResponse: + if image_urls is None: + image_urls = [] if contexts is None: contexts = [] # 获得会话变量 @@ -144,14 +146,15 @@ async def text_chat( # RAG 引用脚标格式化 output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) if self.output_reference and response.output.get("doc_references", None): - ref_str = "" + ref_parts = [] for ref in response.output.get("doc_references", []) or []: ref_title = ( ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") ) - ref_str += f"{ref['index_id']}. {ref_title}\n" + ref_parts.append(f"{ref['index_id']}. {ref_title}\n") + ref_str = "".join(ref_parts) output_text += f"\n\n回答来源:\n{ref_str}" llm_response = LLMResponse("assistant") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index f9eef2e92..c3c9253a5 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -572,7 +572,7 @@ async def _query_stream( async def text_chat( self, - prompt: str, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -584,8 +584,12 @@ async def text_chat( ) -> LLMResponse: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -621,7 +625,7 @@ async def text_chat( async def text_chat_stream( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -633,8 +637,12 @@ async def text_chat_stream( ) -> AsyncGenerator[LLMResponse, None]: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -726,7 +734,6 @@ async def encode_image_bs64(self, image_url: str) -> str: with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return "" async def terminate(self): logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 1020075af..076afc40f 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -14,9 +14,10 @@ import astrbot.core.message.components as Comp from astrbot import logger from astrbot.api.provider import Provider +from astrbot.core.agent.message import Message +from astrbot.core.agent.tool import ToolSet from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, ToolCallsResult -from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url from ..register import register_provider_adapter @@ -102,7 +103,7 @@ async def get_models(self): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") - async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse: + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -153,7 +154,7 @@ async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse: async def _query_stream( self, payloads: dict, - tools: ToolSet, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" if tools: @@ -212,7 +213,9 @@ async def _query_stream( yield llm_response - async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet): + async def parse_openai_completion( + self, completion: ChatCompletion, tools: ToolSet | None + ) -> LLMResponse: """解析 OpenAI 的 ChatCompletion 响应""" llm_response = LLMResponse("assistant") @@ -225,7 +228,7 @@ async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolS completion_text = str(choice.message.content).strip() llm_response.result_chain = MessageChain().message(completion_text) - if choice.message.tool_calls: + if choice.message.tool_calls and tools is not None: # tools call (function calling) args_ls = [] func_name_ls = [] @@ -267,9 +270,9 @@ async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolS async def _prepare_chat_payload( self, - prompt: str, + prompt: str | None, image_urls: list[str] | None = None, - contexts: list | None = None, + contexts: list[dict] | list[Message] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, @@ -278,8 +281,12 @@ async def _prepare_chat_payload( """准备聊天所需的有效载荷和上下文""" if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -310,7 +317,7 @@ async def _handle_api_error( e: Exception, payloads: dict, context_query: list, - func_tool: ToolSet, + func_tool: ToolSet | None, chosen_key: str, available_api_keys: list[str], retry_cnt: int, @@ -390,7 +397,7 @@ async def _handle_api_error( async def text_chat( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -459,7 +466,7 @@ async def text_chat( async def text_chat_stream( self, - prompt: str, + prompt=None, session_id=None, image_urls=None, func_tool=None, diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 620e7e907..1a5bc53d9 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,3 +1,4 @@ +import logging from asyncio import Queue from collections.abc import Awaitable, Callable from typing import Any @@ -35,6 +36,8 @@ from .star import StarMetadata, star_map, star_registry from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry +logger = logging.getLogger("astrbot") + class Context: """暴露给插件的接口上下文。""" @@ -255,9 +258,44 @@ async def send_message( def add_llm_tools(self, *tools: FunctionTool) -> None: """添加 LLM 工具。""" + tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list} + module_path = "" for tool in tools: + if not module_path: + _parts = [] + module_part = tool.__module__.split(".") + flags = ["packages", "plugins"] + for i, part in enumerate(module_part): + _parts.append(part) + if part in flags and i + 1 < len(module_part): + _parts.append(module_part[i + 1]) + break + tool.handler_module_path = ".".join(_parts) + module_path = tool.handler_module_path + else: + tool.handler_module_path = module_path + logger.info( + f"plugin(module_path {module_path}) added LLM tool: {tool.name}" + ) + + if tool.name in tool_name: + logger.warning("替换已存在的 LLM 工具: " + tool.name) + self.provider_manager.llm_tools.remove_func(tool.name) self.provider_manager.llm_tools.func_list.append(tool) + def register_web_api( + self, + route: str, + view_handler: Awaitable, + methods: list, + desc: str, + ): + for idx, api in enumerate(self.registered_web_apis): + if api[0] == route and methods == api[2]: + self.registered_web_apis[idx] = (route, view_handler, methods, desc) + return + self.registered_web_apis.append((route, view_handler, methods, desc)) + """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """ @@ -269,7 +307,7 @@ def register_llm_tool( desc: str, func_obj: Callable[..., Awaitable[Any]], ) -> None: - """为函数调用(function-calling / tools-use)添加工具。 + """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @@ -291,7 +329,7 @@ def register_llm_tool( self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: - """删除一个函数调用工具。如果再要启用,需要重新注册。""" + """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。""" self.provider_manager.llm_tools.remove_func(name) def register_commands( @@ -333,18 +371,5 @@ def register_commands( star_handlers_registry.append(md) def register_task(self, task: Awaitable, desc: str): - """注册一个异步任务。""" + """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task) - - def register_web_api( - self, - route: str, - view_handler: Awaitable, - methods: list, - desc: str, - ): - for idx, api in enumerate(self.registered_web_apis): - if api[0] == route and methods == api[2]: - self.registered_web_apis[idx] = (route, view_handler, methods, desc) - return - self.registered_web_apis.append((route, view_handler, methods, desc)) diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 6e0283a0e..2a9868fdc 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -36,11 +36,13 @@ def __init__( command_name: str, alias: set | None = None, handler_md: StarHandlerMetadata | None = None, - parent_command_names: list[str] = [""], + parent_command_names: list[str] | None = None, ): self.command_name = command_name self.alias = alias if alias else set() - self.parent_command_names = parent_command_names + self.parent_command_names = ( + parent_command_names if parent_command_names is not None else [""] + ) if handler_md: self.init_handler_md(handler_md) self.custom_filter_list: list[CustomFilter] = [] @@ -49,15 +51,15 @@ def __init__( self._cmpl_cmd_names: list | None = None def print_types(self): - result = "" + parts = [] for k, v in self.handler_params.items(): if isinstance(v, type): - result += f"{k}({v.__name__})," + parts.append(f"{k}({v.__name__}),") elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union: - result += f"{k}({v})," + parts.append(f"{k}({v}),") else: - result += f"{k}({type(v).__name__})={v}," - result = result.rstrip(",") + parts.append(f"{k}({type(v).__name__})={v},") + result = "".join(parts).rstrip(",") return result def init_handler_md(self, handle_md: StarHandlerMetadata): diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 0f5c19ec5..e1c2efb22 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -66,7 +66,7 @@ def print_cmd_tree( event: AstrMessageEvent | None = None, cfg: AstrBotConfig | None = None, ) -> str: - result = "" + parts = [] for sub_filter in sub_command_filters: if isinstance(sub_filter, CommandFilter): custom_filter_pass = True @@ -74,31 +74,32 @@ def print_cmd_tree( custom_filter_pass = sub_filter.custom_filter_ok(event, cfg) if custom_filter_pass: cmd_th = sub_filter.print_types() - result += f"{prefix}├── {sub_filter.command_name}" + line = f"{prefix}├── {sub_filter.command_name}" if cmd_th: - result += f" ({cmd_th})" + line += f" ({cmd_th})" else: - result += " (无参数指令)" + line += " (无参数指令)" if sub_filter.handler_md and sub_filter.handler_md.desc: - result += f": {sub_filter.handler_md.desc}" + line += f": {sub_filter.handler_md.desc}" - result += "\n" + parts.append(line + "\n") elif isinstance(sub_filter, CommandGroupFilter): custom_filter_pass = True if event and cfg: custom_filter_pass = sub_filter.custom_filter_ok(event, cfg) if custom_filter_pass: - result += f"{prefix}├── {sub_filter.group_name}" - result += "\n" - result += sub_filter.print_cmd_tree( - sub_filter.sub_command_filters, - prefix + "│ ", - event=event, - cfg=cfg, + parts.append(f"{prefix}├── {sub_filter.group_name}\n") + parts.append( + sub_filter.print_cmd_tree( + sub_filter.sub_command_filters, + prefix + "│ ", + event=event, + cfg=cfg, + ) ) - return result + return "".join(parts) def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: for custom_filter in self.custom_filter_list: diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 4966735d5..70a0f73f8 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -830,7 +830,13 @@ async def turn_off_plugin(self, plugin_name: str): # 禁用插件启用的 llm_tool for func_tool in llm_tools.func_list: - if func_tool.handler_module_path == plugin.module_path: + mp = func_tool.handler_module_path + if ( + plugin.module_path + and mp + and plugin.module_path.startswith(mp) + and not mp.endswith(("packages", "data.plugins")) + ): func_tool.active = False if func_tool.name not in inactivated_llm_tools: inactivated_llm_tools.append(func_tool.name) @@ -873,8 +879,12 @@ async def turn_on_plugin(self, plugin_name: str): # 启用插件启用的 llm_tool for func_tool in llm_tools.func_list: + mp = func_tool.handler_module_path if ( - func_tool.handler_module_path == plugin.module_path + plugin.module_path + and mp + and plugin.module_path.startswith(mp) + and not mp.endswith(("packages", "data.plugins")) and func_tool.name in inactivated_llm_tools ): inactivated_llm_tools.remove(func_tool.name) @@ -883,8 +893,6 @@ async def turn_on_plugin(self, plugin_name: str): await self.reload(plugin_name) - # plugin.activated = True - async def install_plugin_from_file(self, zip_file_path: str): dir_name = os.path.basename(zip_file_path).replace(".zip", "") dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower() diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 2500e69a5..ea8ff9dff 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -46,9 +46,11 @@ async def chat_messages( user: str, response_mode: str = "streaming", conversation_id: str = "", - files: list[dict[str, Any]] = [], + files: list[dict[str, Any]] | None = None, timeout: float = 60, ) -> AsyncGenerator[dict[str, Any], None]: + if files is None: + files = [] url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") @@ -73,9 +75,11 @@ async def workflow_run( inputs: dict, user: str, response_mode: str = "streaming", - files: list[dict[str, Any]] = [], + files: list[dict[str, Any]] | None = None, timeout: float = 60, ): + if files is None: + files = [] url = f"{self.api_base}/workflows/run" payload = locals() payload.pop("self") diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index bd0bea920..03549dc97 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -77,8 +77,8 @@ def save_temp_img(img: Image.Image | str) -> str: async def download_image_by_url( url: str, post: bool = False, - post_data: dict = None, - path=None, + post_data: dict | None = None, + path: str | None = None, ) -> str: """下载图片, 返回 path""" try: diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index abe247146..6076a114a 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -6,15 +6,15 @@ class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str = None): + def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None): self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url async def install( self, - package_name: str = None, - requirements_path: str = None, - mirror: str = None, + package_name: str | None = None, + requirements_path: str | None = None, + mirror: str | None = None, ): args = ["install"] if package_name: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 6954f2d6a..5156e14e5 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -134,10 +134,10 @@ async def chat(self): if not conversation_id: return Response().error("conversation_id is empty").__dict__ - # append user message + # 追加用户消息 webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - # Get conversation-specific queues + # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) new_his = {"type": "user", "message": message} @@ -200,7 +200,7 @@ async def stream(): or not streaming or type == "break" ): - # append bot message + # 追加机器人消息 new_his = {"type": "bot", "message": result_text} await self.platform_history_mgr.insert( platform_id="webchat", @@ -212,7 +212,7 @@ async def stream(): except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) - # Put message to conversation-specific queue + # 将消息放入会话特定的队列 chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) await chat_queue.put( ( diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 59b07e47d..b947d26f2 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -926,7 +926,9 @@ async def _get_plugin_config(self, plugin_name: str): return ret - async def _save_astrbot_configs(self, post_configs: dict, conf_id: str = None): + async def _save_astrbot_configs( + self, post_configs: dict, conf_id: str | None = None + ): try: if conf_id not in self.acm.confs: raise ValueError(f"配置文件 {conf_id} 不存在") diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 775983ef8..84976f2ba 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -105,7 +105,7 @@ async def auth_middleware(self): allowed_endpoints = ["/api/auth/login", "/api/file"] if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return None - # claim jwt + # 声明 JWT token = request.headers.get("Authorization") if not token: r = jsonify(Response().error("未授权").__dict__) @@ -213,11 +213,12 @@ def run(self): raise Exception(f"端口 {port} 已被占用") - display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n" - display += f" ➜ 本地: http://localhost:{port}\n" + parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"] + parts.append(f" ➜ 本地: http://localhost:{port}\n") for ip in ip_addr: - display += f" ➜ 网络: http://{ip}:{port}\n" - display += " ➜ 默认用户名和密码: astrbot\n ✨✨✨\n" + parts.append(f" ➜ 网络: http://{ip}:{port}\n") + parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n") + display = "".join(parts) if not ip_addr: display += ( diff --git a/main.py b/main.py index b453cdfb5..60879f065 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_dashboard, get_dashboard_version -# add parent path to sys.path +# 将父目录添加到 sys.path sys.path.append(Path(__file__).parent.as_posix()) logo_tmpl = r""" @@ -34,7 +34,7 @@ def check_env(): os.makedirs("data/plugins", exist_ok=True) os.makedirs("data/temp", exist_ok=True) - # workaround for issue #181 + # 针对问题 #181 的临时解决方案 mimetypes.add_type("text/javascript", ".js") mimetypes.add_type("text/javascript", ".mjs") mimetypes.add_type("application/json", ".json") @@ -53,7 +53,7 @@ async def check_dashboard_files(webui_dir: str | None = None): if os.path.exists(data_dist_path): v = await get_dashboard_version() if v is not None: - # has file + # 存在文件 if v == f"v{VERSION}": logger.info("WebUI 版本已是最新。") else: @@ -88,16 +88,16 @@ async def check_dashboard_files(webui_dir: str | None = None): check_env() - # start log broker + # 启动日志代理 log_broker = LogBroker() LogManager.set_queue_handler(logger, log_broker) - # check dashboard files + # 检查仪表板文件 webui_dir = asyncio.run(check_dashboard_files(args.webui_dir)) db = db_helper - # print logo + # 打印 logo logger.info(logo_tmpl) core_lifecycle = InitialLoader(db, log_broker) diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py index 82b661773..9538d8f53 100644 --- a/packages/astrbot/commands/conversation.py +++ b/packages/astrbot/commands/conversation.py @@ -134,12 +134,13 @@ async def his(self, message: AstrMessageEvent, page: int = 1): size_per_page, ) - history = "" + parts = [] for context in contexts: if len(context) > 150: context = context[:150] + "..." - history += f"{context}\n" + parts.append(f"{context}\n") + history = "".join(parts) ret = ( f"当前对话历史记录:" f"{history or '无历史记录'}\n\n" @@ -154,7 +155,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): provider = self.context.get_using_provider(message.unified_msg_origin) if provider and provider.meta().type == "dify": """原有的Dify处理逻辑保持不变""" - ret = "Dify 对话列表:\n" + parts = ["Dify 对话列表:\n"] assert isinstance(provider, ProviderDify) data = await provider.api_client.get_chat_convs(message.unified_msg_origin) idx = 1 @@ -162,12 +163,17 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( "%m-%d %H:%M", ) - ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" + parts.append( + f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" + ) idx += 1 if idx == 1: - ret += "没有找到任何对话。" + parts.append("没有找到任何对话。") dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None) - ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。" + parts.append( + f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。" + ) + ret = "".join(parts) message.set_result(MessageEventResult().message(ret)) return @@ -185,7 +191,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): end_idx = start_idx + size_per_page conversations_paged = conversations_all[start_idx:end_idx] - ret = "对话列表:\n---\n" + parts = ["对话列表:\n---\n"] """全局序号从当前页的第一个开始""" global_index = start_idx + 1 @@ -204,10 +210,13 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): ) persona_id = persona["name"] title = _titles.get(conv.cid, "新对话") - ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" + parts.append( + f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" + ) global_index += 1 - ret += "---\n" + parts.append("---\n") + ret = "".join(parts) curr_cid = await self.context.conversation_manager.get_curr_conversation_id( message.unified_msg_origin, ) diff --git a/packages/astrbot/commands/persona.py b/packages/astrbot/commands/persona.py index 53582ce8e..1289cb569 100644 --- a/packages/astrbot/commands/persona.py +++ b/packages/astrbot/commands/persona.py @@ -59,10 +59,11 @@ async def persona(self, message: AstrMessageEvent): .use_t2i(False), ) elif l[1] == "list": - msg = "人格列表:\n" + parts = ["人格列表:\n"] for persona in self.context.provider_manager.personas: - msg += f"- {persona['name']}\n" - msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息" + parts.append(f"- {persona['name']}\n") + parts.append("\n\n*输入 `/persona view 人格名` 查看人格详细信息") + msg = "".join(parts) message.set_result(MessageEventResult().message(msg)) elif l[1] == "view": if len(l) == 2: diff --git a/packages/astrbot/commands/plugin.py b/packages/astrbot/commands/plugin.py index f9092ff97..ab45efc11 100644 --- a/packages/astrbot/commands/plugin.py +++ b/packages/astrbot/commands/plugin.py @@ -13,14 +13,17 @@ def __init__(self, context: star.Context): async def plugin_ls(self, event: AstrMessageEvent): """获取已经安装的插件列表。""" - plugin_list_info = "已加载的插件:\n" + parts = ["已加载的插件:\n"] for plugin in self.context.get_all_stars(): - plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" + line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" if not plugin.activated: - plugin_list_info += " (未启用)" - plugin_list_info += "\n" - if plugin_list_info.strip() == "": + line += " (未启用)" + parts.append(line + "\n") + + if len(parts) == 1: plugin_list_info = "没有加载任何插件。" + else: + plugin_list_info = "".join(parts) plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" event.set_result( @@ -103,14 +106,14 @@ async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): command_names.append(filter_.group_name) if len(command_handlers) > 0: - help_msg += "\n\n🔧 指令列表:\n" + parts = ["\n\n🔧 指令列表:\n"] for i in range(len(command_handlers)): - help_msg += f"- {command_names[i]}" + line = f"- {command_names[i]}" if command_handlers[i].desc: - help_msg += f": {command_handlers[i].desc}" - help_msg += "\n" - - help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。" + line += f": {command_handlers[i].desc}" + parts.append(line + "\n") + parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。") + help_msg += "".join(parts) ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg ret += "更多帮助信息请查看插件仓库 README。" diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py index 750e9de5a..8db7324e4 100644 --- a/packages/astrbot/commands/provider.py +++ b/packages/astrbot/commands/provider.py @@ -19,38 +19,39 @@ async def provider( umo = event.unified_msg_origin if idx is None: - ret = "## 载入的 LLM 提供商\n" + parts = ["## 载入的 LLM 提供商\n"] for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id - ret += f"{idx + 1}. {id_} ({llm.meta().model})" + line = f"{idx + 1}. {id_} ({llm.meta().model})" provider_using = self.context.get_using_provider(umo=umo) if provider_using and provider_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" + line += " (当前使用)" + parts.append(line + "\n") tts_providers = self.context.get_all_tts_providers() if tts_providers: - ret += "\n## 载入的 TTS 提供商\n" + parts.append("\n## 载入的 TTS 提供商\n") for idx, tts in enumerate(tts_providers): id_ = tts.meta().id - ret += f"{idx + 1}. {id_}" + line = f"{idx + 1}. {id_}" tts_using = self.context.get_using_tts_provider(umo=umo) if tts_using and tts_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" + line += " (当前使用)" + parts.append(line + "\n") stt_providers = self.context.get_all_stt_providers() if stt_providers: - ret += "\n## 载入的 STT 提供商\n" + parts.append("\n## 载入的 STT 提供商\n") for idx, stt in enumerate(stt_providers): id_ = stt.meta().id - ret += f"{idx + 1}. {id_}" + line = f"{idx + 1}. {id_}" stt_using = self.context.get_using_stt_provider(umo=umo) if stt_using and stt_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" + line += " (当前使用)" + parts.append(line + "\n") - ret += "\n使用 /provider <序号> 切换 LLM 提供商。" + parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") + ret = "".join(parts) if tts_providers: ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" @@ -128,16 +129,17 @@ async def model_ls( .use_t2i(False), ) return - i = 1 - ret = "下面列出了此模型提供商可用模型:" - for model in models: - ret += f"\n{i}. {model}" - i += 1 + parts = ["下面列出了此模型提供商可用模型:"] + for i, model in enumerate(models, 1): + parts.append(f"\n{i}. {model}") curr_model = prov.get_model() or "无" - ret += f"\n当前模型: [{curr_model}]" + parts.append(f"\n当前模型: [{curr_model}]") + parts.append( + "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + ) - ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + ret = "".join(parts) message.set_result(MessageEventResult().message(ret).use_t2i(False)) elif isinstance(idx_or_name, int): models = [] @@ -180,14 +182,15 @@ async def key(self, message: AstrMessageEvent, index: int | None = None): if index is None: keys_data = prov.get_keys() curr_key = prov.get_current_key() - ret = "Key:" - for i, k in enumerate(keys_data): - ret += f"\n{i + 1}. {k[:8]}" + parts = ["Key:"] + for i, k in enumerate(keys_data, 1): + parts.append(f"\n{i}. {k[:8]}") - ret += f"\n当前 Key: {curr_key[:8]}" - ret += "\n当前模型: " + prov.get_model() - ret += "\n使用 /key 切换 Key。" + parts.append(f"\n当前 Key: {curr_key[:8]}") + parts.append("\n当前模型: " + prov.get_model()) + parts.append("\n使用 /key 切换 Key。") + ret = "".join(parts) message.set_result(MessageEventResult().message(ret).use_t2i(False)) else: keys_data = prov.get_keys() diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index a686d35b2..ceca60ef7 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -119,13 +119,13 @@ async def handle_message(self, event: AstrMessageEvent): if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") - final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: " + parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] cfg = self.cfg(event) for comp in event.get_messages(): if isinstance(comp, Plain): - final_message += f" {comp.text}" + parts.append(f" {comp.text}") elif isinstance(comp, Image): if cfg["image_caption"]: try: @@ -137,11 +137,13 @@ async def handle_message(self, event: AstrMessageEvent): cfg["image_caption_provider_id"], cfg["image_caption_prompt"], ) - final_message += f" [Image: {caption}]" + parts.append(f" [Image: {caption}]") except Exception as e: logger.error(f"获取图片描述失败: {e}") else: - final_message += " [Image]" + parts.append(" [Image]") + + final_message = "".join(parts) logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") self.session_chats[event.unified_msg_origin].append(final_message) if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: diff --git a/packages/reminder/main.py b/packages/reminder/main.py index 0349b9eb4..eaeec8d73 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -203,14 +203,15 @@ async def reminder_ls(self, event: AstrMessageEvent): if not reminders: yield event.plain_result("没有正在进行的待办事项。") else: - reminder_str = "正在进行的待办事项:\n" + parts = ["正在进行的待办事项:\n"] for i, reminder in enumerate(reminders): time_ = reminder.get("datetime", "") if not time_: cron_expr = reminder.get("cron", "") time_ = reminder.get("cron_h", "") + f"(Cron: {cron_expr})" - reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n" - reminder_str += "\n使用 /reminder rm 删除待办事项。\n" + parts.append(f"{i + 1}. {reminder['text']} - {time_}\n") + parts.append("\n使用 /reminder rm 删除待办事项。\n") + reminder_str = "".join(parts) yield event.plain_result(reminder_str) @reminder.command("rm")