Skip to content

Commit

Permalink
Add AgentSet.groupby (projectmesa#2220)
Browse files Browse the repository at this point in the history
Adds Agentset.groupby and Groupby Helper class
  • Loading branch information
quaquel authored Aug 22, 2024
1 parent 2898c01 commit a48dd80
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 3 deletions.
116 changes: 113 additions & 3 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import operator
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from random import Random

Expand Down Expand Up @@ -397,7 +398,116 @@ def random(self) -> Random:
"""
return self.model.random

def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
"""
Group agents by the specified attribute or return from the callable
Args:
by (Callable, str): used to determine what to group agents by
* if ``by`` is a callable, it will be called for each agent and the return is used
for grouping
* if ``by`` is a str, it should refer to an attribute on the agent and the value
of this attribute will be used for grouping
result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
Returns:
GroupBy
Notes:
There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
of an AgentSet.
"""
groups = defaultdict(list)

if isinstance(by, Callable):
for agent in self:
groups[by(agent)].append(agent)
else:
for agent in self:
groups[getattr(agent, by)].append(agent)

if result_type == "agentset":
return GroupBy(
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
)
else:
return GroupBy(groups)

# consider adding for performance reasons
# for Sequence: __reversed__, index, and count
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__


class GroupBy:
"""Helper class for AgentSet.groupby
Attributes:
groups (dict): A dictionary with the group_name as key and group as values
"""

def __init__(self, groups: dict[Any, list | AgentSet]):
self.groups: dict[Any, list | AgentSet] = groups

def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
"""Apply the specified callable to each group and return the results.
Args:
method (Callable, str): The callable to apply to each group,
* if ``method`` is a callable, it will be called it will be called with the group as first argument
* if ``method`` is a str, it should refer to a method on the group
Additional arguments and keyword arguments will be passed on to the callable.
Returns:
dict with group_name as key and the return of the method as value
Notes:
this method is useful for methods or functions that do return something. It
will break method chaining. For that, use ``do`` instead.
"""
if isinstance(method, str):
return {
k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
}
else:
return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}

def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
"""Apply the specified callable to each group
Args:
method (Callable, str): The callable to apply to each group,
* if ``method`` is a callable, it will be called it will be called with the group as first argument
* if ``method`` is a str, it should refer to a method on the group
Additional arguments and keyword arguments will be passed on to the callable.
Returns:
the original GroupBy instance
Notes:
this method is useful for methods or functions that don't return anything and/or
if you want to chain multiple do calls
"""
if isinstance(method, str):
for v in self.groups.values():
getattr(v, method)(*args, **kwargs)
else:
for v in self.groups.values():
method(v, *args, **kwargs)

return self

def __iter__(self):
return iter(self.groups.items())

# consider adding for performance reasons
# for Sequence: __reversed__, index, and count
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
def __len__(self):
return len(self.groups)
42 changes: 42 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,45 @@ def test_agentset_shuffle():
agentset = AgentSet(test_agents, model=model)
agentset.shuffle(inplace=True)
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))


def test_agentset_groupby():
class TestAgent(Agent):
def __init__(self, unique_id, model):
super().__init__(unique_id, model)
self.even = self.unique_id % 2 == 0

def get_unique_identifier(self):
return self.unique_id

model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

groups = agentset.groupby("even")
assert len(groups.groups[True]) == 5
assert len(groups.groups[False]) == 5

groups = agentset.groupby(lambda a: a.unique_id % 2 == 0)
assert len(groups.groups[True]) == 5
assert len(groups.groups[False]) == 5
assert len(groups) == 2

for group_name, group in groups:
assert len(group) == 5
assert group_name in {True, False}

sizes = agentset.groupby("even", result_type="list").map(len)
assert sizes == {True: 5, False: 5}

attributes = agentset.groupby("even", result_type="agentset").map("get", "even")
for group_name, group in attributes.items():
assert all(group_name == entry for entry in group)

groups = agentset.groupby("even", result_type="agentset")
another_ref_to_groups = groups.do("do", "step")
assert groups == another_ref_to_groups

groups = agentset.groupby("even", result_type="agentset")
another_ref_to_groups = groups.do(lambda x: x.do("step"))
assert groups == another_ref_to_groups

0 comments on commit a48dd80

Please sign in to comment.