From 07f9dbbf6915b15520c99b37e91ab3bfa7f63684 Mon Sep 17 00:00:00 2001 From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com> Date: Wed, 3 Apr 2024 18:02:44 +0800 Subject: [PATCH] Support agent_id in rpc agent server (#94) --- notebook/distributed_debate.ipynb | 27 +- src/agentscope/agents/agent.py | 5 +- src/agentscope/agents/rpc_agent.py | 308 +++++++++++++++++------ src/agentscope/message.py | 31 ++- src/agentscope/rpc/__init__.py | 4 +- src/agentscope/rpc/rpc_agent.proto | 1 + src/agentscope/rpc/rpc_agent_client.py | 113 ++++++++- src/agentscope/rpc/rpc_agent_pb2.py | 8 +- src/agentscope/rpc/rpc_agent_pb2_grpc.py | 2 +- src/agentscope/utils/logging_utils.py | 4 +- tests/rpc_agent_test.py | 121 +++++++++ 11 files changed, 506 insertions(+), 118 deletions(-) diff --git a/notebook/distributed_debate.ipynb b/notebook/distributed_debate.ipynb index 5ce024b4a..e2a0c21f4 100644 --- a/notebook/distributed_debate.ipynb +++ b/notebook/distributed_debate.ipynb @@ -146,19 +146,20 @@ "\n", "\n", "\"\"\"Setup the main debate competition process\"\"\"\n", - "participants = [pro_agent, con_agent, judge_agent]\n", - "hint = Msg(name=\"System\", content=ANNOUNCEMENT)\n", - "x = None\n", - "with msghub(participants=participants, announcement=hint):\n", - " for _ in range(3):\n", - " pro_resp = pro_agent(x)\n", - " logger.chat(pro_resp)\n", - " con_resp = con_agent(pro_resp)\n", - " logger.chat(con_resp)\n", - " x = judge_agent(con_resp)\n", - " logger.chat(x)\n", - " x = judge_agent(x)\n", - " logger.chat(x)\n" + "if __name__ == \"__main__\":\n", + " participants = [pro_agent, con_agent, judge_agent]\n", + " hint = Msg(name=\"System\", content=ANNOUNCEMENT)\n", + " x = None\n", + " with msghub(participants=participants, announcement=hint):\n", + " for _ in range(3):\n", + " pro_resp = pro_agent(x)\n", + " logger.chat(pro_resp)\n", + " con_resp = con_agent(pro_resp)\n", + " logger.chat(con_resp)\n", + " x = judge_agent(con_resp)\n", + " logger.chat(x)\n", + " x = judge_agent(x)\n", + " logger.chat(x)\n" ] }, { diff --git a/src/agentscope/agents/agent.py b/src/agentscope/agents/agent.py index 300bc78c4..dc32dcd5c 100644 --- a/src/agentscope/agents/agent.py +++ b/src/agentscope/agents/agent.py @@ -204,7 +204,7 @@ def to_dist( self, host: str = "localhost", port: int = None, - max_pool_size: int = 100, + max_pool_size: int = 8192, max_timeout_seconds: int = 1800, launch_server: bool = True, local_mode: bool = True, @@ -217,7 +217,7 @@ def to_dist( Hostname of the rpc agent server. port (`int`, defaults to `None`): Port of the rpc agent server. - max_pool_size (`int`, defaults to `100`): + max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_timeout_seconds (`int`, defaults to `1800`): Timeout for task results. @@ -246,4 +246,5 @@ def to_dist( launch_server=launch_server, local_mode=local_mode, lazy_launch=lazy_launch, + agent_id=self.agent_id, ) diff --git a/src/agentscope/agents/rpc_agent.py b/src/agentscope/agents/rpc_agent.py index 5ab3102a7..57fb91a68 100644 --- a/src/agentscope/agents/rpc_agent.py +++ b/src/agentscope/agents/rpc_agent.py @@ -1,22 +1,12 @@ # -*- coding: utf-8 -*- """ Base class for Rpc Agent """ -from multiprocessing import ( - Process, - Event, - Pipe, -) +from multiprocessing import Process, Event, Pipe, cpu_count from multiprocessing.synchronize import Event as EventClass import socket import threading -import time import json -from typing import Any -from typing import Optional -from typing import Union -from typing import Type -from typing import Sequence -from queue import Queue +from typing import Any, Optional, Union, Type, Sequence from concurrent import futures from loguru import logger @@ -73,47 +63,66 @@ def __init__( self, name: str, agent_class: Type[AgentBase], - agent_configs: dict, + agent_configs: Optional[dict] = None, host: str = "localhost", port: int = None, - max_pool_size: int = 100, - max_timeout_seconds: int = 1800, launch_server: bool = True, + max_pool_size: int = 8192, + max_timeout_seconds: int = 1800, local_mode: bool = True, lazy_launch: bool = True, + agent_id: str = None, + create_with_agent_configs: bool = True, ) -> None: """Initialize a RpcAgent instance. Args: - agent_class (`Type[AgentBase]`, defaults to `None`): - The AgentBase subclass encapsulated by this wrapper. - agent_configs (`dict`): The args used to initialize the - agent_class. name (`str`): Name of the agent. + agent_class (`Type[AgentBase]`): + The AgentBase subclass encapsulated by this wrapper. + agent_configs (`dict`, defaults to `None`): The args used to + initialize the agent_class. host (`str`, defaults to `"localhost"`): Hostname of the rpc agent server. port (`int`, defaults to `None`): Port of the rpc agent server. - max_pool_size (`int`, defaults to `100`): + launch_server (`bool`, defaults to `True`): + Whether to launch the gRPC agent server. + max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_timeout_seconds (`int`, defaults to `1800`): Timeout for task results. local_mode (`bool`, defaults to `True`): - Whether the started rpc server only listens to local + Whether the started gRPC server only listens to local requests. lazy_launch (`bool`, defaults to `True`): Only launch the server when the agent is called. + agent_id (`str`, defaults to `None`): + The agent id of this instance. If `None`, it will + be generated randomly. + create_with_agent_configs (`bool`, defaults to `True`): + Only takes effect when `agent_configs` is provided. + If true, create the agent instance for the agent with + provided `agent_configs`, otherwise uses the agent server's + default parameters. """ super().__init__(name=name) self.host = host self.port = port self.server_launcher = None self.client = None + if agent_id is not None: + self._agent_id = agent_id + else: + self._agent_id = agent_class.generate_agent_id() + self.agent_class = agent_class if launch_server: self.server_launcher = RpcAgentServerLauncher( agent_class=agent_class, - agent_args=agent_configs["args"], - agent_kwargs=agent_configs["kwargs"], + agent_args=agent_configs["args"] if agent_configs else None, + agent_kwargs=( + agent_configs["kwargs"] if agent_configs else None + ), host=host, port=port, max_pool_size=max_pool_size, @@ -123,23 +132,33 @@ def __init__( if not lazy_launch: self._launch_server() else: - self.client = RpcAgentClient(host=self.host, port=self.port) + self.client = RpcAgentClient( + host=self.host, + port=self.port, + agent_id=self.agent_id, + ) + self.client.create_agent( + agent_configs if create_with_agent_configs else None, + ) def _launch_server(self) -> None: """Launch a rpc server and update the port and the client""" self.server_launcher.launch() self.port = self.server_launcher.port - self.client = RpcAgentClient(host=self.host, port=self.port) + self.client = RpcAgentClient( + host=self.host, + port=self.port, + agent_id=self.agent_id, + ) def reply(self, x: dict = None) -> dict: if self.client is None: self._launch_server() - res_msg = self.client.call_func( - func_name="_call", - value=x.serialize() if x is not None else "", - ) return PlaceholderMessage( - **deserialize(res_msg), # type: ignore[arg-type] + name=self.name, + content=None, + client=self.client, + x=x, ) def observe(self, x: Union[dict, Sequence[dict]]) -> None: @@ -150,17 +169,61 @@ def observe(self, x: Union[dict, Sequence[dict]]) -> None: value=serialize(x), # type: ignore[arg-type] ) + def clone_instances( + self, + num_instances: int, + including_self: bool = True, + ) -> Sequence[AgentBase]: + """ + Clone a series of this instance with different agent_id and + return them as a list. + + Args: + num_instances (`int`): The number of instances in the returned + list. + including_self (`bool`): Whether to include the instance calling + this method in the returned list. + + Returns: + `Sequence[AgentBase]`: A list of agent instances. + """ + generated_instance_number = ( + num_instances - 1 if including_self else num_instances + ) + generated_instances = [] + + # launch the server before clone instances + if self.client is None: + self._launch_server() + + # put itself as the first element of the returned list + if including_self: + generated_instances.append(self) + + # clone instances without agent server + for _ in range(generated_instance_number): + generated_instances.append( + RpcAgent( + name=self.name, + agent_class=self.agent_class, + host=self.host, + port=self.port, + launch_server=False, + create_with_agent_configs=False, + ), + ) + return generated_instances + def stop(self) -> None: - """Stop the RpcAgent and the launched rpc server.""" + """Stop the RpcAgent and the rpc server.""" if self.server_launcher is not None: self.server_launcher.shutdown() def __del__(self) -> None: - if self.server_launcher is not None: - self.server_launcher.shutdown() + self.stop() -def setup_rcp_agent_server( +def setup_rpc_agent_server( agent_class: Type[AgentBase], agent_args: tuple, agent_kwargs: dict, @@ -171,9 +234,8 @@ def setup_rcp_agent_server( stop_event: EventClass = None, pipe: int = None, local_mode: bool = True, - max_pool_size: int = 100, + max_pool_size: int = 8192, max_timeout_seconds: int = 1800, - max_workers: int = 4, ) -> None: """Setup gRPC server rpc agent. @@ -200,18 +262,18 @@ def setup_rcp_agent_server( A pipe instance used to pass the actual port of the server. local_mode (`bool`, defaults to `None`): Only listen to local requests. - max_pool_size (`int`, defaults to `100`): + max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_timeout_seconds (`int`, defaults to `1800`): Timeout for task results. - max_workers (`int`, defaults to `4`): - max worker number of grpc server. """ if init_settings is not None: init_process(**init_settings) servicer = RpcServerSideWrapper( - agent_class(*agent_args, **agent_kwargs), + agent_class, + agent_args, + agent_kwargs, host=host, port=port, max_pool_size=max_pool_size, @@ -226,7 +288,7 @@ def setup_rcp_agent_server( f" [{port}]...", ) server = grpc.server( - futures.ThreadPoolExecutor(max_workers=max_workers), + futures.ThreadPoolExecutor(max_workers=cpu_count()), ) add_RpcAgentServicer_to_server(servicer, server) if local_mode: @@ -248,12 +310,12 @@ def setup_rcp_agent_server( pipe.send(port) start_event.set() stop_event.wait() + logger.info( + f"Stopping rpc server [{agent_class.__name__}] at port [{port}]", + ) + server.stop(1.0).wait() else: server.wait_for_termination() - logger.info( - f"Stopping rpc server [{agent_class.__name__}] at port [{port}]", - ) - server.stop(0) logger.info( f"rpc server [{agent_class.__name__}] at port [{port}] stopped " "successfully", @@ -306,7 +368,7 @@ def __init__( agent_kwargs: dict = None, host: str = "localhost", port: int = None, - max_pool_size: int = 100, + max_pool_size: int = 8192, max_timeout_seconds: int = 1800, local_mode: bool = False, ) -> None: @@ -323,7 +385,7 @@ def __init__( Hostname of the rpc agent server. port (`int`, defaults to `None`): Port of the rpc agent server. - max_pool_size (`int`, defaults to `100`): + max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_timeout_seconds (`int`, defaults to `1800`): Timeout for task results. @@ -346,7 +408,7 @@ def __init__( def _launch_in_main(self) -> None: """Launch gRPC server in main-process""" server_thread = threading.Thread( - target=setup_rcp_agent_server, + target=setup_rpc_agent_server, kwargs={ "agent_class": self.agent_class, "agent_args": self.agent_args, @@ -371,7 +433,7 @@ def _launch_in_sub(self) -> None: self.parent_con, child_con = Pipe() start_event = Event() server_process = Process( - target=setup_rcp_agent_server, + target=setup_rpc_agent_server, kwargs={ "agent_class": self.agent_class, "agent_args": self.agent_args, @@ -420,8 +482,7 @@ def shutdown(self) -> None: if self.stop_event is not None: self.stop_event.set() self.stop_event = None - self.server.join(timeout=5) - self.server.terminate() + self.server.join() if self.server.is_alive(): self.server.kill() logger.info( @@ -436,21 +497,28 @@ class RpcServerSideWrapper(RpcAgentServicer): def __init__( self, - agent_instance: AgentBase, + agent_class: Type[AgentBase], + agent_args: tuple, + agent_kwargs: dict, host: str = "localhost", port: int = None, - max_pool_size: int = 100, + max_pool_size: int = 8192, max_timeout_seconds: int = 1800, ): """Init the service side wrapper. Args: - agent_instance (`AgentBase`): an instance of `AgentBase`. + agent_class (`Type[AgentBase]`): The AgentBase subclass + encapsulated by this wrapper. + agent_args (`tuple`): The args tuple used to initialize the + agent_class. + agent_kwargs (`dict`): The args dict used to initialize the + agent_class. host (`str`, defaults to "localhost"): Hostname of the rpc agent server. port (`int`, defaults to `None`): Port of the rpc agent server. - max_pool_size (`int`, defaults to `100`): + max_pool_size (`int`, defaults to `8192`): The max number of task results that the server can accommodate. Note that the oldest result will be deleted after exceeding the pool size. @@ -458,18 +526,20 @@ def __init__( Timeout for task results. Note that expired results will be deleted. """ + self.agent_class = agent_class + self.agent_args = agent_args + self.agent_kwargs = agent_kwargs self.host = host self.port = port self.result_pool = ExpiringDict( max_len=max_pool_size, max_age_seconds=max_timeout_seconds, ) - self.task_queue = Queue() - self.worker_thread = threading.Thread(target=self.process_tasks) - self.worker_thread.start() + self.executor = futures.ThreadPoolExecutor(max_workers=cpu_count()) self.task_id_lock = threading.Lock() + self.agent_id_lock = threading.Lock() self.task_id_counter = 0 - self.agent = agent_instance + self.agent_pool: dict[str, AgentBase] = {} def get_task_id(self) -> int: """Get the auto-increment task id.""" @@ -477,21 +547,65 @@ def get_task_id(self) -> int: self.task_id_counter += 1 return self.task_id_counter + def check_and_generate_agent( + self, + agent_id: str, + agent_configs: dict = None, + ) -> None: + """ + Check whether the agent exists, and create new agent instance + for new agent. + + Args: + agent_id (`str`): the agent id. + """ + with self.agent_id_lock: + if agent_id not in self.agent_pool: + if agent_configs is not None: + agent_instance = self.agent_class( + *agent_configs["args"], + **agent_configs["kwargs"], + ) + else: + agent_instance = self.agent_class( + *self.agent_args, + **self.agent_kwargs, + ) + agent_instance._agent_id = agent_id # pylint: disable=W0212 + self.agent_pool[agent_id] = agent_instance + logger.info(f"create agent instance [{agent_id}]") + + def check_and_delete_agent(self, agent_id: str) -> None: + """ + Check whether the agent exists, and delete the agent instance + for the agent_id. + + Args: + agent_id (`str`): the agent id. + """ + with self.agent_id_lock: + if agent_id in self.agent_pool: + self.agent_pool.pop(agent_id) + logger.info(f"delete agent instance [{agent_id}]") + def call_func(self, request: RpcMsg, _: ServicerContext) -> RpcMsg: """Call the specific servicer function.""" if hasattr(self, request.target_func): + if request.target_func not in ["_create_agent", "_get"]: + self.check_and_generate_agent(request.agent_id) return getattr(self, request.target_func)(request) else: + # TODO: support other user defined method logger.error(f"Unsupported method {request.target_func}") return RpcMsg( value=Msg( - name=self.agent.name, + name=self.agent_pool[request.agent_id].name, content=f"Unsupported method {request.target_func}", role="assistant", ).serialize(), ) - def _call(self, request: RpcMsg) -> RpcMsg: + def _reply(self, request: RpcMsg) -> RpcMsg: """Call function of RpcAgentService Args: @@ -508,14 +622,17 @@ def _call(self, request: RpcMsg) -> RpcMsg: else: msg = None task_id = self.get_task_id() - self.task_queue.put((task_id, msg)) + self.result_pool[task_id] = threading.Condition() + self.executor.submit( + self.process_messages, + task_id, + request.agent_id, + msg, # type: ignore[arg-type] + ) return RpcMsg( value=Msg( - name=self.agent.name, + name=self.agent_pool[request.agent_id].name, content=None, - role="assistant", - host=self.host, - port=self.port, task_id=task_id, ).serialize(), ) @@ -534,14 +651,15 @@ def _get(self, request: RpcMsg) -> RpcMsg: Returns: `RpcMsg`: Concrete values of the specific message (or part of it). """ - # todo: add format specification of request msg = json.loads(request.value) - # todo: implement the waiting in a more elegant way, add timeout while True: - result = self.result_pool.get(msg["task_id"], None) - if result is not None: - return RpcMsg(value=result.serialize()) - time.sleep(0.1) + result = self.result_pool.get(msg["task_id"]) + if isinstance(result, threading.Condition): + with result: + result.wait(timeout=1) + else: + break + return RpcMsg(value=result.serialize()) def _observe(self, request: RpcMsg) -> RpcMsg: """Observe function of RpcAgentService @@ -557,15 +675,41 @@ def _observe(self, request: RpcMsg) -> RpcMsg: for msg in msgs: if isinstance(msg, PlaceholderMessage): msg.update_value() - self.agent.observe(msgs) + self.agent_pool[request.agent_id].observe(msgs) return RpcMsg() - def process_tasks(self) -> None: - """Task processing thread.""" - while True: - task_id, task_msg = self.task_queue.get() - # TODO: optimize this and avoid blocking - if isinstance(task_msg, PlaceholderMessage): - task_msg.update_value() - result = self.agent.reply(task_msg) - self.result_pool[task_id] = result + def _create_agent(self, request: RpcMsg) -> RpcMsg: + """Create a new agent instance for the agent_id. + + Args: + request (RpcMsg): request message with a `agent_id` field. + """ + self.check_and_generate_agent( + request.agent_id, + agent_configs=json.loads(request.value) if request.value else None, + ) + return RpcMsg() + + def _delete_agent(self, request: RpcMsg) -> RpcMsg: + """Delete the agent instance of the specific sesssion_id. + + Args: + request (RpcMsg): request message with a `agent_id` field. + """ + self.check_and_delete_agent(request.agent_id) + return RpcMsg() + + def process_messages( + self, + task_id: int, + agent_id: str, + task_msg: dict = None, + ) -> None: + """Task processing.""" + if isinstance(task_msg, PlaceholderMessage): + task_msg.update_value() + cond = self.result_pool[task_id] + result = self.agent_pool[agent_id].reply(task_msg) + self.result_pool[task_id] = result + with cond: + cond.notify_all() diff --git a/src/agentscope/message.py b/src/agentscope/message.py index 029131514..8e3c8a67b 100644 --- a/src/agentscope/message.py +++ b/src/agentscope/message.py @@ -7,7 +7,7 @@ from loguru import logger -from .rpc import RpcAgentClient +from .rpc import RpcAgentClient, ResponseStub, call_in_thread from .utils.tools import _get_timestamp @@ -219,6 +219,7 @@ class PlaceholderMessage(MessageBase): "_port", "_client", "_task_id", + "_stub", "_is_placeholder", } @@ -237,6 +238,8 @@ def __init__( host: str = None, port: int = None, task_id: int = None, + client: Optional[RpcAgentClient] = None, + x: dict = None, **kwargs: Any, ) -> None: """A placeholder message, records the address of the real message. @@ -266,6 +269,11 @@ def __init__( The port of the rpc server where the real message is located. task_id (`int`, defaults to `None`): The task id of the real message in the rpc server. + client (`RpcAgentClient`, defaults to `None`): + An RpcAgentClient instance used to connect to the generator of + this placeholder. + x (`dict`, defaults to `None`): + Input parameters used to call rpc methods on the client. """ super().__init__( name=name, @@ -276,9 +284,16 @@ def __init__( ) # placeholder indicates whether the real message is still in rpc server self._is_placeholder = True - self._host = host - self._port = port - self._task_id = task_id + if client is None: + self._stub: ResponseStub = None + self._host: str = host + self._port: int = port + self._task_id: int = task_id + else: + self._stub = call_in_thread(client, x, "_reply") + self._host = client.host + self._port = client.port + self._task_id = None def __is_local(self, key: Any) -> bool: return ( @@ -316,6 +331,7 @@ def update_value(self) -> MessageBase: """Get attribute values from rpc agent server immediately""" if self._is_placeholder: # retrieve real message from rpc agent server + self.__update_task_id() client = RpcAgentClient(self._host, self._port) result = client.call_func( func_name="_get", @@ -326,8 +342,15 @@ def update_value(self) -> MessageBase: self._is_placeholder = False return self + def __update_task_id(self) -> None: + if self._stub is not None: + resp = deserialize(self._stub.get_response()) + self._task_id = resp["task_id"] # type: ignore[call-overload] + self._stub = None + def serialize(self) -> str: if self._is_placeholder: + self.__update_task_id() return json.dumps( { "__type": "PlaceholderMessage", diff --git a/src/agentscope/rpc/__init__.py b/src/agentscope/rpc/__init__.py index 283b2cdba..ead4650eb 100644 --- a/src/agentscope/rpc/__init__.py +++ b/src/agentscope/rpc/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """Import all rpc related modules in the package.""" from typing import Any -from .rpc_agent_client import RpcAgentClient +from .rpc_agent_client import RpcAgentClient, ResponseStub, call_in_thread try: from .rpc_agent_pb2 import RpcMsg # pylint: disable=E0611 @@ -19,6 +19,8 @@ __all__ = [ "RpcAgentClient", + "ResponseStub", + "call_in_thread", "RpcMsg", "RpcAgentServicer", "RpcAgentStub", diff --git a/src/agentscope/rpc/rpc_agent.proto b/src/agentscope/rpc/rpc_agent.proto index f11bbf2f2..fe27e0d1e 100644 --- a/src/agentscope/rpc/rpc_agent.proto +++ b/src/agentscope/rpc/rpc_agent.proto @@ -9,4 +9,5 @@ service RpcAgent { message RpcMsg { string value = 1; string target_func = 2; + string agent_id = 3; } \ No newline at end of file diff --git a/src/agentscope/rpc/rpc_agent_client.py b/src/agentscope/rpc/rpc_agent_client.py index 768a81515..98b82a6d5 100644 --- a/src/agentscope/rpc/rpc_agent_client.py +++ b/src/agentscope/rpc/rpc_agent_client.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- """ Client of rpc agent server """ -from typing import Any +import json +import threading +from typing import Any, Optional +from loguru import logger try: import grpc @@ -12,31 +15,38 @@ from agentscope.rpc.rpc_agent_pb2 import RpcMsg # pylint: disable=E0611 from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentStub except ModuleNotFoundError: - RpcMsg = Any + RpcMsg = Any # type: ignore[misc] RpcAgentStub = Any class RpcAgentClient: """A client of Rpc agent server""" - def __init__(self, host: str, port: int) -> None: + def __init__(self, host: str, port: int, agent_id: str = "") -> None: """Init a rpc agent client Args: - host (str): the hostname of the rpc agent server which the + host (`str`): the hostname of the rpc agent server which the client is connected. - port (int): the port of the rpc agent server which the client + port (`int`): the port of the rpc agent server which the client is connected. + agent_id (`str`): the agent id of the agent being called. """ self.host = host self.port = port + self.agent_id = agent_id - def call_func(self, func_name: str, value: str = None) -> str: + def call_func( + self, + func_name: str, + value: Optional[str] = None, + timeout: int = 300, + ) -> str: """Call the specific function of rpc server. Args: - func_name (str): the name of the function being called. - x (str, optional): the seralized input value. Defaults to None. + func_name (`str`): the name of the function being called. + x (`str`, optional): the seralized input value. Defaults to None. Returns: str: serialized return data. @@ -44,6 +54,91 @@ def call_func(self, func_name: str, value: str = None) -> str: with grpc.insecure_channel(f"{self.host}:{self.port}") as channel: stub = RpcAgentStub(channel) result_msg = stub.call_func( - RpcMsg(value=value, target_func=func_name), + RpcMsg( + value=value, + target_func=func_name, + agent_id=self.agent_id, + ), + timeout=timeout, ) return result_msg.value + + def create_agent(self, agent_configs: Optional[dict]) -> None: + """Create a new agent for this client.""" + try: + if self.agent_id is None or len(self.agent_id) == 0: + return + self.call_func( + func_name="_create_agent", + value=( + None + if agent_configs is None + else json.dumps(agent_configs) + ), + ) + except Exception as e: + logger.error( + f"Fail to create agent with id [{self.agent_id}]: {e}", + ) + + def delete_agent(self) -> None: + """ + Delete the agent created by this client. + """ + try: + if self.agent_id is not None and len(self.agent_id) > 0: + self.call_func("_delete_agent", timeout=5) + except Exception: + logger.warning( + f"Fail to delete agent with id [{self.agent_id}]", + ) + + +class ResponseStub: + """A stub used to save the response of an rpc call in a sub-thread.""" + + def __init__(self) -> None: + self.response = None + self.condition = threading.Condition() + + def set_response(self, response: str) -> None: + """Set the message.""" + with self.condition: + self.response = response + self.condition.notify_all() + + def get_response(self) -> str: + """Get the message.""" + with self.condition: + while self.response is None: + self.condition.wait() + return self.response + + +def call_in_thread( + client: RpcAgentClient, + x: dict, + func_name: str, +) -> ResponseStub: + """Call rpc function in a sub-thread. + + Args: + client (`RpcAgentClient`): the rpc client. + x (`dict`): the value of the reqeust. + func_name (`str`): the name of the function being called. + + Returns: + `ResponseStub`: a stub to get the response. + """ + stub = ResponseStub() + + def wrapper() -> None: + resp = client.call_func( + func_name=func_name, + value=x.serialize() if x is not None else "", + ) + stub.set_response(resp) # type: ignore[arg-type] + + thread = threading.Thread(target=wrapper) + thread.start() + return stub diff --git a/src/agentscope/rpc/rpc_agent_pb2.py b/src/agentscope/rpc/rpc_agent_pb2.py index 8dc9c6111..3480a543c 100644 --- a/src/agentscope/rpc/rpc_agent_pb2.py +++ b/src/agentscope/rpc/rpc_agent_pb2.py @@ -14,7 +14,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0frpc_agent.proto",\n\x06RpcMsg\x12\r\n\x05value\x18\x01 \x01(\t\x12\x13\n\x0btarget_func\x18\x02 \x01(\t2+\n\x08RpcAgent\x12\x1f\n\tcall_func\x12\x07.RpcMsg\x1a\x07.RpcMsg"\x00\x62\x06proto3', + b'\n\x0frpc_agent.proto">\n\x06RpcMsg\x12\r\n\x05value\x18\x01 \x01(\t\x12\x13\n\x0btarget_func\x18\x02 \x01(\t\x12\x10\n\x08\x61gent_id\x18\x03 \x01(\t2+\n\x08RpcAgent\x12\x1f\n\tcall_func\x12\x07.RpcMsg\x1a\x07.RpcMsg"\x00\x62\x06proto3', ) _globals = globals() @@ -23,7 +23,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_RPCMSG"]._serialized_start = 19 - _globals["_RPCMSG"]._serialized_end = 63 - _globals["_RPCAGENT"]._serialized_start = 65 - _globals["_RPCAGENT"]._serialized_end = 108 + _globals["_RPCMSG"]._serialized_end = 81 + _globals["_RPCAGENT"]._serialized_start = 83 + _globals["_RPCAGENT"]._serialized_end = 126 # @@protoc_insertion_point(module_scope) diff --git a/src/agentscope/rpc/rpc_agent_pb2_grpc.py b/src/agentscope/rpc/rpc_agent_pb2_grpc.py index 2aaf80c0f..93ee27369 100644 --- a/src/agentscope/rpc/rpc_agent_pb2_grpc.py +++ b/src/agentscope/rpc/rpc_agent_pb2_grpc.py @@ -10,7 +10,7 @@ class RpcAgentStub(object): - """Rpc agent Server Stub""" + """Servicer for rpc agent server""" def __init__(self, channel): """Constructor. diff --git a/src/agentscope/utils/logging_utils.py b/src/agentscope/utils/logging_utils.py index 35cb6d957..4ea8c8482 100644 --- a/src/agentscope/utils/logging_utils.py +++ b/src/agentscope/utils/logging_utils.py @@ -82,10 +82,10 @@ def _chat( "content" keys, and the message will be logged as ": ". """ - # Save message into file + # Save message into file, add default to ignore not serializable objects logger.log( LEVEL_CHAT_SAVE, - json.dumps(message, ensure_ascii=False), + json.dumps(message, ensure_ascii=False, default=lambda _: None), *args, **kwargs, ) diff --git a/tests/rpc_agent_test.py b/tests/rpc_agent_test.py index 0d7e3ab06..a20edbc42 100644 --- a/tests/rpc_agent_test.py +++ b/tests/rpc_agent_test.py @@ -395,3 +395,124 @@ def test_standalone_multiprocess_init(self) -> None: msg = agent_b(msg) logger.chat(msg) self.assertTrue(msg["content"]["quota_exceeded"]) + + def test_multi_agent(self) -> None: + """test agent server with multi agent""" + launcher = RpcAgentServerLauncher( + # choose port automatically + agent_class=DemoRpcAgentWithMemory, + agent_kwargs={ + "name": "a", + }, + local_mode=False, + host="127.0.0.1", + port=12010, + ) + launcher.launch() + # although agent1 and agent2 connect to the same server + # they are different instances with different memories + agent1 = DemoRpcAgentWithMemory( + name="a", + ) + oid = agent1.agent_id + agent1 = agent1.to_dist( + host="127.0.0.1", + port=launcher.port, + launch_server=False, + ) + self.assertEqual(oid, agent1.agent_id) + self.assertEqual(oid, agent1.client.agent_id) + agent2 = DemoRpcAgentWithMemory( + name="a", + ).to_dist( + host="127.0.0.1", + port=launcher.port, + launch_server=False, + ) + # agent3 has the same agent id as agent1 + # so it share the same memory with agent1 + agent3 = DemoRpcAgentWithMemory( + name="a", + ).to_dist( + host="127.0.0.1", + port=launcher.port, + launch_server=False, + ) + agent3._agent_id = agent1.agent_id # pylint: disable=W0212 + agent3.client.agent_id = agent1.client.agent_id + msg1 = Msg(name="System", content="First Msg for agent1") + res1 = agent1(msg1) + self.assertEqual(res1.content["mem_size"], 1) + msg2 = Msg(name="System", content="First Msg for agent2") + res2 = agent2(msg2) + self.assertEqual(res2.content["mem_size"], 1) + msg3 = Msg(name="System", content="First Msg for agent3") + res3 = agent3(msg3) + self.assertEqual(res3.content["mem_size"], 3) + msg4 = Msg(name="System", content="Second Msg for agent2") + res4 = agent2(msg4) + self.assertEqual(res4.content["mem_size"], 3) + # delete existing agent + agent2.client.delete_agent() + msg2 = Msg(name="System", content="First Msg for agent2") + res2 = agent2(msg2) + self.assertEqual(res2.content["mem_size"], 1) + + # should override remote default parameter(e.g. name field) + agent4 = DemoRpcAgentWithMemory( + name="b", + ).to_dist( + host="127.0.0.1", + port=launcher.port, + launch_server=False, + ) + msg5 = Msg(name="System", content="Second Msg for agent4") + res5 = agent4(msg5) + self.assertEqual(res5.name, "b") + self.assertEqual(res5.content["mem_size"], 1) + launcher.shutdown() + + def test_clone_instances(self) -> None: + """Test the clone_instances method of RpcAgent""" + agent = DemoRpcAgentWithMemory( + name="a", + ).to_dist() + # lazy launch will not init client + self.assertIsNone(agent.client) + # generate two agents (the first is it self) + agents = agent.clone_instances(2) + self.assertEqual(len(agents), 2) + agent1 = agents[0] + agent2 = agents[1] + self.assertTrue(agent1.agent_id.startswith("DemoRpcAgentWithMemory")) + self.assertTrue(agent2.agent_id.startswith("DemoRpcAgentWithMemory")) + self.assertTrue( + agent1.client.agent_id.startswith("DemoRpcAgentWithMemory"), + ) + self.assertTrue( + agent2.client.agent_id.startswith("DemoRpcAgentWithMemory"), + ) + self.assertNotEqual(agent1.agent_id, agent2.agent_id) + self.assertEqual(agent1.agent_id, agent1.client.agent_id) + self.assertEqual(agent2.agent_id, agent2.client.agent_id) + # clone instance will init client + self.assertIsNotNone(agent.client) + self.assertEqual(agent.agent_id, agent1.agent_id) + self.assertNotEqual(agent1.agent_id, agent2.agent_id) + self.assertIsNotNone(agent.server_launcher) + self.assertIsNotNone(agent1.server_launcher) + self.assertIsNone(agent2.server_launcher) + msg1 = Msg(name="System", content="First Msg for agent1") + res1 = agent1(msg1) + self.assertEqual(res1.content["mem_size"], 1) + msg2 = Msg(name="System", content="First Msg for agent2") + res2 = agent2(msg2) + self.assertEqual(res2.content["mem_size"], 1) + new_agents = agent.clone_instances(2, including_self=False) + agent3 = new_agents[0] + agent4 = new_agents[1] + self.assertEqual(len(new_agents), 2) + self.assertNotEqual(agent3.agent_id, agent.agent_id) + self.assertNotEqual(agent4.agent_id, agent.agent_id) + self.assertIsNone(agent3.server_launcher) + self.assertIsNone(agent4.server_launcher)