Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 authored Jun 27, 2024
2 parents 8a0783f + 7410be9 commit 232b863
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 99 deletions.
26 changes: 13 additions & 13 deletions mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def copy(

return obj

def discard(self, ids: IdsLike, inplace: bool = True) -> Self:
def discard(self, agents: "AgentSetDF" | IdsLike, inplace: bool = True) -> Self:
"""Removes an agent from the AgentContainer. Does not raise an error if the agent is not found.
Parameters
Expand All @@ -157,7 +157,7 @@ def discard(self, ids: IdsLike, inplace: bool = True) -> Self:
Self
"""
with suppress(KeyError):
return self.remove(ids, inplace=inplace)
return self.remove(agents, inplace=inplace)
return self._get_obj(inplace)

@abstractmethod
Expand All @@ -180,19 +180,19 @@ def add(self, other, inplace: bool = True) -> Self:

@overload
@abstractmethod
def contains(self, ids: int) -> bool: ...
def contains(self, agents: int) -> bool: ...

@overload
@abstractmethod
def contains(self, ids: IdsLike) -> BoolSeries: ...
def contains(self, agents: "AgentSetDF" | IdsLike) -> BoolSeries: ...

@abstractmethod
def contains(self, ids: IdsLike) -> bool | BoolSeries:
def contains(self, agents: IdsLike) -> bool | BoolSeries:
"""Check if agents with the specified IDs are in the AgentContainer.
Parameters
----------
ids : IdsLike
agents : IdsLike
The ID(s) to check for.
Returns
Expand Down Expand Up @@ -281,12 +281,12 @@ def get(
...

@abstractmethod
def remove(self, ids: IdsLike, inplace: bool = True) -> Self:
def remove(self, agents: IdsLike, inplace: bool = True) -> Self:
"""Removes an agent from the AgentContainer.
Parameters
----------
id : MaskLike
agents : MaskLike
The ID of the agent to remove.
inplace : bool
Whether to remove the agent in place.
Expand Down Expand Up @@ -459,7 +459,7 @@ def __contains__(self, id: int) -> bool:
"""
if not isinstance(id, int):
raise TypeError("id must be an integer")
return self.contains(ids=id)
return self.contains(agents=id)

def __copy__(self) -> Self:
"""Create a shallow copy of the AgentContainer.
Expand Down Expand Up @@ -537,7 +537,7 @@ def __iadd__(self, other) -> Self:
"""
return self.add(other=other, inplace=True)

def __isub__(self, other: IdsLike) -> Self:
def __isub__(self, other: "AgentSetDF" | IdsLike) -> Self:
"""Remove agents from the AgentContainer through the -= operator.
Parameters
Expand All @@ -552,7 +552,7 @@ def __isub__(self, other: IdsLike) -> Self:
"""
return self.discard(other, inplace=True)

def __sub__(self, other: IdsLike) -> Self:
def __sub__(self, other: "AgentSetDF" | IdsLike) -> Self:
"""Remove agents from a new AgentContainer through the - operator.
Parameters
Expand Down Expand Up @@ -1003,10 +1003,10 @@ def __iter__(self) -> Iterator:
return iter(self._agents)

def __repr__(self) -> str:
return repr(self._agents)
return f"{self.__class__.__name__}\n {str(self._agents)}"

def __str__(self) -> str:
return str(self._agents)
return f"{self.__class__.__name__}\n {str(self._agents)}"

def __reversed__(self) -> Iterator:
return reversed(self._agents)
Expand Down
110 changes: 53 additions & 57 deletions mesa_frames/concrete/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Iterable, Iterator, Literal, Self, Sequence, overload

import polars as pl
from mesa import Agent

from mesa_frames.abstract.agents import AgentContainer, AgentSetDF, Collection, Hashable
from mesa_frames.concrete.agentset_pandas import AgentSetPandas
Expand All @@ -22,18 +23,18 @@ class AgentsDF(AgentContainer):
----------
_agentsets : list[AgentSetDF]
The agent sets contained in this collection.
_copy_with_method : dict[str, tuple[str, list[str]]]
_copy_with_method : dict[AgentSetDF, tuple[str, list[str]]]
A dictionary of attributes to copy with a specified method and arguments.
_backend : str
The backend used for data operations.
Properties
----------
active_agents(self) -> dict[str, pd.DataFrame]
active_agents(self) -> dict[AgentSetDF, pd.DataFrame]
Get the active agents in the AgentsDF.
agents(self) -> dict[str, pd.DataFrame]
agents(self) -> dict[AgentSetDF, pd.DataFrame]
Get or set the agents in the AgentsDF.
inactive_agents(self) -> dict[str, pd.DataFrame]
inactive_agents(self) -> dict[AgentSetDF, pd.DataFrame]
Get the inactive agents in the AgentsDF.
model(self) -> ModelDF
Get the model associated with the AgentsDF.
Expand All @@ -54,13 +55,13 @@ class AgentsDF(AgentContainer):
Remove an agent from the AgentsDF. Does not raise an error if the agent is not found.
do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any
Invoke a method on the AgentsDF.
get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[str, Series] | dict[str, DataFrame]
get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]
Retrieve the value of a specified attribute for each agent in the AgentsDF.
remove(self, ids: IdsLike, inplace: bool = True) -> Self
Remove agents from the AgentsDF.
select(self, mask: MaskLike = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
Select agents in the AgentsDF based on the given criteria.
set(self, attr_names: str | Collection[str] | dict[str, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self
set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self
Set the value of a specified attribute or attributes for each agent in the mask in the AgentsDF.
shuffle(self, inplace: bool = True) -> Self
Shuffle the order of agents in the AgentsDF.
Expand Down Expand Up @@ -109,23 +110,27 @@ def add(
"""
obj = self._get_obj(inplace)
self._check_ids(other)
if isinstance(other, Iterable):
obj._agentsets += other
else:
if isinstance(other, AgentSetDF):
obj._agentsets.append(other)
elif isinstance(other, Iterable):
if not all(isinstance(agentset, AgentSetDF) for agentset in other):
raise TypeError("All elements in the iterable must be AgentSetDFs.")
obj._agentsets.extend(other)
return self

@overload
def contains(self, ids: int) -> bool: ...
def contains(self, agents: AgentSetDF | int) -> bool: ...

@overload
def contains(self, ids: IdsLike) -> pl.Series: ...
def contains(self, agents: IdsLike) -> pl.Series: ...

def contains(self, ids: IdsLike) -> bool | pl.Series:
if isinstance(ids, int):
return ids in self._ids
def contains(self, agents: AgentSetDF | IdsLike) -> bool | pl.Series:
if isinstance(agents, int):
return agents in self._ids
elif isinstance(agents, AgentSetDF):
return agents in self._agentsets
else:
return pl.Series(ids).is_in(self._ids)
return pl.Series(agents).is_in(self._ids)

@overload
def do(
Expand All @@ -145,7 +150,7 @@ def do(
return_results: Literal[True],
inplace: bool = True,
**kwargs,
) -> dict[str, Any]: ...
) -> dict[AgentSetDF, Any]: ...

def do(
self,
Expand All @@ -158,7 +163,7 @@ def do(
obj = self._get_obj(inplace)
if return_results:
return {
agentset.__class__.__name__: agentset.do(
agentset: agentset.do(
method_name,
*args,
return_results=return_results,
Expand All @@ -184,26 +189,31 @@ def get(
self,
attr_names: str | list[str] | None = None,
mask: MaskLike | None = None,
) -> dict[str, Series] | dict[str, DataFrame]:
) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]:
return {
agentset.__class__.__name__: agentset.get(attr_names, mask)
for agentset in self._agentsets
agentset: agentset.get(attr_names, mask) for agentset in self._agentsets
}

def remove(self, ids: IdsLike, inplace: bool = True) -> Self:
def remove(self, agents: AgentSetDF | IdsLike, inplace: bool = True) -> Self:
obj = self._get_obj(inplace)
deleted = 0
for agentset in obj._agentsets:
initial_len = len(agentset)
agentset.discard(ids, inplace=True)
deleted += initial_len - len(agentset)
if deleted < len(list(ids)): # TODO: fix type hint
raise KeyError(f"None of the agentsets contain the ID {MaskLike}.")
if isinstance(agents, AgentSetDF):
try:
obj._agentsets.remove(agents)
except ValueError:
raise KeyError(f"{agents} not found in the AgentsDF.")
else: # elif isinstance(ids, IdsLike):
for agents in obj._agentsets:
initial_len = len(agents)
agents.discard(agents, inplace=True)
deleted += initial_len - len(agents)
if deleted < len(list(agents)): # TODO: fix type hint
raise KeyError(f"Some ids were not found in the AgentsDF.")
return obj

def set(
self,
attr_names: str | dict[str, Any] | Collection[str],
attr_names: str | dict[AgentSetDF, Any] | Collection[str],
values: Any | None = None,
mask: MaskLike | None = None,
inplace: bool = True,
Expand Down Expand Up @@ -266,11 +276,15 @@ def _check_ids(self, other: AgentSetDF | Iterable[AgentSetDF]) -> None:
ValueError
If the agent set contains IDs already present in agents.
"""
for agentset in other if isinstance(other, Iterable) else [other]:
for agentset in [other] if isinstance(other, AgentSetDF) else other:
if isinstance(agentset, AgentSetPandas):
new_ids = pl.Series(agentset._agents.index)
new_ids = pl.from_pandas(agentset._agents.index)
elif isinstance(agentset, AgentSetPolars):
new_ids = agentset._agents["unique_id"]
else:
raise TypeError(
"AgentSetDF must be of type AgentSetPandas or AgentSetPolars."
)
if new_ids.is_in(self._ids).any():
raise ValueError(
"The agent set contains IDs already present in agents."
Expand All @@ -291,11 +305,8 @@ def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self:
"""
return super().__add__(other)

def __getattr__(self, name: str) -> dict[str, Any]:
return {
agentset.__class__.__name__: getattr(agentset, name)
for agentset in self._agentsets
}
def __getattr__(self, name: str) -> dict[AgentSetDF, Any]:
return {agentset: getattr(agentset, name) for agentset in self._agentsets}

def __iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self:
"""Add AgentSetDFs to the AgentsDF through the += operator.
Expand All @@ -318,17 +329,10 @@ def __iter__(self) -> Iterator:
)

def __repr__(self) -> str:
return str(
{
agentset.__class__.__name__: repr(agentset)
for agentset in self._agentsets
}
)
return "\n".join([repr(agentset) for agentset in self._agentsets])

def __str__(self) -> str:
return str(
{agentset.__class__.__name__: str(agentset) for agentset in self._agentsets}
)
return "\n".join([str(agentset) for agentset in self._agentsets])

def __reversed__(self) -> Iterator:
return (
Expand All @@ -341,10 +345,8 @@ def __len__(self) -> int:
return sum(len(agentset._agents) for agentset in self._agentsets)

@property
def agents(self) -> dict[str, DataFrame]:
return {
agentset.__class__.__name__: agentset.agents for agentset in self._agentsets
}
def agents(self) -> dict[AgentSetDF, DataFrame]:
return {agentset: agentset.agents for agentset in self._agentsets}

@agents.setter
def agents(self, other: Iterable[AgentSetDF]) -> None:
Expand All @@ -358,15 +360,9 @@ def agents(self, other: Iterable[AgentSetDF]) -> None:
self._agentsets = list(other)

@property
def active_agents(self) -> dict[str, DataFrame]:
return {
agentset.__class__.__name__: agentset.active_agents
for agentset in self._agentsets
}
def active_agents(self) -> dict[AgentSetDF, DataFrame]:
return {agentset: agentset.active_agents for agentset in self._agentsets}

@property
def inactive_agents(self):
return {
agentset.__class__.__name__: agentset.inactive_agents
for agentset in self._agentsets
}
return {agentset: agentset.inactive_agents for agentset in self._agentsets}
Loading

0 comments on commit 232b863

Please sign in to comment.