From 66126d85ac1bde5896c3fa4991849ddcbb66ae77 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Tue, 2 Jul 2024 16:54:10 +0800 Subject: [PATCH] Remove `Tht` class from AgentScope and Unified typing into Msg rather than MessageBase (#313) --- src/agentscope/memory/memory.py | 13 ++-- src/agentscope/memory/temporary_memory.py | 24 +++----- src/agentscope/message.py | 72 +---------------------- src/agentscope/models/dashscope_model.py | 24 ++++---- src/agentscope/models/gemini_model.py | 12 ++-- src/agentscope/models/litellm_model.py | 16 +++-- src/agentscope/models/model.py | 4 +- src/agentscope/models/ollama_model.py | 24 ++++---- src/agentscope/models/openai_model.py | 12 ++-- src/agentscope/models/post_model.py | 12 ++-- src/agentscope/models/zhipu_model.py | 14 ++--- tests/memory_test.py | 28 +-------- tests/model_test.py | 4 +- tests/retrieval_from_list_test.py | 14 ++--- tests/web_digest_test.py | 4 +- 15 files changed, 77 insertions(+), 200 deletions(-) diff --git a/src/agentscope/memory/memory.py b/src/agentscope/memory/memory.py index 14b82ee25..bf457a3e5 100644 --- a/src/agentscope/memory/memory.py +++ b/src/agentscope/memory/memory.py @@ -12,7 +12,7 @@ from typing import Union from typing import Callable -from ..message import MessageBase +from ..message import Msg class MemoryBase(ABC): @@ -62,14 +62,13 @@ def get_memory( @abstractmethod def add( self, - memories: Union[Sequence[dict], dict, None], + memories: Union[Sequence[Msg], Msg, None], ) -> None: """ Adding new memory fragment, depending on how the memory are stored Args: - memories (Union[Sequence[dict], dict, None]): - Memories to be added. If the memory is not in MessageBase, - it will first be converted into a message type. + memories (Union[Sequence[Msg], Msg, None]): + Memories to be added. """ @abstractmethod @@ -85,14 +84,14 @@ def delete(self, index: Union[Iterable, int]) -> None: @abstractmethod def load( self, - memories: Union[str, list[MessageBase], MessageBase], + memories: Union[str, list[Msg], Msg], overwrite: bool = False, ) -> None: """ Load memory, depending on how the memory are passed, design to load from both file or dict Args: - memories (Union[str, list[MessageBase], MessageBase]): + memories (Union[str, list[Msg], Msg]): memories to be loaded. If it is in str type, it will be first checked if it is a file; otherwise it will be deserialized as messages. diff --git a/src/agentscope/memory/temporary_memory.py b/src/agentscope/memory/temporary_memory.py index 356fa4d96..11a004265 100644 --- a/src/agentscope/memory/temporary_memory.py +++ b/src/agentscope/memory/temporary_memory.py @@ -21,7 +21,6 @@ serialize, MessageBase, Msg, - Tht, PlaceholderMessage, ) @@ -58,18 +57,17 @@ def __init__( def add( self, - memories: Union[Sequence[dict], dict, None], + memories: Union[Sequence[Msg], Msg, None], embed: bool = False, ) -> None: # pylint: disable=too-many-branches """ Adding new memory fragment, depending on how the memory are stored Args: - memories (Union[Sequence[dict], dict, None]): - memories to be added. If the memory is not in MessageBase, - it will first be converted into a message type. - embed (bool): - whether to generate embedding for the new added memories + memories (`Union[Sequence[Msg], Msg, None]`): + Memories to be added. + embed (`bool`): + Whether to generate embedding for the new added memories """ if memories is None: return @@ -84,13 +82,7 @@ def add( for memory_unit in record_memories: if not issubclass(type(memory_unit), MessageBase): try: - if ( - "name" in memory_unit - and memory_unit["name"] == "thought" - ): - memory_unit = Tht(**memory_unit) - else: - memory_unit = Msg(**memory_unit) + memory_unit = Msg(**memory_unit) except Exception as exc: raise ValueError( f"Cannot add {memory_unit} to memory, " @@ -186,14 +178,14 @@ def export( def load( self, - memories: Union[str, list[MessageBase], MessageBase], + memories: Union[str, list[Msg], Msg], overwrite: bool = False, ) -> None: """ Load memory, depending on how the memory are passed, design to load from both file or dict Args: - memories (Union[str, list[MessageBase], MessageBase]): + memories (Union[str, list[Msg], Msg]): memories to be loaded. If it is in str type, it will be first checked if it is a file; otherwise it will be deserialized as messages. diff --git a/src/agentscope/message.py b/src/agentscope/message.py index 23d2dd7a0..e59fef97f 100644 --- a/src/agentscope/message.py +++ b/src/agentscope/message.py @@ -173,73 +173,6 @@ def serialize(self) -> str: return json.dumps({"__type": "Msg", **self}) -class Tht(MessageBase): - """The Thought message is used to record the thought of the agent to - help them make decisions and responses. Generally, it shouldn't be - passed to or seen by the other agents. - - In our framework, we formulate the thought in prompt as follows: - - For OpenAI API calling: - - .. code-block:: python - - [ - ... - { - "role": "assistant", - "name": "thought", - "content": "I should ..." - }, - ... - ] - - - For open-source models that accepts string as input: - - .. code-block:: python - - ... - {self.name} thought: I should ... - ... - - We admit that there maybe better ways to formulate the thought. Users - are encouraged to create their own thought formulation methods by - inheriting `MessageBase` class and rewrite the `__init__` and `to_str` - function. - - .. code-block:: python - - class MyThought(MessageBase): - def to_str(self) -> str: - # implement your own thought formulation method - pass - """ - - def __init__( - self, - content: Any, - timestamp: Optional[str] = None, - **kwargs: Any, - ) -> None: - if "name" in kwargs: - kwargs.pop("name") - if "role" in kwargs: - kwargs.pop("role") - super().__init__( - name="thought", - content=content, - role="assistant", - timestamp=timestamp, - **kwargs, - ) - - def to_str(self) -> str: - """Return the string representation of the message""" - return f"{self.name} thought: {self.content}" - - def serialize(self) -> str: - return json.dumps({"__type": "Tht", **self}) - - class PlaceholderMessage(Msg): """A placeholder for the return message of RpcAgent.""" @@ -374,7 +307,7 @@ def update_value(self) -> MessageBase: if status == "ERROR": raise RuntimeError(msg.content) self.update(msg) - # the actual value has been updated, not a placeholder any more + # the actual value has been updated, not a placeholder anymore self._is_placeholder = False return self @@ -418,12 +351,11 @@ def serialize(self) -> str: _MSGS = { "Msg": Msg, - "Tht": Tht, "PlaceholderMessage": PlaceholderMessage, } -def deserialize(s: Union[str, bytes]) -> Union[MessageBase, Sequence]: +def deserialize(s: Union[str, bytes]) -> Union[Msg, Sequence]: """Deserialize json string into MessageBase""" js_msg = json.loads(s) msg_type = js_msg.pop("__type") diff --git a/src/agentscope/models/dashscope_model.py b/src/agentscope/models/dashscope_model.py index e7ef3411b..dcd727f0e 100644 --- a/src/agentscope/models/dashscope_model.py +++ b/src/agentscope/models/dashscope_model.py @@ -6,7 +6,7 @@ from typing import Any, Union, List, Sequence from loguru import logger -from ..message import MessageBase +from ..message import Msg from ..utils.tools import _convert_to_str, _guess_type_by_extension try: @@ -66,7 +66,7 @@ def __init__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -213,7 +213,7 @@ def __call__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List: """Format the messages for DashScope Chat API. @@ -254,7 +254,7 @@ def format( Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -269,11 +269,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( @@ -655,7 +653,7 @@ def __call__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List: """Format the messages for DashScope Multimodal API. @@ -737,7 +735,7 @@ def format( "file://", which will be attached in this format function. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -752,11 +750,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( diff --git a/src/agentscope/models/gemini_model.py b/src/agentscope/models/gemini_model.py index 3deca77d1..0281c1e91 100644 --- a/src/agentscope/models/gemini_model.py +++ b/src/agentscope/models/gemini_model.py @@ -7,7 +7,7 @@ from loguru import logger -from agentscope.message import Msg, MessageBase +from agentscope.message import Msg from agentscope.models import ModelWrapperBase, ModelResponse from agentscope.utils.tools import _convert_to_str @@ -250,7 +250,7 @@ def _register_default_metrics(self) -> None: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """This function provide a basic prompting strategy for Gemini Chat API in multi-party conversation, which combines all input into a @@ -279,7 +279,7 @@ def format( https://github.com/agentscope/agentscope! Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -292,11 +292,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( diff --git a/src/agentscope/models/litellm_model.py b/src/agentscope/models/litellm_model.py index 7a9309c07..a42ca89e5 100644 --- a/src/agentscope/models/litellm_model.py +++ b/src/agentscope/models/litellm_model.py @@ -6,7 +6,7 @@ from loguru import logger from .model import ModelWrapperBase, ModelResponse -from ..message import MessageBase +from ..message import Msg from ..utils.tools import _convert_to_str try: @@ -73,7 +73,7 @@ def __init__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -183,16 +183,16 @@ def __call__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """Format the input string and dictionary into the unified format. - Note that the format function might not be the optimal way to contruct + Note that the format function might not be the optimal way to construct prompt for every model, but a common way to do so. Developers are encouraged to implement their own prompt engineering strategies if have strong performance concerns. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -207,11 +207,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index 2d9d50a8b..072afcad4 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -67,7 +67,7 @@ from ..exception import ResponseParsingError from ..file_manager import file_manager -from ..message import MessageBase +from ..message import Msg from ..utils import MonitorFactory from ..utils.monitor import get_full_name from ..utils.tools import _get_timestamp @@ -227,7 +227,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: """Format the input string or dict into the format that the model API required.""" diff --git a/src/agentscope/models/ollama_model.py b/src/agentscope/models/ollama_model.py index 9bdc07869..3a9f9b89d 100644 --- a/src/agentscope/models/ollama_model.py +++ b/src/agentscope/models/ollama_model.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Sequence, Any, Optional, List, Union -from agentscope.message import MessageBase +from agentscope.message import Msg from agentscope.models import ModelWrapperBase, ModelResponse from agentscope.utils.tools import _convert_to_str @@ -166,7 +166,7 @@ def _register_default_metrics(self) -> None: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """Format the messages for ollama Chat API. @@ -207,7 +207,7 @@ def format( Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -222,11 +222,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( @@ -354,7 +352,7 @@ def _register_default_metrics(self) -> None: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -458,11 +456,11 @@ def _register_default_metrics(self) -> None: metric_unit="token", ) - def format(self, *args: Union[MessageBase, Sequence[MessageBase]]) -> str: + def format(self, *args: Union[Msg, Sequence[Msg]]) -> str: """Forward the input to the model. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -475,11 +473,9 @@ def format(self, *args: Union[MessageBase, Sequence[MessageBase]]) -> str: for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index 9bbdc59c1..affda4f23 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -7,7 +7,7 @@ from .model import ModelWrapperBase, ModelResponse from ..file_manager import file_manager -from ..message import MessageBase +from ..message import Msg from ..utils.tools import _convert_to_str, _to_openai_image_url try: @@ -91,7 +91,7 @@ def __init__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -217,7 +217,7 @@ def __call__( def _format_msg_with_url( self, - msg: MessageBase, + msg: Msg, ) -> Dict: """Format a message with image urls into openai chat format. This format method is used for gpt-4o, gpt-4-turbo, gpt-4-vision and @@ -288,13 +288,13 @@ def _format_msg_with_url( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """Format the input string and dictionary into the format that OpenAI Chat API required. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -308,7 +308,7 @@ def format( for arg in args: if arg is None: continue - if isinstance(arg, MessageBase): + if isinstance(arg, Msg): if arg.url is not None: messages.append(self._format_msg_with_url(arg)) else: diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index fe61be221..7167fd6c6 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -12,7 +12,7 @@ from ..constants import _DEFAULT_MAX_RETRIES from ..constants import _DEFAULT_MESSAGES_KEY from ..constants import _DEFAULT_RETRY_INTERVAL -from ..message import MessageBase +from ..message import Msg from ..utils.tools import _convert_to_str @@ -175,13 +175,13 @@ def _parse_response(self, response: dict) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict]]: """Format the input messages into a list of dict, which is compatible to OpenAI Chat API. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -194,7 +194,7 @@ def format( for arg in args: if arg is None: continue - if isinstance(arg, MessageBase): + if isinstance(arg, Msg): messages.append( { "role": arg.role, @@ -233,7 +233,7 @@ def _parse_response(self, response: dict) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -293,7 +293,7 @@ def _parse_response(self, response: dict) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " diff --git a/src/agentscope/models/zhipu_model.py b/src/agentscope/models/zhipu_model.py index 5c33e2b45..658a71d53 100644 --- a/src/agentscope/models/zhipu_model.py +++ b/src/agentscope/models/zhipu_model.py @@ -6,7 +6,7 @@ from loguru import logger from .model import ModelWrapperBase, ModelResponse -from ..message import MessageBase +from ..message import Msg from ..utils.tools import _convert_to_str try: @@ -71,7 +71,7 @@ def __init__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " @@ -186,7 +186,7 @@ def __call__( def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """Format the input string and dictionary into the format that ZhipuAI Chat API required. @@ -198,7 +198,7 @@ def format( engineering strategies. Args: - args (`Union[MessageBase, Sequence[MessageBase]]`): + args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. @@ -214,11 +214,9 @@ def format( for _ in args: if _ is None: continue - if isinstance(_, MessageBase): + if isinstance(_, Msg): input_msgs.append(_) - elif isinstance(_, list) and all( - isinstance(__, MessageBase) for __ in _ - ): + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( diff --git a/tests/memory_test.py b/tests/memory_test.py index 629fb45d9..55e02c109 100644 --- a/tests/memory_test.py +++ b/tests/memory_test.py @@ -7,7 +7,7 @@ import unittest from unittest.mock import patch, MagicMock -from agentscope.message import Msg, Tht +from agentscope.message import Msg from agentscope.memory import TemporaryMemory @@ -112,32 +112,6 @@ def test_load_export(self) -> None: [user_input, agent_input], ) - def test_tht_memory(self) -> None: - """ - Test temporary memory with Tht, - add, clear, export, loading - """ - memory = TemporaryMemory() - thought = Tht("testing") - memory.add(thought) - - self.assertEqual( - memory.get_memory(), - [thought], - ) - - memory.export(file_path=self.file_name_2) - memory.clear() - self.assertEqual( - memory.get_memory(), - [], - ) - memory.load(self.file_name_2) - self.assertEqual( - memory.get_memory(), - [thought], - ) - if __name__ == "__main__": unittest.main() diff --git a/tests/model_test.py b/tests/model_test.py index f361b7f0d..8701f97d4 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -7,7 +7,7 @@ import unittest from unittest.mock import patch, MagicMock -from agentscope.message import MessageBase +from agentscope.message import Msg from agentscope.models import ( ModelResponse, ModelWrapperBase, @@ -28,7 +28,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: return "" diff --git a/tests/retrieval_from_list_test.py b/tests/retrieval_from_list_test.py index db04468e8..52b30720b 100644 --- a/tests/retrieval_from_list_test.py +++ b/tests/retrieval_from_list_test.py @@ -6,7 +6,7 @@ from agentscope.service import retrieve_from_list, cos_sim from agentscope.service.service_status import ServiceExecStatus -from agentscope.message import MessageBase, Msg, Tht +from agentscope.message import MessageBase, Msg from agentscope.memory.temporary_memory import TemporaryMemory from agentscope.models import OpenAIEmbeddingWrapper, ModelResponse @@ -40,13 +40,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: m2 = Msg(name="env", content="test2", role="assistant") m2.embedding = [0.5, 0.5] m2.timestamp = "2023-12-18 21:50:59" - m3 = Tht(content="test3") - m3.embedding = [0.2, 0.8] - m3.timestamp = "2023-12-18 21:42:59" memory = TemporaryMemory(config={}, embedding_model=dummy_model) memory.add(m1) memory.add(m2) - memory.add(m3) def score_func(m1: MessageBase, m2: MessageBase) -> float: relevance = cos_sim(m1.embedding, m2.embedding).content @@ -65,9 +61,8 @@ def score_func(m1: MessageBase, m2: MessageBase) -> float: preserve_order=False, ) self.assertEqual(retrieved.status, ServiceExecStatus.SUCCESS) - self.assertEqual(retrieved.content[0][2], m3) - self.assertEqual(retrieved.content[1][2], m2) - self.assertEqual(retrieved.content[2][2], m1) + self.assertEqual(retrieved.content[0][2], m2) + self.assertEqual(retrieved.content[1][2], m1) retrieved = retrieve_from_list( query, @@ -78,8 +73,7 @@ def score_func(m1: MessageBase, m2: MessageBase) -> float: preserve_order=True, ) self.assertEqual(retrieved.status, ServiceExecStatus.SUCCESS) - self.assertEqual(retrieved.content[0][2], m2) - self.assertEqual(retrieved.content[1][2], m3) + self.assertEqual(retrieved.content[0][2], m1) # This allows the tests to be run from the command line diff --git a/tests/web_digest_test.py b/tests/web_digest_test.py index 3eb0f6d7b..3b62b0e36 100644 --- a/tests/web_digest_test.py +++ b/tests/web_digest_test.py @@ -9,7 +9,7 @@ from agentscope.service import load_web, digest_webpage from agentscope.service.service_status import ServiceExecStatus from agentscope.models import ModelWrapperBase, ModelResponse -from agentscope.message import Msg, MessageBase +from agentscope.message import Msg class TestWebDigest(unittest.TestCase): @@ -85,7 +85,7 @@ def __call__(self, messages: list[Msg]) -> ModelResponse: def format( self, - *args: Union[MessageBase, Sequence[MessageBase]], + *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: return str(args)