diff --git a/mesa/agent.py b/mesa/agent.py index e65b25f69c2..8805673193d 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -11,6 +11,7 @@ import contextlib import copy import operator +import warnings import weakref from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence from random import Random @@ -216,25 +217,64 @@ def _update(self, agents: Iterable[Agent]): self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) return self - def do( - self, method: str | Callable, *args, return_results: bool = False, **kwargs - ) -> AgentSet | list[Any]: + def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: """ Invoke a method or function on each agent in the AgentSet. Args: - method (str, callable): the callable to do on each agents + method (str, callable): the callable to do on each agent * in case of str, the name of the method to call on each agent. * in case of callable, the function to be called with each agent as first argument - return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls. *args: Variable length argument list passed to the callable being called. **kwargs: Arbitrary keyword arguments passed to the callable being called. Returns: AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself. """ + try: + return_results = kwargs.pop("return_results") + except KeyError: + return_results = False + else: + warnings.warn( + "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " + "AgentSet.map in case of return_results=True", + stacklevel=2, + ) + + if return_results: + return self.map(method, *args, **kwargs) + + # we iterate over the actual weakref keys and check if weakref is alive before calling the method + if isinstance(method, str): + for agentref in self._agents.keyrefs(): + if (agent := agentref()) is not None: + getattr(agent, method)(*args, **kwargs) + else: + for agentref in self._agents.keyrefs(): + if (agent := agentref()) is not None: + method(agent, *args, **kwargs) + + return self + + def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: + """ + Invoke a method or function on each agent in the AgentSet and return the results. + + Args: + method (str, callable): the callable to apply on each agent + + * in case of str, the name of the method to call on each agent. + * in case of callable, the function to be called with each agent as first argument + + *args: Variable length argument list passed to the callable being called. + **kwargs: Arbitrary keyword arguments passed to the callable being called. + + Returns: + list[Any]: The results of the callable calls + """ # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str): res = [ @@ -249,7 +289,7 @@ def do( if (agent := agentref()) is not None ] - return res if return_results else self + return res def get(self, attr_names: str | list[str]) -> list[Any]: """ diff --git a/tests/test_agent.py b/tests/test_agent.py index f4e64ce5338..1541a46b6c2 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -83,12 +83,6 @@ def test_function(agent): assert all( a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset) ) - assert all( - a1 == a2.unique_id - for a1, a2 in zip( - agentset.do("get_unique_identifier", return_results=True), agentset - ) - ) assert agentset == agentset.do("get_unique_identifier") agentset.discard(agents[0]) @@ -276,6 +270,35 @@ def remove_function(agent): assert len(agentset) == 0 +def test_agentset_map_str(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + with pytest.raises(AttributeError): + agentset.do("non_existing_method") + + results = agentset.map("get_unique_identifier") + assert all(i == entry for i, entry in zip(results, range(1, 11))) + + +def test_agentset_map_callable(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + # Test callable with non-existent function + with pytest.raises(AttributeError): + agentset.map(lambda agent: agent.non_existing_method()) + + # tests for addition and removal in do using callables + # do iterates, so no error should be raised to change size while iterating + # related to issue #1595 + + results = agentset.map(lambda agent: agent.unique_id) + assert all(i == entry for i, entry in zip(results, range(1, 11))) + + def test_agentset_get_attribute(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)]