Skip to content

Commit

Permalink
Split AgentSet into map and do to seperate return types (projectmesa#…
Browse files Browse the repository at this point in the history
…2237)

* seperate original do into map and do in AgentSet based on their difference in return type.
  • Loading branch information
quaquel authored Aug 22, 2024
1 parent 33a9926 commit 2898c01
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
52 changes: 46 additions & 6 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import contextlib
import copy
import operator
import warnings
import weakref
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from random import Random
Expand Down Expand Up @@ -216,25 +217,64 @@ def _update(self, agents: Iterable[Agent]):
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
return self

def do(
self, method: str | Callable, *args, return_results: bool = False, **kwargs
) -> AgentSet | list[Any]:
def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
"""
Invoke a method or function on each agent in the AgentSet.
Args:
method (str, callable): the callable to do on each agents
method (str, callable): the callable to do on each agent
* in case of str, the name of the method to call on each agent.
* in case of callable, the function to be called with each agent as first argument
return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
*args: Variable length argument list passed to the callable being called.
**kwargs: Arbitrary keyword arguments passed to the callable being called.
Returns:
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
"""
try:
return_results = kwargs.pop("return_results")
except KeyError:
return_results = False
else:
warnings.warn(
"Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
"AgentSet.map in case of return_results=True",
stacklevel=2,
)

if return_results:
return self.map(method, *args, **kwargs)

# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
getattr(agent, method)(*args, **kwargs)
else:
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
method(agent, *args, **kwargs)

return self

def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
"""
Invoke a method or function on each agent in the AgentSet and return the results.
Args:
method (str, callable): the callable to apply on each agent
* in case of str, the name of the method to call on each agent.
* in case of callable, the function to be called with each agent as first argument
*args: Variable length argument list passed to the callable being called.
**kwargs: Arbitrary keyword arguments passed to the callable being called.
Returns:
list[Any]: The results of the callable calls
"""
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
res = [
Expand All @@ -249,7 +289,7 @@ def do(
if (agent := agentref()) is not None
]

return res if return_results else self
return res

def get(self, attr_names: str | list[str]) -> list[Any]:
"""
Expand Down
35 changes: 29 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ def test_function(agent):
assert all(
a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset)
)
assert all(
a1 == a2.unique_id
for a1, a2 in zip(
agentset.do("get_unique_identifier", return_results=True), agentset
)
)
assert agentset == agentset.do("get_unique_identifier")

agentset.discard(agents[0])
Expand Down Expand Up @@ -276,6 +270,35 @@ def remove_function(agent):
assert len(agentset) == 0


def test_agentset_map_str():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")

results = agentset.map("get_unique_identifier")
assert all(i == entry for i, entry in zip(results, range(1, 11)))


def test_agentset_map_callable():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
with pytest.raises(AttributeError):
agentset.map(lambda agent: agent.non_existing_method())

# tests for addition and removal in do using callables
# do iterates, so no error should be raised to change size while iterating
# related to issue #1595

results = agentset.map(lambda agent: agent.unique_id)
assert all(i == entry for i, entry in zip(results, range(1, 11)))


def test_agentset_get_attribute():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
Expand Down

0 comments on commit 2898c01

Please sign in to comment.