Skip to content

Commit

Permalink
changed list to iterable in AgentsDF, change ids to IdsLike, added id…
Browse files Browse the repository at this point in the history
… check in AgentsDF
  • Loading branch information
adamamer20 committed Jun 25, 2024
1 parent f317ce1 commit 03f375e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 66 deletions.
40 changes: 17 additions & 23 deletions mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


from __future__ import annotations # PEP 563: postponed evaluation of type annotations

from abc import ABC, abstractmethod
Expand All @@ -20,7 +18,7 @@

from numpy.random import Generator

from mesa_frames.types import BoolSeries, DataFrame, MaskLike, Series
from mesa_frames.types import BoolSeries, DataFrame, IdsLike, MaskLike, Series

if TYPE_CHECKING:
from mesa_frames.concrete.model import ModelDF
Expand All @@ -42,17 +40,17 @@ class AgentContainer(ABC):
-------
copy(deep: bool = False, memo: dict | None = None) -> Self
Create a copy of the AgentContainer.
discard(ids: MaskLike, inplace: bool = True) -> Self
discard(ids: IdsLike, inplace: bool = True) -> Self
Removes an agent from the AgentContainer. Does not raise an error if the agent is not found.
add(other: Any, inplace: bool = True) -> Self
Add agents to the AgentContainer.
contains(ids: Hashable | Collection[Hashable]) -> bool | BoolSeries
contains(ids: IdsLike) -> bool | BoolSeries
Check if agents with the specified IDs are in the AgentContainer.
do(method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any | dict[str, Any]
Invoke a method on the AgentContainer.
get(attr_names: str | Collection[str] | None = None, mask: MaskLike | None = None) -> Series | DataFrame | dict[str, Series] | dict[str, DataFrame]
Retrieve the value of a specified attribute for each agent in the AgentContainer.
remove(ids: MaskLike, inplace: bool = True) -> Self
remove(ids: IdsLike, inplace: bool = True) -> Self
Removes an agent from the AgentContainer.
select(mask: MaskLike | None = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
Select agents in the AgentContainer based on the given criteria.
Expand Down Expand Up @@ -144,7 +142,7 @@ def copy(

return obj

def discard(self, ids: MaskLike, inplace: bool = True) -> Self:
def discard(self, ids: IdsLike, inplace: bool = True) -> Self:
"""Removes an agent from the AgentContainer. Does not raise an error if the agent is not found.
Parameters
Expand Down Expand Up @@ -182,19 +180,19 @@ def add(self, other, inplace: bool = True) -> Self:

@overload
@abstractmethod
def contains(self, ids: Collection[Hashable]) -> BoolSeries: ...
def contains(self, ids: int) -> bool: ...

@overload
@abstractmethod
def contains(self, ids: Hashable) -> bool: ...
def contains(self, ids: IdsLike) -> BoolSeries: ...

@abstractmethod
def contains(self, ids: Hashable | Collection[Hashable]) -> bool | BoolSeries:
def contains(self, ids: IdsLike) -> bool | BoolSeries:
"""Check if agents with the specified IDs are in the AgentContainer.
Parameters
----------
ids : Hashable | Collection[Any]
ids : IdsLike
The ID(s) to check for.
Returns
Expand Down Expand Up @@ -283,7 +281,7 @@ def get(
...

@abstractmethod
def remove(self, ids: MaskLike, inplace: bool = True) -> Self:
def remove(self, ids: IdsLike, inplace: bool = True) -> Self:
"""Removes an agent from the AgentContainer.
Parameters
Expand Down Expand Up @@ -446,7 +444,7 @@ def _get_obj(self, inplace: bool) -> Self:
def __add__(self, other) -> Self:
return self.add(other=other, inplace=False)

def __contains__(self, id: Hashable) -> bool:
def __contains__(self, id: int) -> bool:
"""Check if an agent is in the AgentContainer.
Parameters
Expand All @@ -459,13 +457,9 @@ def __contains__(self, id: Hashable) -> bool:
bool
True if the agent is in the AgentContainer, False otherwise.
"""
bool_series = self.contains(ids=id)
if isinstance(bool_series, bool):
return bool_series
elif len(bool_series) == 1:
return bool_series[0].value
else:
raise ValueError("The in operator can only be used with a single ID.")
if not isinstance(id, int):
raise TypeError("id must be an integer")
return self.contains(ids=id)

def __copy__(self) -> Self:
"""Create a shallow copy of the AgentContainer.
Expand Down Expand Up @@ -543,7 +537,7 @@ def __iadd__(self, other) -> Self:
"""
return self.add(other=other, inplace=True)

def __isub__(self, other: MaskLike) -> Self:
def __isub__(self, other: IdsLike) -> Self:
"""Remove agents from the AgentContainer through the -= operator.
Parameters
Expand All @@ -558,7 +552,7 @@ def __isub__(self, other: MaskLike) -> Self:
"""
return self.discard(other, inplace=True)

def __sub__(self, other: MaskLike) -> Self:
def __sub__(self, other: IdsLike) -> Self:
"""Remove agents from a new AgentContainer through the - operator.
Parameters
Expand Down Expand Up @@ -1038,4 +1032,4 @@ def active_agents(self) -> DataFrame: ...

@property
@abstractmethod
def inactive_agents(self) -> DataFrame: ...
def inactive_agents(self) -> DataFrame: ...
81 changes: 60 additions & 21 deletions mesa_frames/concrete/agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from operator import ne
from typing import Any, Callable, Iterable, Iterator, Literal, Self, Sequence, overload

import polars as pl

from mesa_frames.abstract.agents import AgentContainer, AgentSetDF, Collection, Hashable
from mesa_frames.types import BoolSeries, DataFrame, MaskLike, Series
from mesa_frames.concrete.agentset_pandas import AgentSetPandas
from mesa_frames.concrete.agentset_polars import AgentSetPolars
from mesa_frames.types import BoolSeries, DataFrame, IdsLike, MaskLike, Series


class AgentsDF(AgentContainer):
Expand All @@ -10,6 +15,7 @@ class AgentsDF(AgentContainer):
"_agentsets": ("copy", []),
}
_backend: str
_ids: pl.Series
"""A collection of AgentSetDFs. All agents of the model are stored here.
Attributes
Expand Down Expand Up @@ -38,19 +44,19 @@ class AgentsDF(AgentContainer):
-------
__init__(self) -> None
Initialize a new AgentsDF.
add(self, other: AgentSetDF | list[AgentSetDF], inplace: bool = True) -> Self
add(self, other: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True) -> Self
Add agents to the AgentsDF.
contains(self, ids: Hashable | Collection[Hashable]) -> bool | dict[str, pd.Series]
contains(self, ids: IdsLike) -> bool | pl.Series
Check if agents with the specified IDs are in the AgentsDF.
copy(self, deep: bool = False, memo: dict | None = None) -> Self
Create a copy of the AgentsDF.
discard(self, ids: MaskLike, inplace: bool = True) -> Self
discard(self, ids: IdsLike, inplace: bool = True) -> Self
Remove an agent from the AgentsDF. Does not raise an error if the agent is not found.
do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any
Invoke a method on the AgentsDF.
get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[str, pd.Series] | dict[str, pd.DataFrame]
get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[str, Series] | dict[str, DataFrame]
Retrieve the value of a specified attribute for each agent in the AgentsDF.
remove(self, ids: MaskLike, inplace: bool = True) -> Self
remove(self, ids: IdsLike, inplace: bool = True) -> Self
Remove agents from the AgentsDF.
select(self, mask: MaskLike = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
Select agents in the AgentsDF based on the given criteria.
Expand All @@ -60,8 +66,14 @@ class AgentsDF(AgentContainer):
Shuffle the order of agents in the AgentsDF.
sort(self, by: str | Sequence[str], ascending: bool | Sequence[bool] = True, inplace: bool = True, **kwargs) -> Self
Sort the agents in the AgentsDF based on the given criteria.
_check_ids(self, other: AgentSetDF | Iterable[AgentSetDF]) -> None
Check if the IDs of the agents to be added are unique.
__add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self
Add AgentSetDFs to a new AgentsDF through the + operator.
__getattr__(self, key: str) -> Any
Retrieve an attribute of the underlying agent sets.
__iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self
Add AgentSetDFs to the AgentsDF through the += operator.
__iter__(self) -> Iterator
Get an iterator for the agents in the AgentsDF.
__len__(self) -> int
Expand All @@ -76,8 +88,11 @@ class AgentsDF(AgentContainer):

def __init__(self) -> None:
self._agentsets = []
self._ids = pl.Series(name="unique_id", dtype=pl.Int64)

def add(self, other: AgentSetDF | list[AgentSetDF], inplace: bool = True) -> Self:
def add(
self, other: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True
) -> Self:
"""Add an AgentSetDF to the AgentsDF.
Parameters
Expand All @@ -93,23 +108,24 @@ def add(self, other: AgentSetDF | list[AgentSetDF], inplace: bool = True) -> Sel
The updated AgentsDF.
"""
obj = self._get_obj(inplace)
if isinstance(other, list):
self._check_ids(other)
if isinstance(other, Iterable):
obj._agentsets += other
else:
obj._agentsets.append(other)
return self

@overload
def contains(self, ids: Collection[Hashable]) -> BoolSeries: ...
def contains(self, ids: int) -> bool: ...

@overload
def contains(self, ids: Hashable) -> bool: ...
def contains(self, ids: IdsLike) -> pl.Series: ...

def contains(self, ids: Hashable | Collection[Hashable]) -> bool | BoolSeries:
bool_series = self._agentsets[0].contains(ids)
for agentset in self._agentsets[1:]:
bool_series = bool_series | agentset.contains(ids)
return bool_series
def contains(self, ids: IdsLike) -> bool | pl.Series:
if isinstance(ids, int):
return ids in self._ids
else:
return pl.Series(ids).is_in(self._ids)

@overload
def do(
Expand Down Expand Up @@ -174,7 +190,7 @@ def get(
for agentset in self._agentsets
}

def remove(self, ids: MaskLike, inplace: bool = True) -> Self:
def remove(self, ids: IdsLike, inplace: bool = True) -> Self:
obj = self._get_obj(inplace)
deleted = 0
for agentset in obj._agentsets:
Expand Down Expand Up @@ -204,7 +220,7 @@ def set(
def select(
self,
mask: MaskLike | None = None,
filter_func: Callable[[DataFrame], MaskLike] | None = None,
filter_func: Callable[[AgentSetDF], MaskLike] | None = None,
n: int | None = None,
inplace: bool = True,
negate: bool = False,
Expand Down Expand Up @@ -237,12 +253,35 @@ def sort(
]
return obj

def __add__(self, other: AgentSetDF | list[AgentSetDF]) -> Self:
def _check_ids(self, other: AgentSetDF | Iterable[AgentSetDF]) -> None:
"""Check if the IDs of the agents to be added are unique.
Parameters
----------
other : AgentSetDF | Iterable[AgentSetDF]
The AgentSetDFs to check.
Raises
------
ValueError
If the agent set contains IDs already present in agents.
"""
for agentset in other if isinstance(other, Iterable) else [other]:
if isinstance(agentset, AgentSetPandas):
new_ids = pl.Series(agentset._agents.index)
elif isinstance(agentset, AgentSetPolars):
new_ids = agentset._agents["unique_id"]
if new_ids.is_in(self._ids).any():
raise ValueError(
"The agent set contains IDs already present in agents."
)

def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self:
"""Add AgentSetDFs to a new AgentsDF through the + operator.
Parameters
----------
other : AgentSetDF | list[AgentSetDF]
other : AgentSetDF | Iterable[AgentSetDF]
The AgentSetDFs to add.
Returns
Expand All @@ -258,12 +297,12 @@ def __getattr__(self, name: str) -> dict[str, Any]:
for agentset in self._agentsets
}

def __iadd__(self, other: AgentSetDF | list[AgentSetDF]) -> Self:
def __iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self:
"""Add AgentSetDFs to the AgentsDF through the += operator.
Parameters
----------
other : Self | AgentSetDF | list[AgentSetDF]
other : Self | AgentSetDF | Iterable[AgentSetDF]
The AgentSetDFs to add.
Returns
Expand Down
26 changes: 16 additions & 10 deletions mesa_frames/concrete/agentset_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mesa_frames.abstract.agents import AgentSetDF
from mesa_frames.concrete.agentset_polars import AgentSetPolars
from mesa_frames.concrete.model import ModelDF
from mesa_frames.types import PandasMaskLike
from mesa_frames.types import PandasIdsLike, PandasMaskLike


class AgentSetPandas(AgentSetDF):
Expand Down Expand Up @@ -62,17 +62,17 @@ class AgentSetPandas(AgentSetDF):
Initialize a new AgentSetPandas.
add(self, other: pd.DataFrame | Sequence[Any] | dict[str, Any], inplace: bool = True) -> Self
Add agents to the AgentSetPandas.
contains(self, ids: Hashable | Collection[Hashable]) -> bool | pd.Series
contains(self, ids: PandasIdsLike) -> bool | pd.Series
Check if agents with the specified IDs are in the AgentSetPandas.
copy(self, deep: bool = False, memo: dict | None = None) -> Self
Create a copy of the AgentSetPandas.
discard(self, ids: PandasMaskLike, inplace: bool = True) -> Self
discard(self, ids: PandasIdsLike, inplace: bool = True) -> Self
Remove an agent from the AgentSetPandas. Does not raise an error if the agent is not found.
do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any
Invoke a method on the AgentSetPandas.
get(self, attr_names: str | Collection[str] | None, mask: PandasMaskLike = None) -> pd.Series | pd.DataFrame
Retrieve the value of a specified attribute for each agent in the AgentSetPandas.
remove(self, ids: PandasMaskLike, inplace: bool = True) -> Self
remove(self, ids: PandasIdsLike, inplace: bool = True) -> Self
Remove agents from the AgentSetPandas.
select(self, mask: PandasMaskLike = None, filter_func: Callable[[Self], PandasMaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
Select agents in the AgentSetPandas based on the given criteria.
Expand Down Expand Up @@ -104,8 +104,10 @@ class AgentSetPandas(AgentSetDF):

def __init__(self, model: ModelDF) -> None:
self._model = model
self._agents = pd.DataFrame(columns=["unique_id"]).set_index("unique_id")
self._mask = pd.Series(True, index=self._agents.index)
self._agents = pd.DataFrame(
columns=["unique_id"], dtype={"unique_id": pd.Int64Dtype}
).set_index("unique_id")
self._mask = pd.Series(True, index=self._agents.index, dtype=pd.BooleanDtype)

def add(
self,
Expand Down Expand Up @@ -140,17 +142,21 @@ def add(
return obj

@overload
def contains(self, ids: Collection[Hashable]) -> pd.Series: ...
def contains(self, ids: int) -> bool: ...

@overload
def contains(self, ids: Hashable) -> bool: ...
def contains(self, ids: PandasIdsLike) -> pd.Series: ...

def contains(
self,
ids: Hashable | Collection[Hashable],
ids: int | Collection[int] | pd.Series[int] | pd.Index,
) -> bool | pd.Series:
if isinstance(ids, pd.Series):
return ids.isin(self._agents.index)
elif isinstance(ids, pd.Index):
return pd.Series(
ids.isin(self._agents.index), index=ids, dtype=pd.BooleanDtype
)
elif isinstance(ids, Collection):
return pd.Series(list(ids), index=list(ids)).isin(self._agents.index)
else:
Expand All @@ -172,7 +178,7 @@ def get(

def remove(
self,
ids: PandasMaskLike,
ids: PandasIdsLike,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
Expand Down
Loading

0 comments on commit 03f375e

Please sign in to comment.