From a48dd80a6ce424638e78ae962225d099551dfe46 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 22 Aug 2024 11:05:36 +0200 Subject: [PATCH] Add AgentSet.groupby (#2220) Adds Agentset.groupby and Groupby Helper class --- mesa/agent.py | 116 ++++++++++++++++++++++++++++++++++++++++++-- tests/test_agent.py | 42 ++++++++++++++++ 2 files changed, 155 insertions(+), 3 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 8805673193d..f137ba09c2c 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -13,6 +13,7 @@ import operator import warnings import weakref +from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence from random import Random @@ -397,7 +398,116 @@ def random(self) -> Random: """ return self.model.random + def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: + """ + Group agents by the specified attribute or return from the callable + + Args: + by (Callable, str): used to determine what to group agents by + + * if ``by`` is a callable, it will be called for each agent and the return is used + for grouping + * if ``by`` is a str, it should refer to an attribute on the agent and the value + of this attribute will be used for grouping + result_type (str, optional): The datatype for the resulting groups {"agentset", "list"} + Returns: + GroupBy + + + Notes: + There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality + of an AgentSet. + + """ + groups = defaultdict(list) + + if isinstance(by, Callable): + for agent in self: + groups[by(agent)].append(agent) + else: + for agent in self: + groups[getattr(agent, by)].append(agent) + + if result_type == "agentset": + return GroupBy( + {k: AgentSet(v, model=self.model) for k, v in groups.items()} + ) + else: + return GroupBy(groups) + + # consider adding for performance reasons + # for Sequence: __reversed__, index, and count + # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ + + +class GroupBy: + """Helper class for AgentSet.groupby + + + Attributes: + groups (dict): A dictionary with the group_name as key and group as values + + """ + + def __init__(self, groups: dict[Any, list | AgentSet]): + self.groups: dict[Any, list | AgentSet] = groups + + def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: + """Apply the specified callable to each group and return the results. + + Args: + method (Callable, str): The callable to apply to each group, + + * if ``method`` is a callable, it will be called it will be called with the group as first argument + * if ``method`` is a str, it should refer to a method on the group + + Additional arguments and keyword arguments will be passed on to the callable. + + Returns: + dict with group_name as key and the return of the method as value + + Notes: + this method is useful for methods or functions that do return something. It + will break method chaining. For that, use ``do`` instead. + + """ + if isinstance(method, str): + return { + k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items() + } + else: + return {k: method(v, *args, **kwargs) for k, v in self.groups.items()} + + def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: + """Apply the specified callable to each group + + Args: + method (Callable, str): The callable to apply to each group, + + * if ``method`` is a callable, it will be called it will be called with the group as first argument + * if ``method`` is a str, it should refer to a method on the group + + Additional arguments and keyword arguments will be passed on to the callable. + + Returns: + the original GroupBy instance + + Notes: + this method is useful for methods or functions that don't return anything and/or + if you want to chain multiple do calls + + """ + if isinstance(method, str): + for v in self.groups.values(): + getattr(v, method)(*args, **kwargs) + else: + for v in self.groups.values(): + method(v, *args, **kwargs) + + return self + + def __iter__(self): + return iter(self.groups.items()) -# consider adding for performance reasons -# for Sequence: __reversed__, index, and count -# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ + def __len__(self): + return len(self.groups) diff --git a/tests/test_agent.py b/tests/test_agent.py index 1541a46b6c2..516902f969b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -371,3 +371,45 @@ def test_agentset_shuffle(): agentset = AgentSet(test_agents, model=model) agentset.shuffle(inplace=True) assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset)) + + +def test_agentset_groupby(): + class TestAgent(Agent): + def __init__(self, unique_id, model): + super().__init__(unique_id, model) + self.even = self.unique_id % 2 == 0 + + def get_unique_identifier(self): + return self.unique_id + + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + groups = agentset.groupby("even") + assert len(groups.groups[True]) == 5 + assert len(groups.groups[False]) == 5 + + groups = agentset.groupby(lambda a: a.unique_id % 2 == 0) + assert len(groups.groups[True]) == 5 + assert len(groups.groups[False]) == 5 + assert len(groups) == 2 + + for group_name, group in groups: + assert len(group) == 5 + assert group_name in {True, False} + + sizes = agentset.groupby("even", result_type="list").map(len) + assert sizes == {True: 5, False: 5} + + attributes = agentset.groupby("even", result_type="agentset").map("get", "even") + for group_name, group in attributes.items(): + assert all(group_name == entry for entry in group) + + groups = agentset.groupby("even", result_type="agentset") + another_ref_to_groups = groups.do("do", "step") + assert groups == another_ref_to_groups + + groups = agentset.groupby("even", result_type="agentset") + another_ref_to_groups = groups.do(lambda x: x.do("step")) + assert groups == another_ref_to_groups