From 89098fba7a44b48da00c42d44ca1c5a5a7a8b629 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sat, 1 Nov 2025 17:53:43 +0000
Subject: [PATCH 1/6] Initial plan
From 84316d1bbe30fc6612ac67aafd3667ac5df6d531 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sat, 1 Nov 2025 17:58:20 +0000
Subject: [PATCH 2/6] Fix function tool duplication bug during plugin reload
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
---
astrbot/core/star/star_manager.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index ef50917fe..7949eb52f 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -751,6 +751,18 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
]:
del star_handlers_registry.star_handlers_map[k]
+ # 移除插件注册的函数调用工具
+ removed_tools = []
+ for func_tool in list(llm_tools.func_list):
+ if (
+ func_tool.handler_module_path == plugin_module_path
+ and func_tool.origin != "mcp"
+ ):
+ llm_tools.func_list.remove(func_tool)
+ removed_tools.append(func_tool.name)
+ if removed_tools:
+ logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}")
+
if plugin is None:
return
From 3dd467fe2b25822fe0972783bca7ad5f9646bb9f Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sat, 1 Nov 2025 18:04:22 +0000
Subject: [PATCH 3/6] Improve function tool removal to handle edge cases
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
---
astrbot/core/star/star_manager.py | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 7949eb52f..998151474 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -754,10 +754,21 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
# 移除插件注册的函数调用工具
removed_tools = []
for func_tool in list(llm_tools.func_list):
- if (
- func_tool.handler_module_path == plugin_module_path
- and func_tool.origin != "mcp"
- ):
+ # 检查工具是否属于此插件:
+ # 1. 通过 handler_module_path 匹配(已绑定的工具)
+ # 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的)
+ should_remove = False
+ if func_tool.origin != "mcp":
+ if func_tool.handler_module_path == plugin_module_path:
+ should_remove = True
+ elif (
+ func_tool.handler
+ and hasattr(func_tool.handler, "__module__")
+ and func_tool.handler.__module__ == plugin_module_path
+ ):
+ should_remove = True
+
+ if should_remove:
llm_tools.func_list.remove(func_tool)
removed_tools.append(func_tool.name)
if removed_tools:
From fc742d99b302985117049dd3ac6c3644a34aab01 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sat, 1 Nov 2025 18:06:21 +0000
Subject: [PATCH 4/6] Refactor tool removal to use list comprehension for
consistency
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
---
astrbot/core/star/star_manager.py | 38 ++++++++++++++++---------------
1 file changed, 20 insertions(+), 18 deletions(-)
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 998151474..662c91e17 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -752,25 +752,27 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_handlers_registry.star_handlers_map[k]
# 移除插件注册的函数调用工具
- removed_tools = []
- for func_tool in list(llm_tools.func_list):
- # 检查工具是否属于此插件:
- # 1. 通过 handler_module_path 匹配(已绑定的工具)
- # 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的)
- should_remove = False
- if func_tool.origin != "mcp":
- if func_tool.handler_module_path == plugin_module_path:
- should_remove = True
- elif (
- func_tool.handler
- and hasattr(func_tool.handler, "__module__")
- and func_tool.handler.__module__ == plugin_module_path
- ):
- should_remove = True
+ def _should_remove_tool(func_tool):
+ """检查工具是否属于此插件:
+ 1. 通过 handler_module_path 匹配(已绑定的工具)
+ 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的)
+ """
+ if func_tool.origin == "mcp":
+ return False
+ if func_tool.handler_module_path == plugin_module_path:
+ return True
+ if (
+ func_tool.handler
+ and hasattr(func_tool.handler, "__module__")
+ and func_tool.handler.__module__ == plugin_module_path
+ ):
+ return True
+ return False
- if should_remove:
- llm_tools.func_list.remove(func_tool)
- removed_tools.append(func_tool.name)
+ removed_tools = [f.name for f in llm_tools.func_list if _should_remove_tool(f)]
+ llm_tools.func_list = [
+ f for f in llm_tools.func_list if not _should_remove_tool(f)
+ ]
if removed_tools:
logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}")
From a82c1850be840e2cc936f5ba5dae4c92312058c9 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sat, 1 Nov 2025 18:09:08 +0000
Subject: [PATCH 5/6] Optimize tool removal to single pass and improve
documentation
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
---
astrbot/core/star/star_manager.py | 24 +++++++++++++++++-------
1 file changed, 17 insertions(+), 7 deletions(-)
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 662c91e17..4966735d5 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -753,9 +753,13 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
# 移除插件注册的函数调用工具
def _should_remove_tool(func_tool):
- """检查工具是否属于此插件:
- 1. 通过 handler_module_path 匹配(已绑定的工具)
- 2. 通过 handler.__module__ 匹配(未绑定的工具,例如在 __init__ 中通过 add_llm_tools 添加的)
+ """检查工具是否属于此插件
+
+ 通过两种方式匹配工具:
+ 1. handler_module_path 匹配:适用于已完成绑定的工具(在插件加载后,handler 被 functools.partial 包装)
+ 2. handler.__module__ 匹配:适用于尚未绑定的工具(在 __init__ 中通过 add_llm_tools 添加的工具)
+
+ 注意:MCP 工具由 MCP 客户端管理,不在此移除
"""
if func_tool.origin == "mcp":
return False
@@ -769,10 +773,16 @@ def _should_remove_tool(func_tool):
return True
return False
- removed_tools = [f.name for f in llm_tools.func_list if _should_remove_tool(f)]
- llm_tools.func_list = [
- f for f in llm_tools.func_list if not _should_remove_tool(f)
- ]
+ # 一次遍历完成工具的分类和移除
+ removed_tools = []
+ remaining_tools = []
+ for func_tool in llm_tools.func_list:
+ if _should_remove_tool(func_tool):
+ removed_tools.append(func_tool.name)
+ else:
+ remaining_tools.append(func_tool)
+ llm_tools.func_list = remaining_tools
+
if removed_tools:
logger.info(f"移除了插件 {plugin_name} 的函数调用工具: {removed_tools}")
From 6ebea8d538a204025dbebbbda4f03ddfd6dec2ed Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Sun, 2 Nov 2025 11:02:26 +0000
Subject: [PATCH 6/6] Merge master and resolve conflict in star_manager.py
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
---
astrbot/cli/commands/cmd_conf.py | 2 +-
astrbot/cli/commands/cmd_init.py | 3 +-
astrbot/cli/utils/plugin.py | 8 +-
astrbot/cli/utils/version_comparator.py | 9 +-
astrbot/core/agent/hooks.py | 2 -
astrbot/core/agent/mcp_client.py | 36 ++++
astrbot/core/agent/message.py | 168 ++++++++++++++++++
astrbot/core/agent/run_context.py | 4 +-
.../agent/runners/tool_loop_agent_runner.py | 19 +-
astrbot/core/agent/tool.py | 76 +++++---
astrbot/core/astr_agent_context.py | 3 +-
astrbot/core/config/astrbot_config.py | 4 +-
astrbot/core/conversation_mgr.py | 36 ++++
astrbot/core/db/migration/migra_3_to_4.py | 7 +-
astrbot/core/db/migration/sqlite_v3.py | 10 +-
astrbot/core/persona_mgr.py | 6 +-
.../strategies/baidu_aip.py | 7 +-
astrbot/core/pipeline/context.py | 3 +-
astrbot/core/pipeline/context_utils.py | 65 +++++++
.../process_stage/method/llm_request.py | 113 ++++++++----
.../core/pipeline/result_decorate/stage.py | 5 +-
astrbot/core/platform/astr_message_event.py | 39 ++--
astrbot/core/platform/astrbot_message.py | 12 +-
astrbot/core/platform/message_session.py | 2 +-
astrbot/core/platform/platform_metadata.py | 8 +-
astrbot/core/platform/register.py | 6 +-
.../aiocqhttp/aiocqhttp_message_event.py | 2 +-
.../aiocqhttp/aiocqhttp_platform_adapter.py | 6 +-
.../core/platform/sources/discord/client.py | 2 +-
.../platform/sources/discord/components.py | 22 +--
.../discord/discord_platform_adapter.py | 4 +-
.../sources/discord/discord_platform_event.py | 7 +-
.../qqofficial/qqofficial_message_event.py | 22 +--
.../platform/sources/slack/slack_adapter.py | 16 +-
.../platform/sources/slack/slack_event.py | 9 +-
.../sources/wechatpadpro/xml_data_parser.py | 2 +-
astrbot/core/platform_message_history_mgr.py | 4 +-
astrbot/core/provider/entities.py | 88 ++++-----
astrbot/core/provider/func_tool_manager.py | 21 ++-
astrbot/core/provider/provider.py | 70 +++++---
astrbot/core/provider/register.py | 4 +-
.../core/provider/sources/anthropic_source.py | 19 +-
astrbot/core/provider/sources/coze_source.py | 1 +
.../core/provider/sources/dashscope_source.py | 9 +-
.../core/provider/sources/gemini_source.py | 21 ++-
.../core/provider/sources/openai_source.py | 31 ++--
astrbot/core/star/context.py | 57 ++++--
astrbot/core/star/filter/command.py | 16 +-
astrbot/core/star/filter/command_group.py | 29 +--
astrbot/core/star/star_manager.py | 16 +-
astrbot/core/utils/dify_api_client.py | 8 +-
astrbot/core/utils/io.py | 4 +-
astrbot/core/utils/pip_installer.py | 8 +-
astrbot/dashboard/routes/chat.py | 8 +-
astrbot/dashboard/routes/config.py | 4 +-
astrbot/dashboard/server.py | 11 +-
main.py | 12 +-
packages/astrbot/commands/conversation.py | 27 ++-
packages/astrbot/commands/persona.py | 7 +-
packages/astrbot/commands/plugin.py | 25 +--
packages/astrbot/commands/provider.py | 55 +++---
packages/astrbot/long_term_memory.py | 10 +-
packages/reminder/main.py | 7 +-
63 files changed, 897 insertions(+), 420 deletions(-)
create mode 100644 astrbot/core/agent/message.py
diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py
index 86f78cbaa..a9bd40f00 100644
--- a/astrbot/cli/commands/cmd_conf.py
+++ b/astrbot/cli/commands/cmd_conf.py
@@ -178,7 +178,7 @@ def set_config(key: str, value: str):
@conf.command(name="get")
@click.argument("key", required=False)
-def get_config(key: str = None):
+def get_config(key: str | None = None):
"""获取配置项的值,不提供key则显示所有可配置项"""
config = _load_config()
diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py
index 993995a66..6c0c34b99 100644
--- a/astrbot/cli/commands/cmd_init.py
+++ b/astrbot/cli/commands/cmd_init.py
@@ -1,4 +1,5 @@
import asyncio
+from pathlib import Path
import click
from filelock import FileLock, Timeout
@@ -6,7 +7,7 @@
from ..utils import check_dashboard, get_astrbot_root
-async def initialize_astrbot(astrbot_root) -> None:
+async def initialize_astrbot(astrbot_root: Path) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py
index 55edf4de2..cd76a07c8 100644
--- a/astrbot/cli/utils/plugin.py
+++ b/astrbot/cli/utils/plugin.py
@@ -221,9 +221,9 @@ def manage_plugin(
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件
- if is_update and backup_path.exists():
+ if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
- if is_update:
+ if is_update and backup_path is not None:
shutil.copytree(target_path, backup_path)
try:
@@ -233,13 +233,13 @@ def manage_plugin(
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
- if is_update and backup_path.exists():
+ if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
- if is_update and backup_path.exists():
+ if is_update and backup_path is not None and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py
index 99d71d34d..0aaf8dcab 100644
--- a/astrbot/cli/utils/version_comparator.py
+++ b/astrbot/cli/utils/version_comparator.py
@@ -62,9 +62,12 @@ def split_version(version):
return -1
if isinstance(p1, str) and isinstance(p2, int):
return 1
- if (isinstance(p1, int) and isinstance(p2, int)) or (
- isinstance(p1, str) and isinstance(p2, str)
- ):
+ if isinstance(p1, int) and isinstance(p2, int):
+ if p1 > p2:
+ return 1
+ if p1 < p2:
+ return -1
+ elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:
diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py
index 949ebd3fe..d834240b7 100644
--- a/astrbot/core/agent/hooks.py
+++ b/astrbot/core/agent/hooks.py
@@ -1,4 +1,3 @@
-from dataclasses import dataclass
from typing import Generic
import mcp
@@ -9,7 +8,6 @@
from .run_context import ContextWrapper, TContext
-@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py
index 303973a0d..05980b212 100644
--- a/astrbot/core/agent/mcp_client.py
+++ b/astrbot/core/agent/mcp_client.py
@@ -2,10 +2,15 @@
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
+from typing import Generic
from astrbot import logger
+from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
+from .run_context import TContext
+from .tool import FunctionTool
+
try:
import mcp
from mcp.client.sse import sse_client
@@ -221,3 +226,34 @@ async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
+
+
+class MCPTool(FunctionTool, Generic[TContext]):
+ """A function tool that calls an MCP service."""
+
+ def __init__(
+ self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
+ ):
+ super().__init__(
+ name=mcp_tool.name,
+ description=mcp_tool.description or "",
+ parameters=mcp_tool.inputSchema,
+ )
+ self.mcp_tool = mcp_tool
+ self.mcp_client = mcp_client
+ self.mcp_server_name = mcp_server_name
+
+ async def call(
+ self, context: ContextWrapper[TContext], **kwargs
+ ) -> mcp.types.CallToolResult:
+ session = self.mcp_client.session
+ if not session:
+ raise ValueError("MCP session is not available for MCP function tools.")
+ res = await session.call_tool(
+ name=self.mcp_tool.name,
+ arguments=kwargs,
+ read_timeout_seconds=timedelta(
+ seconds=context.tool_call_timeout,
+ ),
+ )
+ return res
diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py
new file mode 100644
index 000000000..11128c0f6
--- /dev/null
+++ b/astrbot/core/agent/message.py
@@ -0,0 +1,168 @@
+# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
+# License: Apache License 2.0
+
+from typing import Any, ClassVar, Literal, cast
+
+from pydantic import BaseModel, GetCoreSchemaHandler
+from pydantic_core import core_schema
+
+
+class ContentPart(BaseModel):
+ """A part of the content in a message."""
+
+ __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
+
+ type: str
+
+ def __init_subclass__(cls, **kwargs: Any) -> None:
+ super().__init_subclass__(**kwargs)
+
+ invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
+
+ type_value = getattr(cls, "type", None)
+ if type_value is None or not isinstance(type_value, str):
+ raise ValueError(invalid_subclass_error_msg)
+
+ cls.__content_part_registry[type_value] = cls
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: Any, handler: GetCoreSchemaHandler
+ ) -> core_schema.CoreSchema:
+ # If we're dealing with the base ContentPart class, use custom validation
+ if cls.__name__ == "ContentPart":
+
+ def validate_content_part(value: Any) -> Any:
+ # if it's already an instance of a ContentPart subclass, return it
+ if hasattr(value, "__class__") and issubclass(value.__class__, cls):
+ return value
+
+ # if it's a dict with a type field, dispatch to the appropriate subclass
+ if isinstance(value, dict) and "type" in value:
+ type_value: Any | None = cast(dict[str, Any], value).get("type")
+ if not isinstance(type_value, str):
+ raise ValueError(f"Cannot validate {value} as ContentPart")
+ target_class = cls.__content_part_registry[type_value]
+ return target_class.model_validate(value)
+
+ raise ValueError(f"Cannot validate {value} as ContentPart")
+
+ return core_schema.no_info_plain_validator_function(validate_content_part)
+
+ # for subclasses, use the default schema
+ return handler(source_type)
+
+
+class TextPart(ContentPart):
+ """
+ >>> TextPart(text="Hello, world!").model_dump()
+ {'type': 'text', 'text': 'Hello, world!'}
+ """
+
+ type: str = "text"
+ text: str
+
+
+class ImageURLPart(ContentPart):
+ """
+ >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
+ {'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
+ """
+
+ class ImageURL(BaseModel):
+ url: str
+ """The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
+ id: str | None = None
+ """The ID of the image, to allow LLMs to distinguish different images."""
+
+ type: str = "image_url"
+ image_url: str
+
+
+class AudioURLPart(ContentPart):
+ """
+ >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
+ {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
+ """
+
+ class AudioURL(BaseModel):
+ url: str
+ """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
+ id: str | None = None
+ """The ID of the audio, to allow LLMs to distinguish different audios."""
+
+ type: str = "audio_url"
+ audio_url: AudioURL
+
+
+class ToolCall(BaseModel):
+ """
+ A tool call requested by the assistant.
+
+ >>> ToolCall(
+ ... id="123",
+ ... function=ToolCall.FunctionBody(
+ ... name="function",
+ ... arguments="{}"
+ ... ),
+ ... ).model_dump()
+ {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
+ """
+
+ class FunctionBody(BaseModel):
+ name: str
+ arguments: str | None
+
+ type: Literal["function"] = "function"
+
+ id: str
+ """The ID of the tool call."""
+ function: FunctionBody
+ """The function body of the tool call."""
+
+
+class ToolCallPart(BaseModel):
+ """A part of the tool call."""
+
+ arguments_part: str | None = None
+ """A part of the arguments of the tool call."""
+
+
+class Message(BaseModel):
+ """A message in a conversation."""
+
+ role: Literal[
+ "system",
+ "user",
+ "assistant",
+ "tool",
+ ]
+
+ content: str | list[ContentPart]
+ """The content of the message."""
+
+
+class AssistantMessageSegment(Message):
+ """A message segment from the assistant."""
+
+ role: Literal["assistant"] = "assistant"
+ tool_calls: list[ToolCall] | list[dict] | None = None
+
+
+class ToolCallMessageSegment(Message):
+ """A message segment representing a tool call."""
+
+ role: Literal["tool"] = "tool"
+ tool_call_id: str
+
+
+class UserMessageSegment(Message):
+ """A message segment from the user."""
+
+ role: Literal["user"] = "user"
+
+
+class SystemMessageSegment(Message):
+ """A message segment from the system."""
+
+ role: Literal["system"] = "system"
diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py
index 634735ccc..395817679 100644
--- a/astrbot/core/agent/run_context.py
+++ b/astrbot/core/agent/run_context.py
@@ -3,8 +3,6 @@
from typing_extensions import TypeVar
-from astrbot.core.platform.astr_message_event import AstrMessageEvent
-
TContext = TypeVar("TContext", default=Any)
@@ -13,7 +11,7 @@ class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
- event: AstrMessageEvent
+ tool_call_timeout: int = 60 # Default tool call timeout in seconds
NoContext = ContextWrapper[None]
diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py
index cb89fb612..23071d446 100644
--- a/astrbot/core/agent/runners/tool_loop_agent_runner.py
+++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py
@@ -16,15 +16,14 @@
MessageChain,
)
from astrbot.core.provider.entities import (
- AssistantMessageSegment,
LLMResponse,
ProviderRequest,
- ToolCallMessageSegment,
ToolCallsResult,
)
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
+from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -171,8 +170,7 @@ async def step(self):
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
- role="assistant",
- tool_calls=llm_resp.to_openai_tool_calls(),
+ tool_calls=llm_resp.to_openai_to_calls_model(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
@@ -238,7 +236,6 @@ async def _handle_function_tools(
else:
# 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args
- logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
try:
await self.agent_hooks.on_tool_start(
@@ -319,13 +316,11 @@ async def _handle_function_tools(
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
+ # 发送消息逻辑在 ToolExecutor 中处理了。
+ logger.warning(
+ f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
+ )
self._transition_state(AgentState.DONE)
- if res := self.run_context.event.get_result():
- if res.chain:
- yield MessageChain(
- chain=res.chain,
- type="tool_direct_result",
- )
else:
# 不应该出现其他类型
logger.warning(
@@ -341,8 +336,6 @@ async def _handle_function_tools(
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
-
- self.run_context.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py
index 3c36def63..e9738dc0f 100644
--- a/astrbot/core/agent/tool.py
+++ b/astrbot/core/agent/tool.py
@@ -1,52 +1,76 @@
from collections.abc import Awaitable, Callable
-from dataclasses import dataclass
-from typing import Any, Literal
+from typing import Any, Generic
+import jsonschema
+import mcp
from deprecated import deprecated
+from pydantic import model_validator
+from pydantic.dataclasses import dataclass
-from .mcp_client import MCPClient
+from .run_context import ContextWrapper, TContext
+
+ParametersType = dict[str, Any]
@dataclass
-class FunctionTool:
- """A class representing a function tool that can be used in function calling."""
+class ToolSchema:
+ """A class representing the schema of a tool for function calling."""
name: str
- parameters: dict | None = None
- description: str | None = None
+ """The name of the tool."""
+
+ description: str
+ """The description of the tool."""
+
+ parameters: ParametersType
+ """The parameters of the tool, in JSON Schema format."""
+
+ @model_validator(mode="after")
+ def validate_parameters(self) -> "ToolSchema":
+ jsonschema.validate(
+ self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
+ )
+ return self
+
+
+@dataclass
+class FunctionTool(ToolSchema, Generic[TContext]):
+ """A callable tool, for function calling."""
+
handler: Callable[..., Awaitable[Any]] | None = None
- """处理函数, 当 origin 为 mcp 时,这个为空"""
- handler_module_path: str | None = None
- """处理函数的模块路径,当 origin 为 mcp 时,这个为空
+ """a callable that implements the tool's functionality. It should be an async function."""
- 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
+ handler_module_path: str | None = None
+ """
+ The module path of the handler function. This is empty when the origin is mcp.
+ This field must be retained, as the handler will be wrapped in functools.partial during initialization,
+ causing the handler's __module__ to be functools
"""
active: bool = True
- """是否激活"""
-
- origin: Literal["local", "mcp"] = "local"
- """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
-
- # MCP 相关字段
- mcp_server_name: str | None = None
- """MCP 服务名称,当 origin 为 mcp 时有效"""
- mcp_client: MCPClient | None = None
- """MCP 客户端,当 origin 为 mcp 时有效"""
+ """
+ Whether the tool is active. This field is a special field for AstrBot.
+ You can ignore it when integrating with other frameworks.
+ """
def __repr__(self):
- return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
+ return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
def __dict__(self) -> dict[str, Any]:
- """将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
- "origin": self.origin,
- "mcp_server_name": self.mcp_server_name,
}
+ async def call(
+ self, context: ContextWrapper[TContext], **kwargs
+ ) -> str | mcp.types.CallToolResult:
+ """Run the tool with the given arguments. The handler field has priority."""
+ raise NotImplementedError(
+ "FunctionTool.call() must be implemented by subclasses or set a handler."
+ )
+
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -225,7 +249,7 @@ def convert_schema(schema: dict) -> dict:
tools = []
for tool in self.tools:
- d = {
+ d: dict[str, Any] = {
"name": tool.name,
"description": tool.description,
}
diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py
index e21ddb9c6..28b242253 100644
--- a/astrbot/core/astr_agent_context.py
+++ b/astrbot/core/astr_agent_context.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@@ -10,4 +11,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
- tool_call_timeout: int = 60 # Default tool call timeout in seconds
+ event: AstrMessageEvent
diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py
index 68b73cd29..786d29c81 100644
--- a/astrbot/core/config/astrbot_config.py
+++ b/astrbot/core/config/astrbot_config.py
@@ -28,7 +28,7 @@ def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
- schema: dict = None,
+ schema: dict | None = None,
):
super().__init__()
@@ -142,7 +142,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):
return has_new
- def save_config(self, replace_config: dict = None):
+ def save_config(self, replace_config: dict | None = None):
"""将配置写入文件
如果传入 replace_config,则将配置替换为 replace_config
diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py
index 2be406100..287fe03c4 100644
--- a/astrbot/core/conversation_mgr.py
+++ b/astrbot/core/conversation_mgr.py
@@ -8,6 +8,7 @@
from collections.abc import Awaitable, Callable
from astrbot.core import sp
+from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
@@ -319,6 +320,41 @@ async def update_conversation_persona_id(
persona_id=persona_id,
)
+ async def add_message_pair(
+ self,
+ cid: str,
+ user_message: UserMessageSegment | dict,
+ assistant_message: AssistantMessageSegment | dict,
+ ) -> None:
+ """Add a user-assistant message pair to the conversation history.
+
+ Args:
+ cid (str): Conversation ID
+ user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
+ assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
+
+ Raises:
+ Exception: If the conversation with the given ID is not found
+ """
+ conv = await self.db.get_conversation_by_id(cid=cid)
+ if not conv:
+ raise Exception(f"Conversation with id {cid} not found")
+ history = conv.content or []
+ if isinstance(user_message, UserMessageSegment):
+ user_msg_dict = user_message.model_dump()
+ else:
+ user_msg_dict = user_message
+ if isinstance(assistant_message, AssistantMessageSegment):
+ assistant_msg_dict = assistant_message.model_dump()
+ else:
+ assistant_msg_dict = assistant_message
+ history.append(user_msg_dict)
+ history.append(assistant_msg_dict)
+ await self.db.update_conversation(
+ cid=cid,
+ content=history,
+ )
+
async def get_human_readable_context(
self,
unified_msg_origin: str,
diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py
index 13a14c327..a75c60a1b 100644
--- a/astrbot/core/db/migration/migra_3_to_4.py
+++ b/astrbot/core/db/migration/migra_3_to_4.py
@@ -250,14 +250,15 @@ async def migration_persona_data(
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
- mood_prompt = ""
+ parts = []
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
- mood_prompt += f"A: {mood_dialog}\n"
+ parts.append(f"A: {mood_dialog}\n")
else:
- mood_prompt += f"B: {mood_dialog}\n"
+ parts.append(f"B: {mood_dialog}\n")
user_turn = not user_turn
+ mood_prompt = "".join(parts)
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py
index 7b341c316..a301028d1 100644
--- a/astrbot/core/db/migration/sqlite_v3.py
+++ b/astrbot/core/db/migration/sqlite_v3.py
@@ -384,11 +384,11 @@ def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
- platforms: list[str] = None,
- message_types: list[str] = None,
- search_query: str = None,
- exclude_ids: list[str] = None,
- exclude_platforms: list[str] = None,
+ platforms: list[str] | None = None,
+ message_types: list[str] | None = None,
+ search_query: str | None = None,
+ exclude_ids: list[str] | None = None,
+ exclude_platforms: list[str] | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py
index 482b5887c..5d1743ab9 100644
--- a/astrbot/core/persona_mgr.py
+++ b/astrbot/core/persona_mgr.py
@@ -68,9 +68,9 @@ async def delete_persona(self, persona_id: str):
async def update_persona(
self,
persona_id: str,
- system_prompt: str = None,
- begin_dialogs: list[str] = None,
- tools: list[str] = None,
+ system_prompt: str | None = None,
+ begin_dialogs: list[str] | None = None,
+ tools: list[str] | None = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py
index c11822896..bfa82de0e 100644
--- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py
+++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py
@@ -21,8 +21,9 @@ def check(self, content: str) -> tuple[bool, str]:
if "data" not in res:
return False, ""
count = len(res["data"])
- info = f"百度审核服务发现 {count} 处违规:\n"
+ parts = [f"百度审核服务发现 {count} 处违规:\n"]
for i in res["data"]:
- info += f"{i['msg']};\n"
- info += "\n判断结果:" + res["conclusion"]
+ parts.append(f"{i['msg']};\n")
+ parts.append("\n判断结果:" + res["conclusion"])
+ info = "".join(parts)
return False, info
diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py
index a6cd567e0..44186764e 100644
--- a/astrbot/core/pipeline/context.py
+++ b/astrbot/core/pipeline/context.py
@@ -3,7 +3,7 @@
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
-from .context_utils import call_event_hook, call_handler
+from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass
@@ -15,3 +15,4 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
+ call_local_llm_tool = call_local_llm_tool
diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py
index 73d28c5d1..371816b6e 100644
--- a/astrbot/core/pipeline/context_utils.py
+++ b/astrbot/core/pipeline/context_utils.py
@@ -3,6 +3,8 @@
import typing as T
from astrbot import logger
+from astrbot.core.agent.run_context import ContextWrapper
+from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -105,3 +107,66 @@ async def call_event_hook(
return True
return event.is_stopped()
+
+
+async def call_local_llm_tool(
+ context: ContextWrapper[AstrAgentContext],
+ handler: T.Callable[..., T.Awaitable[T.Any]],
+ method_name: str,
+ *args,
+ **kwargs,
+) -> T.AsyncGenerator[T.Any, None]:
+ """执行本地 LLM 工具的处理函数并处理其返回结果"""
+ ready_to_call = None # 一个协程或者异步生成器
+
+ trace_ = None
+
+ event = context.context.event
+
+ try:
+ if method_name == "run" or method_name == "decorator_handler":
+ ready_to_call = handler(event, *args, **kwargs)
+ elif method_name == "call":
+ ready_to_call = handler(context, *args, **kwargs)
+ else:
+ raise ValueError(f"未知的方法名: {method_name}")
+ except ValueError as e:
+ logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
+ except TypeError:
+ logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
+ except Exception as e:
+ trace_ = traceback.format_exc()
+ logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
+
+ if not ready_to_call:
+ return
+
+ if inspect.isasyncgen(ready_to_call):
+ _has_yielded = False
+ try:
+ async for ret in ready_to_call:
+ # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
+ # 返回值只能是 MessageEventResult 或者 None(无返回值)
+ _has_yielded = True
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ # 如果返回值是 MessageEventResult, 设置结果并继续
+ event.set_result(ret)
+ yield
+ else:
+ # 如果返回值是 None, 则不设置结果并继续
+ # 继续执行后续阶段
+ yield ret
+ if not _has_yielded:
+ # 如果这个异步生成器没有执行到 yield 分支
+ yield
+ except Exception as e:
+ logger.error(f"Previous Error: {trace_}")
+ raise e
+ elif inspect.iscoroutine(ready_to_call):
+ # 如果只是一个协程, 直接执行
+ ret = await ready_to_call
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ event.set_result(ret)
+ yield
+ else:
+ yield ret
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index d1cffc43f..03352cc40 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -5,11 +5,14 @@
import json
import traceback
from collections.abc import AsyncGenerator
-from datetime import timedelta
+from typing import Any
+
+from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
+from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
@@ -33,7 +36,7 @@
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
-from ...context import PipelineContext, call_event_hook, call_handler
+from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ..stage import Stage
from ..utils import inject_kb_context
@@ -65,18 +68,16 @@ async def execute(cls, tool, run_context, **tool_args):
yield r
return
- if tool.origin == "local":
- async for r in cls._execute_local(tool, run_context, **tool_args):
+ elif isinstance(tool, MCPTool):
+ async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
- elif tool.origin == "mcp":
- async for r in cls._execute_mcp(tool, run_context, **tool_args):
+ else:
+ async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
- raise Exception(f"Unknown function origin: {tool.origin}")
-
@classmethod
async def _execute_handoff(
cls,
@@ -113,10 +114,13 @@ async def _execute_handoff(
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
+ event=run_context.context.event,
)
+ event = run_context.context.event
+
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
- await run_context.event.send(
+ await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
)
@@ -125,7 +129,7 @@ async def _execute_handoff(
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
- event=run_context.event,
+ tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
@@ -175,25 +179,46 @@ async def _execute_local(
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
- if not run_context.event:
+ event = run_context.context.event
+ if not event:
raise ValueError("Event must be provided for local function tools.")
- # 检查 tool 下有没有 run 方法
- if not tool.handler and not hasattr(tool, "run"):
- raise ValueError("Tool must have a valid handler or 'run' method.")
- awaitable = tool.handler or tool.run
+ is_override_call = False
+ for ty in type(tool).mro():
+ if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
+ logger.debug(f"Found call in: {ty}")
+ is_override_call = True
+ break
- wrapper = call_handler(
- event=run_context.event,
+ # 检查 tool 下有没有 run 方法
+ if not tool.handler and not hasattr(tool, "run") and not is_override_call:
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
+
+ awaitable = None
+ method_name = ""
+ if tool.handler:
+ awaitable = tool.handler
+ method_name = "decorator_handler"
+ elif is_override_call:
+ awaitable = tool.call
+ method_name = "call"
+ elif hasattr(tool, "run"):
+ awaitable = getattr(tool, "run")
+ method_name = "run"
+ if awaitable is None:
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
+
+ wrapper = call_local_llm_tool(
+ context=run_context,
handler=awaitable,
+ method_name=method_name,
**tool_args,
)
- # async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
- timeout=run_context.context.tool_call_timeout,
+ timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
@@ -208,10 +233,24 @@ async def _execute_local(
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
+ if res := run_context.context.event.get_result():
+ if res.chain:
+ try:
+ await event.send(
+ MessageChain(
+ chain=res.chain,
+ type="tool_direct_result",
+ )
+ )
+ except Exception as e:
+ logger.error(
+ f"Tool 直接发送消息失败: {e}",
+ exc_info=True,
+ )
yield None
except asyncio.TimeoutError:
raise Exception(
- f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds.",
+ f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@@ -223,19 +262,7 @@ async def _execute_mcp(
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
- if not tool.mcp_client:
- raise ValueError("MCP client is not available for MCP function tools.")
-
- session = tool.mcp_client.session
- if not session:
- raise ValueError("MCP session is not available for MCP function tools.")
- res = await session.call_tool(
- name=tool.name,
- arguments=tool_args,
- read_timeout_seconds=timedelta(
- seconds=run_context.context.tool_call_timeout,
- ),
- )
+ res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
@@ -245,11 +272,20 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
- run_context.event,
+ run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
+ async def on_tool_end(
+ self,
+ run_context: ContextWrapper[AstrAgentContext],
+ tool: FunctionTool[Any],
+ tool_args: dict | None,
+ tool_result: CallToolResult | None,
+ ):
+ run_context.context.event.clear_result()
+
MAIN_AGENT_HOOKS = MainAgentHooks()
@@ -260,7 +296,7 @@ async def run_agent(
show_tool_use: bool = True,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
- astr_event = agent_runner.run_context.event
+ astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
@@ -513,12 +549,15 @@ async def process(
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
- tool_call_timeout=self.tool_call_timeout,
+ event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
- run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
+ run_context=AgentContextWrapper(
+ context=astr_agent_ctx,
+ tool_call_timeout=self.tool_call_timeout,
+ ),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py
index 08661a367..5dfb52f6f 100644
--- a/astrbot/core/pipeline/result_decorate/stage.py
+++ b/astrbot/core/pipeline/result_decorate/stage.py
@@ -246,12 +246,13 @@ async def process(
elif (
result.use_t2i_ is None and self.ctx.astrbot_config["t2i"]
) or result.use_t2i_:
- plain_str = ""
+ parts = []
for comp in result.chain:
if isinstance(comp, Plain):
- plain_str += "\n\n" + comp.text
+ parts.append("\n\n" + comp.text)
else:
break
+ plain_str = "".join(parts)
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py
index 9605eaffb..6402aeaed 100644
--- a/astrbot/core/platform/astr_message_event.py
+++ b/astrbot/core/platform/astr_message_event.py
@@ -91,33 +91,34 @@ def get_message_str(self) -> str:
return self.message_str
def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str:
- outline = ""
if not chain:
- return outline
+ return ""
+
+ parts = []
for i in chain:
if isinstance(i, Plain):
- outline += i.text
+ parts.append(i.text)
elif isinstance(i, Image):
- outline += "[图片]"
+ parts.append("[图片]")
elif isinstance(i, Face):
- outline += f"[表情:{i.id}]"
+ parts.append(f"[表情:{i.id}]")
elif isinstance(i, At):
- outline += f"[At:{i.qq}]"
+ parts.append(f"[At:{i.qq}]")
elif isinstance(i, AtAll):
- outline += "[At:全体成员]"
+ parts.append("[At:全体成员]")
elif isinstance(i, Forward):
# 转发消息
- outline += "[转发消息]"
+ parts.append("[转发消息]")
elif isinstance(i, Reply):
# 引用回复
if i.message_str:
- outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
+ parts.append(f"[引用消息({i.sender_nickname}: {i.message_str})]")
else:
- outline += "[引用消息]"
+ parts.append("[引用消息]")
else:
- outline += f"[{i.type}]"
- outline += " "
- return outline
+ parts.append(f"[{i.type}]")
+ parts.append(" ")
+ return "".join(parts)
def get_message_outline(self) -> str:
"""获取消息概要。
@@ -320,10 +321,10 @@ def request_llm(
prompt: str,
func_tool_manager=None,
session_id: str = None,
- image_urls: list[str] = [],
- contexts: list = [],
+ image_urls: list[str] | None = None,
+ contexts: list | None = None,
system_prompt: str = "",
- conversation: Conversation = None,
+ conversation: Conversation | None = None,
) -> ProviderRequest:
"""创建一个 LLM 请求。
@@ -346,6 +347,10 @@ def request_llm(
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
"""
+ if image_urls is None:
+ image_urls = []
+ if contexts is None:
+ contexts = []
if len(contexts) > 0 and conversation:
conversation = None
@@ -389,7 +394,7 @@ async def react(self, emoji: str):
"""
await self.send(MessageChain([Plain(emoji)]))
- async def get_group(self, group_id: str = None, **kwargs) -> Group | None:
+ async def get_group(self, group_id: str | None = None, **kwargs) -> Group | None:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
适配情况:
diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py
index dcc70b0f2..0ada18506 100644
--- a/astrbot/core/platform/astrbot_message.py
+++ b/astrbot/core/platform/astrbot_message.py
@@ -9,7 +9,7 @@
@dataclass
class MessageMember:
user_id: str # 发送者id
- nickname: str = None
+ nickname: str | None = None
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
@@ -23,15 +23,15 @@ def __str__(self):
class Group:
group_id: str
"""群号"""
- group_name: str = None
+ group_name: str | None = None
"""群名称"""
- group_avatar: str = None
+ group_avatar: str | None = None
"""群头像"""
- group_owner: str = None
+ group_owner: str | None = None
"""群主 id"""
- group_admins: list[str] = None
+ group_admins: list[str] | None = None
"""群管理员 id"""
- members: list[MessageMember] = None
+ members: list[MessageMember] | None = None
"""所有群成员"""
def __str__(self):
diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py
index 62240b621..bca5300b8 100644
--- a/astrbot/core/platform/message_session.py
+++ b/astrbot/core/platform/message_session.py
@@ -13,7 +13,7 @@ class MessageSession:
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
- platform_id: str = None
+ platform_id: str | None = None
def __str__(self):
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py
index 37f8527a1..d75811245 100644
--- a/astrbot/core/platform/platform_metadata.py
+++ b/astrbot/core/platform/platform_metadata.py
@@ -7,12 +7,12 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
- id: str = None
+ id: str | None = None
"""平台的唯一标识符,用于配置中识别特定平台"""
- default_config_tmpl: dict = None
+ default_config_tmpl: dict | None = None
"""平台的默认配置模板"""
- adapter_display_name: str = None
+ adapter_display_name: str | None = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
- logo_path: str = None
+ logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py
index 4cd62ede0..0c6267492 100644
--- a/astrbot/core/platform/register.py
+++ b/astrbot/core/platform/register.py
@@ -11,9 +11,9 @@
def register_platform_adapter(
adapter_name: str,
desc: str,
- default_config_tmpl: dict = None,
- adapter_display_name: str = None,
- logo_path: str = None,
+ default_config_tmpl: dict | None = None,
+ adapter_display_name: str | None = None,
+ logo_path: str | None = None,
):
"""用于注册平台适配器的带参装饰器。
diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
index 5fe605a74..ce8fd56df 100644
--- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
+++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
@@ -94,7 +94,7 @@ async def send_message(
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
- session_id: str = None,
+ session_id: str | None = None,
):
"""发送消息至 QQ 协议端(aiocqhttp)。
diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
index bb9c4474a..85bb2fead 100644
--- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
+++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
@@ -315,6 +315,8 @@ async def _convert_handle_message_event(
abm.message.append(a)
elif t == "at":
first_at_self_processed = False
+ # Accumulate @ mention text for efficient concatenation
+ at_parts = []
for m in m_group:
try:
@@ -354,13 +356,15 @@ async def _convert_handle_message_event(
first_at_self_processed = True
else:
# 非第一个@机器人或@其他用户,添加到message_str
- message_str += f" @{nickname}({m['data']['qq']}) "
+ at_parts.append(f" @{nickname}({m['data']['qq']}) ")
else:
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
except ActionFailed as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
+
+ message_str += "".join(at_parts)
else:
for m in m_group:
a = ComponentTypes[t](**m["data"])
diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py
index 0a2982ce8..5d29e3429 100644
--- a/astrbot/core/platform/sources/discord/client.py
+++ b/astrbot/core/platform/sources/discord/client.py
@@ -14,7 +14,7 @@
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
- def __init__(self, token: str, proxy: str = None):
+ def __init__(self, token: str, proxy: str | None = None):
self.token = token
self.proxy = proxy
diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py
index dbddd1686..d3e69e763 100644
--- a/astrbot/core/platform/sources/discord/components.py
+++ b/astrbot/core/platform/sources/discord/components.py
@@ -11,14 +11,14 @@ class DiscordEmbed(BaseMessageComponent):
def __init__(
self,
- title: str = None,
- description: str = None,
- color: int = None,
- url: str = None,
- thumbnail: str = None,
- image: str = None,
- footer: str = None,
- fields: list[dict] = None,
+ title: str | None = None,
+ description: str | None = None,
+ color: int | None = None,
+ url: str | None = None,
+ thumbnail: str | None = None,
+ image: str | None = None,
+ footer: str | None = None,
+ fields: list[dict] | None = None,
):
self.title = title
self.description = description
@@ -66,10 +66,10 @@ class DiscordButton(BaseMessageComponent):
def __init__(
self,
label: str,
- custom_id: str = None,
+ custom_id: str | None = None,
style: str = "primary",
- emoji: str = None,
- url: str = None,
+ emoji: str | None = None,
+ url: str | None = None,
disabled: bool = False,
):
self.label = label
diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py
index 5dc1fd8a6..276d3dce5 100644
--- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py
+++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py
@@ -386,7 +386,9 @@ async def _collect_and_register_commands(self):
def _create_dynamic_callback(self, cmd_name: str):
"""为每个指令动态创建一个异步回调函数"""
- async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
+ async def dynamic_callback(
+ ctx: discord.ApplicationContext, params: str | None = None
+ ):
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py
index 3c701c4ce..06f921bc4 100644
--- a/astrbot/core/platform/sources/discord/discord_platform_event.py
+++ b/astrbot/core/platform/sources/discord/discord_platform_event.py
@@ -113,18 +113,18 @@ async def _parse_to_discord(
message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
- content = ""
+ content_parts = []
files = []
view = None
embeds = []
reference_message_id = None
for i in message.chain: # 遍历消息链
if isinstance(i, Plain): # 如果是文字类型的
- content += i.text
+ content_parts.append(i.text)
elif isinstance(i, Reply):
reference_message_id = i.id
elif isinstance(i, At):
- content += f"<@{i.qq}>"
+ content_parts.append(f"<@{i.qq}>")
elif isinstance(i, Image):
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
try:
@@ -238,6 +238,7 @@ async def _parse_to_discord(
else:
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
+ content = "".join(content_parts)
if len(content) > 2000:
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
content = content[:2000]
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
index f3c2ef0e5..fe1496644 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
@@ -76,7 +76,7 @@ async def send_streaming(self, generator, use_fallback: bool = False):
return await super().send_streaming(generator, use_fallback)
- async def _post_send(self, stream: dict = None):
+ async def _post_send(self, stream: dict | None = None):
if not self.send_buffer:
return None
@@ -265,17 +265,17 @@ async def post_c2c_message(
self,
openid: str,
msg_type: int = 0,
- content: str = None,
- embed: message.Embed = None,
- ark: message.Ark = None,
- message_reference: message.Reference = None,
- media: message.Media = None,
- msg_id: str = None,
+ content: str | None = None,
+ embed: message.Embed | None = None,
+ ark: message.Ark | None = None,
+ message_reference: message.Reference | None = None,
+ media: message.Media | None = None,
+ msg_id: str | None = None,
msg_seq: str = 1,
- event_id: str = None,
- markdown: message.MarkdownPayload = None,
- keyboard: message.Keyboard = None,
- stream: dict = None,
+ event_id: str | None = None,
+ markdown: message.MarkdownPayload | None = None,
+ keyboard: message.Keyboard | None = None,
+ stream: dict | None = None,
) -> message.Message:
payload = locals()
payload.pop("self", None)
diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py
index 9f21656ed..6c74a4713 100644
--- a/astrbot/core/platform/sources/slack/slack_adapter.py
+++ b/astrbot/core/platform/sources/slack/slack_adapter.py
@@ -222,39 +222,41 @@ def _parse_blocks(self, blocks: list) -> list:
if element.get("type") == "rich_text_section":
# 处理富文本段落
section_elements = element.get("elements", [])
- text_content = ""
-
+ text_parts = []
for section_element in section_elements:
element_type = section_element.get("type", "")
if element_type == "text":
# 普通文本
- text_content += section_element.get("text", "")
+ text_parts.append(section_element.get("text", ""))
elif element_type == "user":
# @用户提及
user_id = section_element.get("user_id", "")
if user_id:
# 将之前的文本内容先添加到组件中
+ text_content = "".join(text_parts)
if text_content.strip():
message_components.append(
Plain(text=text_content),
)
- text_content = ""
+ text_parts = []
# 添加@提及组件
message_components.append(At(qq=user_id, name=""))
elif element_type == "channel":
# #频道提及
channel_id = section_element.get("channel_id", "")
- text_content += f"#{channel_id}"
+ text_parts.append(f"#{channel_id}")
elif element_type == "link":
# 链接
url = section_element.get("url", "")
link_text = section_element.get("text", url)
- text_content += f"[{link_text}]({url})"
+ text_parts.append(f"[{link_text}]({url})")
elif element_type == "emoji":
# 表情符号
emoji_name = section_element.get("name", "")
- text_content += f":{emoji_name}:"
+ text_parts.append(f":{emoji_name}:")
+
+ text_content = "".join(text_parts)
if text_content.strip():
message_components.append(Plain(text=text_content))
diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py
index 21c1b0fed..c918abbac 100644
--- a/astrbot/core/platform/sources/slack/slack_event.py
+++ b/astrbot/core/platform/sources/slack/slack_event.py
@@ -148,14 +148,15 @@ async def send(self, message: MessageChain):
)
except Exception:
# 如果块发送失败,尝试只发送文本
- fallback_text = ""
+ parts = []
for segment in message.chain:
if isinstance(segment, Plain):
- fallback_text += segment.text
+ parts.append(segment.text)
elif isinstance(segment, File):
- fallback_text += f" [文件: {segment.name}] "
+ parts.append(f" [文件: {segment.name}] ")
elif isinstance(segment, Image):
- fallback_text += " [图片] "
+ parts.append(" [图片] ")
+ fallback_text = "".join(parts)
if self.get_group_id():
await self.web_client.chat_postMessage(
diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
index d372211c9..09924edb6 100644
--- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
+++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
@@ -18,7 +18,7 @@ def __init__(
is_private_chat: bool = False,
cached_texts=None,
cached_images=None,
- raw_message: dict = None,
+ raw_message: dict | None = None,
downloader=None,
):
self._xml = None
diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py
index fa9a9733c..0e079e893 100644
--- a/astrbot/core/platform_message_history_mgr.py
+++ b/astrbot/core/platform_message_history_mgr.py
@@ -11,8 +11,8 @@ async def insert(
platform_id: str,
user_id: str,
content: list[dict], # TODO: parse from message chain
- sender_id: str = None,
- sender_name: str = None,
+ sender_id: str | None = None,
+ sender_name: str | None = None,
):
"""Insert a new platform message history record."""
await self.db.insert_platform_message_history(
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index 28dc63f72..2f1e84419 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -4,15 +4,17 @@
from dataclasses import dataclass, field
from typing import Any
-from anthropic.types import Message
+from anthropic.types import Message as AnthropicMessage
from google.genai.types import GenerateContentResponse
from openai.types.chat.chat_completion import ChatCompletion
-from openai.types.chat.chat_completion_message_tool_call import (
- ChatCompletionMessageToolCall,
-)
import astrbot.core.message.components as Comp
from astrbot import logger
+from astrbot.core.agent.message import (
+ AssistantMessageSegment,
+ ToolCall,
+ ToolCallMessageSegment,
+)
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
@@ -32,9 +34,9 @@ class ProviderMetaData:
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
- """提供商适配器描述."""
+ """提供商适配器描述"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
- cls_type: type | None = None
+ cls_type: Any = None
default_config_tmpl: dict | None = None
"""平台的默认配置模板"""
@@ -42,44 +44,6 @@ class ProviderMetaData:
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
-@dataclass
-class ToolCallMessageSegment:
- """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
-
- tool_call_id: str
- content: str
- role: str = "tool"
-
- def to_dict(self):
- return {
- "tool_call_id": self.tool_call_id,
- "content": self.content,
- "role": self.role,
- }
-
-
-@dataclass
-class AssistantMessageSegment:
- """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
-
- content: str | None = None
- tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list)
- role: str = "assistant"
-
- def to_dict(self):
- ret: dict[str, str | list[dict]] = {
- "role": self.role,
- }
- if self.content:
- ret["content"] = self.content
- if self.tool_calls:
- tool_calls_dict = [
- tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
- ]
- ret["tool_calls"] = tool_calls_dict
- return ret
-
-
@dataclass
class ToolCallsResult:
"""工具调用结果"""
@@ -91,8 +55,8 @@ class ToolCallsResult:
def to_openai_messages(self) -> list[dict]:
ret = [
- self.tool_calls_info.to_dict(),
- *[item.to_dict() for item in self.tool_calls_result],
+ self.tool_calls_info.model_dump(),
+ *[item.model_dump() for item in self.tool_calls_result],
]
return ret
@@ -108,16 +72,16 @@ class ProviderRequest:
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
- """上下文。格式与 openai 的上下文格式一致:
+ """
+ OpenAI 格式上下文列表。
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation | None = None
-
+ """关联的对话对象"""
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
-
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
@@ -227,7 +191,9 @@ class LLMResponse:
tools_call_ids: list[str] = field(default_factory=list)
"""工具调用 ID"""
- raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
+ raw_completion: (
+ ChatCompletion | GenerateContentResponse | AnthropicMessage | None
+ ) = None
_new_record: dict[str, Any] | None = None
_completion_text: str = ""
@@ -243,7 +209,10 @@ def __init__(
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
- raw_completion: ChatCompletion | None = None,
+ raw_completion: ChatCompletion
+ | GenerateContentResponse
+ | AnthropicMessage
+ | None = None,
_new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
):
@@ -294,7 +263,7 @@ def completion_text(self, value):
self._completion_text = value
def to_openai_tool_calls(self) -> list[dict]:
- """将工具调用信息转换为 OpenAI 格式"""
+ """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
@@ -309,6 +278,21 @@ def to_openai_tool_calls(self) -> list[dict]:
)
return ret
+ def to_openai_to_calls_model(self) -> list[ToolCall]:
+ """The same as to_openai_tool_calls but return pydantic model."""
+ ret = []
+ for idx, tool_call_arg in enumerate(self.tools_call_args):
+ ret.append(
+ ToolCall(
+ id=self.tools_call_ids[idx],
+ function=ToolCall.FunctionBody(
+ name=self.tools_call_name[idx],
+ arguments=json.dumps(tool_call_arg),
+ ),
+ ),
+ )
+ return ret
+
@dataclass
class RerankResult:
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index b3ef1ed5c..36aad2ae9 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -10,7 +10,7 @@
from astrbot import logger
from astrbot.core import sp
-from astrbot.core.agent.mcp_client import MCPClient
+from astrbot.core.agent.mcp_client import MCPClient, MCPTool
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -254,18 +254,15 @@ async def _init_mcp_client(self, name: str, config: dict) -> None:
self.func_list = [
f
for f in self.func_list
- if not (f.origin == "mcp" and f.mcp_server_name == name)
+ if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
- func_tool = FuncTool(
- name=tool.name,
- parameters=tool.inputSchema,
- description=tool.description,
- origin="mcp",
- mcp_server_name=name,
+ func_tool = MCPTool(
+ mcp_tool=tool,
mcp_client=mcp_client,
+ mcp_server_name=name,
)
self.func_list.append(func_tool)
@@ -284,7 +281,7 @@ async def _terminate_mcp_client(self, name: str) -> None:
self.func_list = [
f
for f in self.func_list
- if not (f.origin == "mcp" and f.mcp_server_name == name)
+ if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
@@ -374,7 +371,7 @@ async def disable_mcp_server(
self.func_list = [
f
for f in self.func_list
- if f.origin != "mcp" or f.mcp_server_name != name
+ if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
else:
running_events = [
@@ -388,7 +385,9 @@ async def disable_mcp_server(
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
- self.func_list = [f for f in self.func_list if f.origin != "mcp"]
+ self.func_list = [
+ f for f in self.func_list if not isinstance(f, MCPTool)
+ ]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 23abfcfc0..7ab8f00ba 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -3,6 +3,7 @@
from collections.abc import AsyncGenerator
from dataclasses import dataclass
+from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
@@ -23,24 +24,28 @@ class ProviderMeta:
class AbstractProvider(abc.ABC):
+ """Provider Abstract Class"""
+
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
- """设置当前使用的模型名称"""
+ """Set the current model name"""
self.model_name = model_name
def get_model(self) -> str:
- """获得当前使用的模型名称"""
+ """Get the current model name"""
return self.model_name
def meta(self) -> ProviderMeta:
- """获取 Provider 的元数据"""
+ """Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
+ if provider_type is None:
+ raise ValueError(f"Cannot find provider type: {provider_type_name}")
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
@@ -50,6 +55,8 @@ def meta(self) -> ProviderMeta:
class Provider(AbstractProvider):
+ """Chat Provider"""
+
def __init__(
self,
provider_config: dict,
@@ -84,24 +91,24 @@ async def get_models(self) -> list[str]:
@abc.abstractmethod
async def text_chat(
self,
- prompt: str,
- session_id: str = None,
- image_urls: list[str] = None,
- func_tool: ToolSet = None,
- contexts: list = None,
- system_prompt: str = None,
- tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ prompt: str | None = None,
+ session_id: str | None = None,
+ image_urls: list[str] | None = None,
+ func_tool: ToolSet | None = None,
+ contexts: list[Message] | list[dict] | None = None,
+ system_prompt: str | None = None,
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
- prompt: 提示词
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
- tools: Function-calling 工具
- contexts: 上下文
+ tools: tool set
+ contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
@@ -114,24 +121,24 @@ async def text_chat(
async def text_chat_stream(
self,
- prompt: str,
- session_id: str = None,
- image_urls: list[str] = None,
- func_tool: ToolSet = None,
- contexts: list = None,
- system_prompt: str = None,
- tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ prompt: str | None = None,
+ session_id: str | None = None,
+ image_urls: list[str] | None = None,
+ func_tool: ToolSet | None = None,
+ contexts: list[Message] | list[dict] | None = None,
+ system_prompt: str | None = None,
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
- prompt: 提示词
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
- tools: Function-calling 工具
- contexts: 上下文
+ tools: tool set
+ contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
@@ -140,6 +147,7 @@ async def text_chat_stream(
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
+ ...
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
@@ -156,6 +164,22 @@ async def pop_record(self, context: list):
for idx in reversed(indexs_to_pop):
context.pop(idx)
+ def _ensure_message_to_dicts(
+ self,
+ messages: list[dict] | list[Message] | None,
+ ) -> list[dict]:
+ """Convert a list of Message objects to a list of dictionaries."""
+ if not messages:
+ return []
+ dicts: list[dict] = []
+ for message in messages:
+ if isinstance(message, Message):
+ dicts.append(message.model_dump())
+ else:
+ dicts.append(message)
+
+ return dicts
+
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py
index eb8c72aea..1aead54df 100644
--- a/astrbot/core/provider/register.py
+++ b/astrbot/core/provider/register.py
@@ -15,8 +15,8 @@ def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
- default_config_tmpl: dict = None,
- provider_display_name: str = None,
+ default_config_tmpl: dict | None = None,
+ provider_display_name: str | None = None,
):
"""用于注册平台适配器的带参装饰器"""
diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py
index 6f292f076..77c85cef4 100644
--- a/astrbot/core/provider/sources/anthropic_source.py
+++ b/astrbot/core/provider/sources/anthropic_source.py
@@ -243,7 +243,7 @@ async def _query_stream(
async def text_chat(
self,
- prompt,
+ prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -255,8 +255,13 @@ async def text_chat(
) -> LLMResponse:
if contexts is None:
contexts = []
- new_record = await self.assemble_context(prompt, image_urls)
- context_query = [*contexts, new_record]
+ new_record = None
+ if prompt is not None:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = self._ensure_message_to_dicts(contexts)
+ if new_record:
+ context_query.append(new_record)
+
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -306,8 +311,12 @@ async def text_chat_stream(
):
if contexts is None:
contexts = []
- new_record = await self.assemble_context(prompt, image_urls)
- context_query = [*contexts, new_record]
+ new_record = None
+ if prompt is not None:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = self._ensure_message_to_dicts(contexts)
+ if new_record:
+ context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py
index caee65020..23a8b3b76 100644
--- a/astrbot/core/provider/sources/coze_source.py
+++ b/astrbot/core/provider/sources/coze_source.py
@@ -331,6 +331,7 @@ async def text_chat_stream(
},
)
+ contexts = self._ensure_message_to_dicts(contexts)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py
index 92613dc1a..9b262c001 100644
--- a/astrbot/core/provider/sources/dashscope_source.py
+++ b/astrbot/core/provider/sources/dashscope_source.py
@@ -66,13 +66,15 @@ async def text_chat(
self,
prompt: str,
session_id=None,
- image_urls=[],
+ image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
+ if image_urls is None:
+ image_urls = []
if contexts is None:
contexts = []
# 获得会话变量
@@ -144,14 +146,15 @@ async def text_chat(
# RAG 引用脚标格式化
output_text = re.sub(r"[\[(\d+)\]]", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
- ref_str = ""
+ ref_parts = []
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
- ref_str += f"{ref['index_id']}. {ref_title}\n"
+ ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
+ ref_str = "".join(ref_parts)
output_text += f"\n\n回答来源:\n{ref_str}"
llm_response = LLMResponse("assistant")
diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py
index f9eef2e92..c3c9253a5 100644
--- a/astrbot/core/provider/sources/gemini_source.py
+++ b/astrbot/core/provider/sources/gemini_source.py
@@ -572,7 +572,7 @@ async def _query_stream(
async def text_chat(
self,
- prompt: str,
+ prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -584,8 +584,12 @@ async def text_chat(
) -> LLMResponse:
if contexts is None:
contexts = []
- new_record = await self.assemble_context(prompt, image_urls)
- context_query = [*contexts, new_record]
+ new_record = None
+ if prompt is not None:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = self._ensure_message_to_dicts(contexts)
+ if new_record:
+ context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -621,7 +625,7 @@ async def text_chat(
async def text_chat_stream(
self,
- prompt,
+ prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -633,8 +637,12 @@ async def text_chat_stream(
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
- new_record = await self.assemble_context(prompt, image_urls)
- context_query = [*contexts, new_record]
+ new_record = None
+ if prompt is not None:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = self._ensure_message_to_dicts(contexts)
+ if new_record:
+ context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -726,7 +734,6 @@ async def encode_image_bs64(self, image_url: str) -> str:
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
- return ""
async def terminate(self):
logger.info("Google GenAI 适配器已终止。")
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 1020075af..076afc40f 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -14,9 +14,10 @@
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
+from astrbot.core.agent.message import Message
+from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
-from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -102,7 +103,7 @@ async def get_models(self):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
- async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
+ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -153,7 +154,7 @@ async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
async def _query_stream(
self,
payloads: dict,
- tools: ToolSet,
+ tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
@@ -212,7 +213,9 @@ async def _query_stream(
yield llm_response
- async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
+ async def parse_openai_completion(
+ self, completion: ChatCompletion, tools: ToolSet | None
+ ) -> LLMResponse:
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
@@ -225,7 +228,7 @@ async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolS
completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text)
- if choice.message.tool_calls:
+ if choice.message.tool_calls and tools is not None:
# tools call (function calling)
args_ls = []
func_name_ls = []
@@ -267,9 +270,9 @@ async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolS
async def _prepare_chat_payload(
self,
- prompt: str,
+ prompt: str | None,
image_urls: list[str] | None = None,
- contexts: list | None = None,
+ contexts: list[dict] | list[Message] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
@@ -278,8 +281,12 @@ async def _prepare_chat_payload(
"""准备聊天所需的有效载荷和上下文"""
if contexts is None:
contexts = []
- new_record = await self.assemble_context(prompt, image_urls)
- context_query = [*contexts, new_record]
+ new_record = None
+ if prompt is not None:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = self._ensure_message_to_dicts(contexts)
+ if new_record:
+ context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -310,7 +317,7 @@ async def _handle_api_error(
e: Exception,
payloads: dict,
context_query: list,
- func_tool: ToolSet,
+ func_tool: ToolSet | None,
chosen_key: str,
available_api_keys: list[str],
retry_cnt: int,
@@ -390,7 +397,7 @@ async def _handle_api_error(
async def text_chat(
self,
- prompt,
+ prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -459,7 +466,7 @@ async def text_chat(
async def text_chat_stream(
self,
- prompt: str,
+ prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py
index 620e7e907..1a5bc53d9 100644
--- a/astrbot/core/star/context.py
+++ b/astrbot/core/star/context.py
@@ -1,3 +1,4 @@
+import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import Any
@@ -35,6 +36,8 @@
from .star import StarMetadata, star_map, star_registry
from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
+logger = logging.getLogger("astrbot")
+
class Context:
"""暴露给插件的接口上下文。"""
@@ -255,9 +258,44 @@ async def send_message(
def add_llm_tools(self, *tools: FunctionTool) -> None:
"""添加 LLM 工具。"""
+ tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
+ module_path = ""
for tool in tools:
+ if not module_path:
+ _parts = []
+ module_part = tool.__module__.split(".")
+ flags = ["packages", "plugins"]
+ for i, part in enumerate(module_part):
+ _parts.append(part)
+ if part in flags and i + 1 < len(module_part):
+ _parts.append(module_part[i + 1])
+ break
+ tool.handler_module_path = ".".join(_parts)
+ module_path = tool.handler_module_path
+ else:
+ tool.handler_module_path = module_path
+ logger.info(
+ f"plugin(module_path {module_path}) added LLM tool: {tool.name}"
+ )
+
+ if tool.name in tool_name:
+ logger.warning("替换已存在的 LLM 工具: " + tool.name)
+ self.provider_manager.llm_tools.remove_func(tool.name)
self.provider_manager.llm_tools.func_list.append(tool)
+ def register_web_api(
+ self,
+ route: str,
+ view_handler: Awaitable,
+ methods: list,
+ desc: str,
+ ):
+ for idx, api in enumerate(self.registered_web_apis):
+ if api[0] == route and methods == api[2]:
+ self.registered_web_apis[idx] = (route, view_handler, methods, desc)
+ return
+ self.registered_web_apis.append((route, view_handler, methods, desc))
+
"""
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
@@ -269,7 +307,7 @@ def register_llm_tool(
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
- """为函数调用(function-calling / tools-use)添加工具。
+ """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@@ -291,7 +329,7 @@ def register_llm_tool(
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
- """删除一个函数调用工具。如果再要启用,需要重新注册。"""
+ """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
self.provider_manager.llm_tools.remove_func(name)
def register_commands(
@@ -333,18 +371,5 @@ def register_commands(
star_handlers_registry.append(md)
def register_task(self, task: Awaitable, desc: str):
- """注册一个异步任务。"""
+ """[DEPRECATED]注册一个异步任务。"""
self._register_tasks.append(task)
-
- def register_web_api(
- self,
- route: str,
- view_handler: Awaitable,
- methods: list,
- desc: str,
- ):
- for idx, api in enumerate(self.registered_web_apis):
- if api[0] == route and methods == api[2]:
- self.registered_web_apis[idx] = (route, view_handler, methods, desc)
- return
- self.registered_web_apis.append((route, view_handler, methods, desc))
diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py
index 6e0283a0e..2a9868fdc 100755
--- a/astrbot/core/star/filter/command.py
+++ b/astrbot/core/star/filter/command.py
@@ -36,11 +36,13 @@ def __init__(
command_name: str,
alias: set | None = None,
handler_md: StarHandlerMetadata | None = None,
- parent_command_names: list[str] = [""],
+ parent_command_names: list[str] | None = None,
):
self.command_name = command_name
self.alias = alias if alias else set()
- self.parent_command_names = parent_command_names
+ self.parent_command_names = (
+ parent_command_names if parent_command_names is not None else [""]
+ )
if handler_md:
self.init_handler_md(handler_md)
self.custom_filter_list: list[CustomFilter] = []
@@ -49,15 +51,15 @@ def __init__(
self._cmpl_cmd_names: list | None = None
def print_types(self):
- result = ""
+ parts = []
for k, v in self.handler_params.items():
if isinstance(v, type):
- result += f"{k}({v.__name__}),"
+ parts.append(f"{k}({v.__name__}),")
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
- result += f"{k}({v}),"
+ parts.append(f"{k}({v}),")
else:
- result += f"{k}({type(v).__name__})={v},"
- result = result.rstrip(",")
+ parts.append(f"{k}({type(v).__name__})={v},")
+ result = "".join(parts).rstrip(",")
return result
def init_handler_md(self, handle_md: StarHandlerMetadata):
diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py
index 0f5c19ec5..e1c2efb22 100755
--- a/astrbot/core/star/filter/command_group.py
+++ b/astrbot/core/star/filter/command_group.py
@@ -66,7 +66,7 @@ def print_cmd_tree(
event: AstrMessageEvent | None = None,
cfg: AstrBotConfig | None = None,
) -> str:
- result = ""
+ parts = []
for sub_filter in sub_command_filters:
if isinstance(sub_filter, CommandFilter):
custom_filter_pass = True
@@ -74,31 +74,32 @@ def print_cmd_tree(
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
if custom_filter_pass:
cmd_th = sub_filter.print_types()
- result += f"{prefix}├── {sub_filter.command_name}"
+ line = f"{prefix}├── {sub_filter.command_name}"
if cmd_th:
- result += f" ({cmd_th})"
+ line += f" ({cmd_th})"
else:
- result += " (无参数指令)"
+ line += " (无参数指令)"
if sub_filter.handler_md and sub_filter.handler_md.desc:
- result += f": {sub_filter.handler_md.desc}"
+ line += f": {sub_filter.handler_md.desc}"
- result += "\n"
+ parts.append(line + "\n")
elif isinstance(sub_filter, CommandGroupFilter):
custom_filter_pass = True
if event and cfg:
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
if custom_filter_pass:
- result += f"{prefix}├── {sub_filter.group_name}"
- result += "\n"
- result += sub_filter.print_cmd_tree(
- sub_filter.sub_command_filters,
- prefix + "│ ",
- event=event,
- cfg=cfg,
+ parts.append(f"{prefix}├── {sub_filter.group_name}\n")
+ parts.append(
+ sub_filter.print_cmd_tree(
+ sub_filter.sub_command_filters,
+ prefix + "│ ",
+ event=event,
+ cfg=cfg,
+ )
)
- return result
+ return "".join(parts)
def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
for custom_filter in self.custom_filter_list:
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 4966735d5..70a0f73f8 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -830,7 +830,13 @@ async def turn_off_plugin(self, plugin_name: str):
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
- if func_tool.handler_module_path == plugin.module_path:
+ mp = func_tool.handler_module_path
+ if (
+ plugin.module_path
+ and mp
+ and plugin.module_path.startswith(mp)
+ and not mp.endswith(("packages", "data.plugins"))
+ ):
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
@@ -873,8 +879,12 @@ async def turn_on_plugin(self, plugin_name: str):
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
+ mp = func_tool.handler_module_path
if (
- func_tool.handler_module_path == plugin.module_path
+ plugin.module_path
+ and mp
+ and plugin.module_path.startswith(mp)
+ and not mp.endswith(("packages", "data.plugins"))
and func_tool.name in inactivated_llm_tools
):
inactivated_llm_tools.remove(func_tool.name)
@@ -883,8 +893,6 @@ async def turn_on_plugin(self, plugin_name: str):
await self.reload(plugin_name)
- # plugin.activated = True
-
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py
index 2500e69a5..ea8ff9dff 100644
--- a/astrbot/core/utils/dify_api_client.py
+++ b/astrbot/core/utils/dify_api_client.py
@@ -46,9 +46,11 @@ async def chat_messages(
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
- files: list[dict[str, Any]] = [],
+ files: list[dict[str, Any]] | None = None,
timeout: float = 60,
) -> AsyncGenerator[dict[str, Any], None]:
+ if files is None:
+ files = []
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
@@ -73,9 +75,11 @@ async def workflow_run(
inputs: dict,
user: str,
response_mode: str = "streaming",
- files: list[dict[str, Any]] = [],
+ files: list[dict[str, Any]] | None = None,
timeout: float = 60,
):
+ if files is None:
+ files = []
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py
index bd0bea920..03549dc97 100644
--- a/astrbot/core/utils/io.py
+++ b/astrbot/core/utils/io.py
@@ -77,8 +77,8 @@ def save_temp_img(img: Image.Image | str) -> str:
async def download_image_by_url(
url: str,
post: bool = False,
- post_data: dict = None,
- path=None,
+ post_data: dict | None = None,
+ path: str | None = None,
) -> str:
"""下载图片, 返回 path"""
try:
diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py
index abe247146..6076a114a 100644
--- a/astrbot/core/utils/pip_installer.py
+++ b/astrbot/core/utils/pip_installer.py
@@ -6,15 +6,15 @@
class PipInstaller:
- def __init__(self, pip_install_arg: str, pypi_index_url: str = None):
+ def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None):
self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url
async def install(
self,
- package_name: str = None,
- requirements_path: str = None,
- mirror: str = None,
+ package_name: str | None = None,
+ requirements_path: str | None = None,
+ mirror: str | None = None,
):
args = ["install"]
if package_name:
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index 6954f2d6a..5156e14e5 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -134,10 +134,10 @@ async def chat(self):
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
- # append user message
+ # 追加用户消息
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
- # Get conversation-specific queues
+ # 获取会话特定的队列
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
new_his = {"type": "user", "message": message}
@@ -200,7 +200,7 @@ async def stream():
or not streaming
or type == "break"
):
- # append bot message
+ # 追加机器人消息
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
@@ -212,7 +212,7 @@ async def stream():
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
- # Put message to conversation-specific queue
+ # 将消息放入会话特定的队列
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
await chat_queue.put(
(
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index 59b07e47d..b947d26f2 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -926,7 +926,9 @@ async def _get_plugin_config(self, plugin_name: str):
return ret
- async def _save_astrbot_configs(self, post_configs: dict, conf_id: str = None):
+ async def _save_astrbot_configs(
+ self, post_configs: dict, conf_id: str | None = None
+ ):
try:
if conf_id not in self.acm.confs:
raise ValueError(f"配置文件 {conf_id} 不存在")
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index 775983ef8..84976f2ba 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -105,7 +105,7 @@ async def auth_middleware(self):
allowed_endpoints = ["/api/auth/login", "/api/file"]
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
return None
- # claim jwt
+ # 声明 JWT
token = request.headers.get("Authorization")
if not token:
r = jsonify(Response().error("未授权").__dict__)
@@ -213,11 +213,12 @@ def run(self):
raise Exception(f"端口 {port} 已被占用")
- display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"
- display += f" ➜ 本地: http://localhost:{port}\n"
+ parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"]
+ parts.append(f" ➜ 本地: http://localhost:{port}\n")
for ip in ip_addr:
- display += f" ➜ 网络: http://{ip}:{port}\n"
- display += " ➜ 默认用户名和密码: astrbot\n ✨✨✨\n"
+ parts.append(f" ➜ 网络: http://{ip}:{port}\n")
+ parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n")
+ display = "".join(parts)
if not ip_addr:
display += (
diff --git a/main.py b/main.py
index b453cdfb5..60879f065 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
-# add parent path to sys.path
+# 将父目录添加到 sys.path
sys.path.append(Path(__file__).parent.as_posix())
logo_tmpl = r"""
@@ -34,7 +34,7 @@ def check_env():
os.makedirs("data/plugins", exist_ok=True)
os.makedirs("data/temp", exist_ok=True)
- # workaround for issue #181
+ # 针对问题 #181 的临时解决方案
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
@@ -53,7 +53,7 @@ async def check_dashboard_files(webui_dir: str | None = None):
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
if v is not None:
- # has file
+ # 存在文件
if v == f"v{VERSION}":
logger.info("WebUI 版本已是最新。")
else:
@@ -88,16 +88,16 @@ async def check_dashboard_files(webui_dir: str | None = None):
check_env()
- # start log broker
+ # 启动日志代理
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
- # check dashboard files
+ # 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
- # print logo
+ # 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py
index 82b661773..9538d8f53 100644
--- a/packages/astrbot/commands/conversation.py
+++ b/packages/astrbot/commands/conversation.py
@@ -134,12 +134,13 @@ async def his(self, message: AstrMessageEvent, page: int = 1):
size_per_page,
)
- history = ""
+ parts = []
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
- history += f"{context}\n"
+ parts.append(f"{context}\n")
+ history = "".join(parts)
ret = (
f"当前对话历史记录:"
f"{history or '无历史记录'}\n\n"
@@ -154,7 +155,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1):
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
"""原有的Dify处理逻辑保持不变"""
- ret = "Dify 对话列表:\n"
+ parts = ["Dify 对话列表:\n"]
assert isinstance(provider, ProviderDify)
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
idx = 1
@@ -162,12 +163,17 @@ async def convs(self, message: AstrMessageEvent, page: int = 1):
ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
"%m-%d %H:%M",
)
- ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
+ parts.append(
+ f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
+ )
idx += 1
if idx == 1:
- ret += "没有找到任何对话。"
+ parts.append("没有找到任何对话。")
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
- ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
+ parts.append(
+ f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
+ )
+ ret = "".join(parts)
message.set_result(MessageEventResult().message(ret))
return
@@ -185,7 +191,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1):
end_idx = start_idx + size_per_page
conversations_paged = conversations_all[start_idx:end_idx]
- ret = "对话列表:\n---\n"
+ parts = ["对话列表:\n---\n"]
"""全局序号从当前页的第一个开始"""
global_index = start_idx + 1
@@ -204,10 +210,13 @@ async def convs(self, message: AstrMessageEvent, page: int = 1):
)
persona_id = persona["name"]
title = _titles.get(conv.cid, "新对话")
- ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
+ parts.append(
+ f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
+ )
global_index += 1
- ret += "---\n"
+ parts.append("---\n")
+ ret = "".join(parts)
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
diff --git a/packages/astrbot/commands/persona.py b/packages/astrbot/commands/persona.py
index 53582ce8e..1289cb569 100644
--- a/packages/astrbot/commands/persona.py
+++ b/packages/astrbot/commands/persona.py
@@ -59,10 +59,11 @@ async def persona(self, message: AstrMessageEvent):
.use_t2i(False),
)
elif l[1] == "list":
- msg = "人格列表:\n"
+ parts = ["人格列表:\n"]
for persona in self.context.provider_manager.personas:
- msg += f"- {persona['name']}\n"
- msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息"
+ parts.append(f"- {persona['name']}\n")
+ parts.append("\n\n*输入 `/persona view 人格名` 查看人格详细信息")
+ msg = "".join(parts)
message.set_result(MessageEventResult().message(msg))
elif l[1] == "view":
if len(l) == 2:
diff --git a/packages/astrbot/commands/plugin.py b/packages/astrbot/commands/plugin.py
index f9092ff97..ab45efc11 100644
--- a/packages/astrbot/commands/plugin.py
+++ b/packages/astrbot/commands/plugin.py
@@ -13,14 +13,17 @@ def __init__(self, context: star.Context):
async def plugin_ls(self, event: AstrMessageEvent):
"""获取已经安装的插件列表。"""
- plugin_list_info = "已加载的插件:\n"
+ parts = ["已加载的插件:\n"]
for plugin in self.context.get_all_stars():
- plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
+ line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
if not plugin.activated:
- plugin_list_info += " (未启用)"
- plugin_list_info += "\n"
- if plugin_list_info.strip() == "":
+ line += " (未启用)"
+ parts.append(line + "\n")
+
+ if len(parts) == 1:
plugin_list_info = "没有加载任何插件。"
+ else:
+ plugin_list_info = "".join(parts)
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(
@@ -103,14 +106,14 @@ async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
- help_msg += "\n\n🔧 指令列表:\n"
+ parts = ["\n\n🔧 指令列表:\n"]
for i in range(len(command_handlers)):
- help_msg += f"- {command_names[i]}"
+ line = f"- {command_names[i]}"
if command_handlers[i].desc:
- help_msg += f": {command_handlers[i].desc}"
- help_msg += "\n"
-
- help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
+ line += f": {command_handlers[i].desc}"
+ parts.append(line + "\n")
+ parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。")
+ help_msg += "".join(parts)
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README。"
diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py
index 750e9de5a..8db7324e4 100644
--- a/packages/astrbot/commands/provider.py
+++ b/packages/astrbot/commands/provider.py
@@ -19,38 +19,39 @@ async def provider(
umo = event.unified_msg_origin
if idx is None:
- ret = "## 载入的 LLM 提供商\n"
+ parts = ["## 载入的 LLM 提供商\n"]
for idx, llm in enumerate(self.context.get_all_providers()):
id_ = llm.meta().id
- ret += f"{idx + 1}. {id_} ({llm.meta().model})"
+ line = f"{idx + 1}. {id_} ({llm.meta().model})"
provider_using = self.context.get_using_provider(umo=umo)
if provider_using and provider_using.meta().id == id_:
- ret += " (当前使用)"
- ret += "\n"
+ line += " (当前使用)"
+ parts.append(line + "\n")
tts_providers = self.context.get_all_tts_providers()
if tts_providers:
- ret += "\n## 载入的 TTS 提供商\n"
+ parts.append("\n## 载入的 TTS 提供商\n")
for idx, tts in enumerate(tts_providers):
id_ = tts.meta().id
- ret += f"{idx + 1}. {id_}"
+ line = f"{idx + 1}. {id_}"
tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == id_:
- ret += " (当前使用)"
- ret += "\n"
+ line += " (当前使用)"
+ parts.append(line + "\n")
stt_providers = self.context.get_all_stt_providers()
if stt_providers:
- ret += "\n## 载入的 STT 提供商\n"
+ parts.append("\n## 载入的 STT 提供商\n")
for idx, stt in enumerate(stt_providers):
id_ = stt.meta().id
- ret += f"{idx + 1}. {id_}"
+ line = f"{idx + 1}. {id_}"
stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == id_:
- ret += " (当前使用)"
- ret += "\n"
+ line += " (当前使用)"
+ parts.append(line + "\n")
- ret += "\n使用 /provider <序号> 切换 LLM 提供商。"
+ parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
+ ret = "".join(parts)
if tts_providers:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
@@ -128,16 +129,17 @@ async def model_ls(
.use_t2i(False),
)
return
- i = 1
- ret = "下面列出了此模型提供商可用模型:"
- for model in models:
- ret += f"\n{i}. {model}"
- i += 1
+ parts = ["下面列出了此模型提供商可用模型:"]
+ for i, model in enumerate(models, 1):
+ parts.append(f"\n{i}. {model}")
curr_model = prov.get_model() or "无"
- ret += f"\n当前模型: [{curr_model}]"
+ parts.append(f"\n当前模型: [{curr_model}]")
+ parts.append(
+ "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
+ )
- ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
+ ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
elif isinstance(idx_or_name, int):
models = []
@@ -180,14 +182,15 @@ async def key(self, message: AstrMessageEvent, index: int | None = None):
if index is None:
keys_data = prov.get_keys()
curr_key = prov.get_current_key()
- ret = "Key:"
- for i, k in enumerate(keys_data):
- ret += f"\n{i + 1}. {k[:8]}"
+ parts = ["Key:"]
+ for i, k in enumerate(keys_data, 1):
+ parts.append(f"\n{i}. {k[:8]}")
- ret += f"\n当前 Key: {curr_key[:8]}"
- ret += "\n当前模型: " + prov.get_model()
- ret += "\n使用 /key 切换 Key。"
+ parts.append(f"\n当前 Key: {curr_key[:8]}")
+ parts.append("\n当前模型: " + prov.get_model())
+ parts.append("\n使用 /key 切换 Key。")
+ ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
keys_data = prov.get_keys()
diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py
index a686d35b2..ceca60ef7 100644
--- a/packages/astrbot/long_term_memory.py
+++ b/packages/astrbot/long_term_memory.py
@@ -119,13 +119,13 @@ async def handle_message(self, event: AstrMessageEvent):
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
- final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
+ parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "]
cfg = self.cfg(event)
for comp in event.get_messages():
if isinstance(comp, Plain):
- final_message += f" {comp.text}"
+ parts.append(f" {comp.text}")
elif isinstance(comp, Image):
if cfg["image_caption"]:
try:
@@ -137,11 +137,13 @@ async def handle_message(self, event: AstrMessageEvent):
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
)
- final_message += f" [Image: {caption}]"
+ parts.append(f" [Image: {caption}]")
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
- final_message += " [Image]"
+ parts.append(" [Image]")
+
+ final_message = "".join(parts)
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
diff --git a/packages/reminder/main.py b/packages/reminder/main.py
index 0349b9eb4..eaeec8d73 100644
--- a/packages/reminder/main.py
+++ b/packages/reminder/main.py
@@ -203,14 +203,15 @@ async def reminder_ls(self, event: AstrMessageEvent):
if not reminders:
yield event.plain_result("没有正在进行的待办事项。")
else:
- reminder_str = "正在进行的待办事项:\n"
+ parts = ["正在进行的待办事项:\n"]
for i, reminder in enumerate(reminders):
time_ = reminder.get("datetime", "")
if not time_:
cron_expr = reminder.get("cron", "")
time_ = reminder.get("cron_h", "") + f"(Cron: {cron_expr})"
- reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n"
- reminder_str += "\n使用 /reminder rm 删除待办事项。\n"
+ parts.append(f"{i + 1}. {reminder['text']} - {time_}\n")
+ parts.append("\n使用 /reminder rm 删除待办事项。\n")
+ reminder_str = "".join(parts)
yield event.plain_result(reminder_str)
@reminder.command("rm")