Skip to content

Commit

Permalink
Remove Tht class from AgentScope and Unified typing into Msg rather…
Browse files Browse the repository at this point in the history
… than MessageBase (#313)
  • Loading branch information
DavdGao committed Jul 2, 2024
1 parent 2f38d15 commit 66126d8
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 200 deletions.
13 changes: 6 additions & 7 deletions src/agentscope/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Union
from typing import Callable

from ..message import MessageBase
from ..message import Msg


class MemoryBase(ABC):
Expand Down Expand Up @@ -62,14 +62,13 @@ def get_memory(
@abstractmethod
def add(
self,
memories: Union[Sequence[dict], dict, None],
memories: Union[Sequence[Msg], Msg, None],
) -> None:
"""
Adding new memory fragment, depending on how the memory are stored
Args:
memories (Union[Sequence[dict], dict, None]):
Memories to be added. If the memory is not in MessageBase,
it will first be converted into a message type.
memories (Union[Sequence[Msg], Msg, None]):
Memories to be added.
"""

@abstractmethod
Expand All @@ -85,14 +84,14 @@ def delete(self, index: Union[Iterable, int]) -> None:
@abstractmethod
def load(
self,
memories: Union[str, list[MessageBase], MessageBase],
memories: Union[str, list[Msg], Msg],
overwrite: bool = False,
) -> None:
"""
Load memory, depending on how the memory are passed, design to load
from both file or dict
Args:
memories (Union[str, list[MessageBase], MessageBase]):
memories (Union[str, list[Msg], Msg]):
memories to be loaded.
If it is in str type, it will be first checked if it is a
file; otherwise it will be deserialized as messages.
Expand Down
24 changes: 8 additions & 16 deletions src/agentscope/memory/temporary_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
serialize,
MessageBase,
Msg,
Tht,
PlaceholderMessage,
)

Expand Down Expand Up @@ -58,18 +57,17 @@ def __init__(

def add(
self,
memories: Union[Sequence[dict], dict, None],
memories: Union[Sequence[Msg], Msg, None],
embed: bool = False,
) -> None:
# pylint: disable=too-many-branches
"""
Adding new memory fragment, depending on how the memory are stored
Args:
memories (Union[Sequence[dict], dict, None]):
memories to be added. If the memory is not in MessageBase,
it will first be converted into a message type.
embed (bool):
whether to generate embedding for the new added memories
memories (`Union[Sequence[Msg], Msg, None]`):
Memories to be added.
embed (`bool`):
Whether to generate embedding for the new added memories
"""
if memories is None:
return
Expand All @@ -84,13 +82,7 @@ def add(
for memory_unit in record_memories:
if not issubclass(type(memory_unit), MessageBase):
try:
if (
"name" in memory_unit
and memory_unit["name"] == "thought"
):
memory_unit = Tht(**memory_unit)
else:
memory_unit = Msg(**memory_unit)
memory_unit = Msg(**memory_unit)
except Exception as exc:
raise ValueError(
f"Cannot add {memory_unit} to memory, "
Expand Down Expand Up @@ -186,14 +178,14 @@ def export(

def load(
self,
memories: Union[str, list[MessageBase], MessageBase],
memories: Union[str, list[Msg], Msg],
overwrite: bool = False,
) -> None:
"""
Load memory, depending on how the memory are passed, design to load
from both file or dict
Args:
memories (Union[str, list[MessageBase], MessageBase]):
memories (Union[str, list[Msg], Msg]):
memories to be loaded.
If it is in str type, it will be first checked if it is a
file; otherwise it will be deserialized as messages.
Expand Down
72 changes: 2 additions & 70 deletions src/agentscope/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,73 +173,6 @@ def serialize(self) -> str:
return json.dumps({"__type": "Msg", **self})


class Tht(MessageBase):
"""The Thought message is used to record the thought of the agent to
help them make decisions and responses. Generally, it shouldn't be
passed to or seen by the other agents.
In our framework, we formulate the thought in prompt as follows:
- For OpenAI API calling:
.. code-block:: python
[
...
{
"role": "assistant",
"name": "thought",
"content": "I should ..."
},
...
]
- For open-source models that accepts string as input:
.. code-block:: python
...
{self.name} thought: I should ...
...
We admit that there maybe better ways to formulate the thought. Users
are encouraged to create their own thought formulation methods by
inheriting `MessageBase` class and rewrite the `__init__` and `to_str`
function.
.. code-block:: python
class MyThought(MessageBase):
def to_str(self) -> str:
# implement your own thought formulation method
pass
"""

def __init__(
self,
content: Any,
timestamp: Optional[str] = None,
**kwargs: Any,
) -> None:
if "name" in kwargs:
kwargs.pop("name")
if "role" in kwargs:
kwargs.pop("role")
super().__init__(
name="thought",
content=content,
role="assistant",
timestamp=timestamp,
**kwargs,
)

def to_str(self) -> str:
"""Return the string representation of the message"""
return f"{self.name} thought: {self.content}"

def serialize(self) -> str:
return json.dumps({"__type": "Tht", **self})


class PlaceholderMessage(Msg):
"""A placeholder for the return message of RpcAgent."""

Expand Down Expand Up @@ -374,7 +307,7 @@ def update_value(self) -> MessageBase:
if status == "ERROR":
raise RuntimeError(msg.content)
self.update(msg)
# the actual value has been updated, not a placeholder any more
# the actual value has been updated, not a placeholder anymore
self._is_placeholder = False
return self

Expand Down Expand Up @@ -418,12 +351,11 @@ def serialize(self) -> str:

_MSGS = {
"Msg": Msg,
"Tht": Tht,
"PlaceholderMessage": PlaceholderMessage,
}


def deserialize(s: Union[str, bytes]) -> Union[MessageBase, Sequence]:
def deserialize(s: Union[str, bytes]) -> Union[Msg, Sequence]:
"""Deserialize json string into MessageBase"""
js_msg = json.loads(s)
msg_type = js_msg.pop("__type")
Expand Down
24 changes: 10 additions & 14 deletions src/agentscope/models/dashscope_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Union, List, Sequence
from loguru import logger

from ..message import MessageBase
from ..message import Msg
from ..utils.tools import _convert_to_str, _guess_type_by_extension

try:
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
Expand Down Expand Up @@ -213,7 +213,7 @@ def __call__(

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> List:
"""Format the messages for DashScope Chat API.
Expand Down Expand Up @@ -254,7 +254,7 @@ def format(
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Expand All @@ -269,11 +269,9 @@ def format(
for _ in args:
if _ is None:
continue
if isinstance(_, MessageBase):
if isinstance(_, Msg):
input_msgs.append(_)
elif isinstance(_, list) and all(
isinstance(__, MessageBase) for __ in _
):
elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _):
input_msgs.extend(_)
else:
raise TypeError(
Expand Down Expand Up @@ -655,7 +653,7 @@ def __call__(

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> List:
"""Format the messages for DashScope Multimodal API.
Expand Down Expand Up @@ -737,7 +735,7 @@ def format(
"file://", which will be attached in this format function.
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Expand All @@ -752,11 +750,9 @@ def format(
for _ in args:
if _ is None:
continue
if isinstance(_, MessageBase):
if isinstance(_, Msg):
input_msgs.append(_)
elif isinstance(_, list) and all(
isinstance(__, MessageBase) for __ in _
):
elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _):
input_msgs.extend(_)
else:
raise TypeError(
Expand Down
12 changes: 5 additions & 7 deletions src/agentscope/models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from loguru import logger

from agentscope.message import Msg, MessageBase
from agentscope.message import Msg
from agentscope.models import ModelWrapperBase, ModelResponse
from agentscope.utils.tools import _convert_to_str

Expand Down Expand Up @@ -250,7 +250,7 @@ def _register_default_metrics(self) -> None:

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> List[dict]:
"""This function provide a basic prompting strategy for Gemini Chat
API in multi-party conversation, which combines all input into a
Expand Down Expand Up @@ -279,7 +279,7 @@ def format(
https://github.com/agentscope/agentscope!
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Expand All @@ -292,11 +292,9 @@ def format(
for _ in args:
if _ is None:
continue
if isinstance(_, MessageBase):
if isinstance(_, Msg):
input_msgs.append(_)
elif isinstance(_, list) and all(
isinstance(__, MessageBase) for __ in _
):
elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _):
input_msgs.extend(_)
else:
raise TypeError(
Expand Down
16 changes: 7 additions & 9 deletions src/agentscope/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger

from .model import ModelWrapperBase, ModelResponse
from ..message import MessageBase
from ..message import Msg
from ..utils.tools import _convert_to_str

try:
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
Expand Down Expand Up @@ -183,16 +183,16 @@ def __call__(

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> List[dict]:
"""Format the input string and dictionary into the unified format.
Note that the format function might not be the optimal way to contruct
Note that the format function might not be the optimal way to construct
prompt for every model, but a common way to do so.
Developers are encouraged to implement their own prompt
engineering strategies if have strong performance concerns.
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Expand All @@ -207,11 +207,9 @@ def format(
for _ in args:
if _ is None:
continue
if isinstance(_, MessageBase):
if isinstance(_, Msg):
input_msgs.append(_)
elif isinstance(_, list) and all(
isinstance(__, MessageBase) for __ in _
):
elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _):
input_msgs.extend(_)
else:
raise TypeError(
Expand Down
4 changes: 2 additions & 2 deletions src/agentscope/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from ..exception import ResponseParsingError

from ..file_manager import file_manager
from ..message import MessageBase
from ..message import Msg
from ..utils import MonitorFactory
from ..utils.monitor import get_full_name
from ..utils.tools import _get_timestamp
Expand Down Expand Up @@ -227,7 +227,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse:

def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
"""Format the input string or dict into the format that the model
API required."""
Expand Down
Loading

0 comments on commit 66126d8

Please sign in to comment.