From 7410be9c9b3cb6d88746df67ef6683fd7d16b129 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:07:17 +0200 Subject: [PATCH] Fixed __repr__ and __str__ of AgentsDF, added removal of AgentSetDF, fixed _ids attribute of AgentsDF --- mesa_frames/abstract/agents.py | 26 +++--- mesa_frames/concrete/agents.py | 110 ++++++++++++------------ mesa_frames/concrete/agentset_pandas.py | 32 +++---- mesa_frames/concrete/agentset_polars.py | 11 +-- mesa_frames/test.py | 4 - mesa_frames/types.py | 4 +- requirements.txt | 3 +- 7 files changed, 91 insertions(+), 99 deletions(-) delete mode 100644 mesa_frames/test.py diff --git a/mesa_frames/abstract/agents.py b/mesa_frames/abstract/agents.py index cc66287..2e01b86 100644 --- a/mesa_frames/abstract/agents.py +++ b/mesa_frames/abstract/agents.py @@ -142,7 +142,7 @@ def copy( return obj - def discard(self, ids: IdsLike, inplace: bool = True) -> Self: + def discard(self, agents: "AgentSetDF" | IdsLike, inplace: bool = True) -> Self: """Removes an agent from the AgentContainer. Does not raise an error if the agent is not found. Parameters @@ -157,7 +157,7 @@ def discard(self, ids: IdsLike, inplace: bool = True) -> Self: Self """ with suppress(KeyError): - return self.remove(ids, inplace=inplace) + return self.remove(agents, inplace=inplace) return self._get_obj(inplace) @abstractmethod @@ -180,19 +180,19 @@ def add(self, other, inplace: bool = True) -> Self: @overload @abstractmethod - def contains(self, ids: int) -> bool: ... + def contains(self, agents: int) -> bool: ... @overload @abstractmethod - def contains(self, ids: IdsLike) -> BoolSeries: ... + def contains(self, agents: "AgentSetDF" | IdsLike) -> BoolSeries: ... @abstractmethod - def contains(self, ids: IdsLike) -> bool | BoolSeries: + def contains(self, agents: IdsLike) -> bool | BoolSeries: """Check if agents with the specified IDs are in the AgentContainer. Parameters ---------- - ids : IdsLike + agents : IdsLike The ID(s) to check for. Returns @@ -281,12 +281,12 @@ def get( ... @abstractmethod - def remove(self, ids: IdsLike, inplace: bool = True) -> Self: + def remove(self, agents: IdsLike, inplace: bool = True) -> Self: """Removes an agent from the AgentContainer. Parameters ---------- - id : MaskLike + agents : MaskLike The ID of the agent to remove. inplace : bool Whether to remove the agent in place. @@ -459,7 +459,7 @@ def __contains__(self, id: int) -> bool: """ if not isinstance(id, int): raise TypeError("id must be an integer") - return self.contains(ids=id) + return self.contains(agents=id) def __copy__(self) -> Self: """Create a shallow copy of the AgentContainer. @@ -537,7 +537,7 @@ def __iadd__(self, other) -> Self: """ return self.add(other=other, inplace=True) - def __isub__(self, other: IdsLike) -> Self: + def __isub__(self, other: "AgentSetDF" | IdsLike) -> Self: """Remove agents from the AgentContainer through the -= operator. Parameters @@ -552,7 +552,7 @@ def __isub__(self, other: IdsLike) -> Self: """ return self.discard(other, inplace=True) - def __sub__(self, other: IdsLike) -> Self: + def __sub__(self, other: "AgentSetDF" | IdsLike) -> Self: """Remove agents from a new AgentContainer through the - operator. Parameters @@ -1003,10 +1003,10 @@ def __iter__(self) -> Iterator: return iter(self._agents) def __repr__(self) -> str: - return repr(self._agents) + return f"{self.__class__.__name__}\n {str(self._agents)}" def __str__(self) -> str: - return str(self._agents) + return f"{self.__class__.__name__}\n {str(self._agents)}" def __reversed__(self) -> Iterator: return reversed(self._agents) diff --git a/mesa_frames/concrete/agents.py b/mesa_frames/concrete/agents.py index 34cfe65..0729980 100644 --- a/mesa_frames/concrete/agents.py +++ b/mesa_frames/concrete/agents.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Iterable, Iterator, Literal, Self, Sequence, overload import polars as pl +from mesa import Agent from mesa_frames.abstract.agents import AgentContainer, AgentSetDF, Collection, Hashable from mesa_frames.concrete.agentset_pandas import AgentSetPandas @@ -22,18 +23,18 @@ class AgentsDF(AgentContainer): ---------- _agentsets : list[AgentSetDF] The agent sets contained in this collection. - _copy_with_method : dict[str, tuple[str, list[str]]] + _copy_with_method : dict[AgentSetDF, tuple[str, list[str]]] A dictionary of attributes to copy with a specified method and arguments. _backend : str The backend used for data operations. Properties ---------- - active_agents(self) -> dict[str, pd.DataFrame] + active_agents(self) -> dict[AgentSetDF, pd.DataFrame] Get the active agents in the AgentsDF. - agents(self) -> dict[str, pd.DataFrame] + agents(self) -> dict[AgentSetDF, pd.DataFrame] Get or set the agents in the AgentsDF. - inactive_agents(self) -> dict[str, pd.DataFrame] + inactive_agents(self) -> dict[AgentSetDF, pd.DataFrame] Get the inactive agents in the AgentsDF. model(self) -> ModelDF Get the model associated with the AgentsDF. @@ -54,13 +55,13 @@ class AgentsDF(AgentContainer): Remove an agent from the AgentsDF. Does not raise an error if the agent is not found. do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any Invoke a method on the AgentsDF. - get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[str, Series] | dict[str, DataFrame] + get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame] Retrieve the value of a specified attribute for each agent in the AgentsDF. remove(self, ids: IdsLike, inplace: bool = True) -> Self Remove agents from the AgentsDF. select(self, mask: MaskLike = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self Select agents in the AgentsDF based on the given criteria. - set(self, attr_names: str | Collection[str] | dict[str, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self + set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self Set the value of a specified attribute or attributes for each agent in the mask in the AgentsDF. shuffle(self, inplace: bool = True) -> Self Shuffle the order of agents in the AgentsDF. @@ -109,23 +110,27 @@ def add( """ obj = self._get_obj(inplace) self._check_ids(other) - if isinstance(other, Iterable): - obj._agentsets += other - else: + if isinstance(other, AgentSetDF): obj._agentsets.append(other) + elif isinstance(other, Iterable): + if not all(isinstance(agentset, AgentSetDF) for agentset in other): + raise TypeError("All elements in the iterable must be AgentSetDFs.") + obj._agentsets.extend(other) return self @overload - def contains(self, ids: int) -> bool: ... + def contains(self, agents: AgentSetDF | int) -> bool: ... @overload - def contains(self, ids: IdsLike) -> pl.Series: ... + def contains(self, agents: IdsLike) -> pl.Series: ... - def contains(self, ids: IdsLike) -> bool | pl.Series: - if isinstance(ids, int): - return ids in self._ids + def contains(self, agents: AgentSetDF | IdsLike) -> bool | pl.Series: + if isinstance(agents, int): + return agents in self._ids + elif isinstance(agents, AgentSetDF): + return agents in self._agentsets else: - return pl.Series(ids).is_in(self._ids) + return pl.Series(agents).is_in(self._ids) @overload def do( @@ -145,7 +150,7 @@ def do( return_results: Literal[True], inplace: bool = True, **kwargs, - ) -> dict[str, Any]: ... + ) -> dict[AgentSetDF, Any]: ... def do( self, @@ -158,7 +163,7 @@ def do( obj = self._get_obj(inplace) if return_results: return { - agentset.__class__.__name__: agentset.do( + agentset: agentset.do( method_name, *args, return_results=return_results, @@ -184,26 +189,31 @@ def get( self, attr_names: str | list[str] | None = None, mask: MaskLike | None = None, - ) -> dict[str, Series] | dict[str, DataFrame]: + ) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: return { - agentset.__class__.__name__: agentset.get(attr_names, mask) - for agentset in self._agentsets + agentset: agentset.get(attr_names, mask) for agentset in self._agentsets } - def remove(self, ids: IdsLike, inplace: bool = True) -> Self: + def remove(self, agents: AgentSetDF | IdsLike, inplace: bool = True) -> Self: obj = self._get_obj(inplace) deleted = 0 - for agentset in obj._agentsets: - initial_len = len(agentset) - agentset.discard(ids, inplace=True) - deleted += initial_len - len(agentset) - if deleted < len(list(ids)): # TODO: fix type hint - raise KeyError(f"None of the agentsets contain the ID {MaskLike}.") + if isinstance(agents, AgentSetDF): + try: + obj._agentsets.remove(agents) + except ValueError: + raise KeyError(f"{agents} not found in the AgentsDF.") + else: # elif isinstance(ids, IdsLike): + for agents in obj._agentsets: + initial_len = len(agents) + agents.discard(agents, inplace=True) + deleted += initial_len - len(agents) + if deleted < len(list(agents)): # TODO: fix type hint + raise KeyError(f"Some ids were not found in the AgentsDF.") return obj def set( self, - attr_names: str | dict[str, Any] | Collection[str], + attr_names: str | dict[AgentSetDF, Any] | Collection[str], values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True, @@ -266,11 +276,15 @@ def _check_ids(self, other: AgentSetDF | Iterable[AgentSetDF]) -> None: ValueError If the agent set contains IDs already present in agents. """ - for agentset in other if isinstance(other, Iterable) else [other]: + for agentset in [other] if isinstance(other, AgentSetDF) else other: if isinstance(agentset, AgentSetPandas): - new_ids = pl.Series(agentset._agents.index) + new_ids = pl.from_pandas(agentset._agents.index) elif isinstance(agentset, AgentSetPolars): new_ids = agentset._agents["unique_id"] + else: + raise TypeError( + "AgentSetDF must be of type AgentSetPandas or AgentSetPolars." + ) if new_ids.is_in(self._ids).any(): raise ValueError( "The agent set contains IDs already present in agents." @@ -291,11 +305,8 @@ def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: """ return super().__add__(other) - def __getattr__(self, name: str) -> dict[str, Any]: - return { - agentset.__class__.__name__: getattr(agentset, name) - for agentset in self._agentsets - } + def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: + return {agentset: getattr(agentset, name) for agentset in self._agentsets} def __iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: """Add AgentSetDFs to the AgentsDF through the += operator. @@ -318,17 +329,10 @@ def __iter__(self) -> Iterator: ) def __repr__(self) -> str: - return str( - { - agentset.__class__.__name__: repr(agentset) - for agentset in self._agentsets - } - ) + return "\n".join([repr(agentset) for agentset in self._agentsets]) def __str__(self) -> str: - return str( - {agentset.__class__.__name__: str(agentset) for agentset in self._agentsets} - ) + return "\n".join([str(agentset) for agentset in self._agentsets]) def __reversed__(self) -> Iterator: return ( @@ -341,10 +345,8 @@ def __len__(self) -> int: return sum(len(agentset._agents) for agentset in self._agentsets) @property - def agents(self) -> dict[str, DataFrame]: - return { - agentset.__class__.__name__: agentset.agents for agentset in self._agentsets - } + def agents(self) -> dict[AgentSetDF, DataFrame]: + return {agentset: agentset.agents for agentset in self._agentsets} @agents.setter def agents(self, other: Iterable[AgentSetDF]) -> None: @@ -358,15 +360,9 @@ def agents(self, other: Iterable[AgentSetDF]) -> None: self._agentsets = list(other) @property - def active_agents(self) -> dict[str, DataFrame]: - return { - agentset.__class__.__name__: agentset.active_agents - for agentset in self._agentsets - } + def active_agents(self) -> dict[AgentSetDF, DataFrame]: + return {agentset: agentset.active_agents for agentset in self._agentsets} @property def inactive_agents(self): - return { - agentset.__class__.__name__: agentset.inactive_agents - for agentset in self._agentsets - } + return {agentset: agentset.inactive_agents for agentset in self._agentsets} diff --git a/mesa_frames/concrete/agentset_pandas.py b/mesa_frames/concrete/agentset_pandas.py index d2e7c13..6f89b01 100644 --- a/mesa_frames/concrete/agentset_pandas.py +++ b/mesa_frames/concrete/agentset_pandas.py @@ -1,8 +1,8 @@ from typing import ( + TYPE_CHECKING, Any, Callable, Collection, - Hashable, Iterator, Self, Sequence, @@ -14,9 +14,11 @@ from mesa_frames.abstract.agents import AgentSetDF from mesa_frames.concrete.agentset_polars import AgentSetPolars -from mesa_frames.concrete.model import ModelDF from mesa_frames.types import PandasIdsLike, PandasMaskLike +if TYPE_CHECKING: + from mesa_frames.concrete.model import ModelDF + class AgentSetPandas(AgentSetDF): _agents: pd.DataFrame @@ -102,12 +104,14 @@ class AgentSetPandas(AgentSetDF): Get the string representation of the AgentSetPandas. """ - def __init__(self, model: ModelDF) -> None: + def __init__(self, model: "ModelDF") -> None: self._model = model - self._agents = pd.DataFrame( - columns=["unique_id"], dtype={"unique_id": pd.Int64Dtype} - ).set_index("unique_id") - self._mask = pd.Series(True, index=self._agents.index, dtype=pd.BooleanDtype) + self._agents = ( + pd.DataFrame(columns=["unique_id"]) + .astype({"unique_id": "int64"}) + .set_index("unique_id") + ) + self._mask = pd.Series(True, index=self._agents.index, dtype=pd.BooleanDtype()) def add( self, @@ -138,6 +142,10 @@ def add( new_agents = pd.DataFrame([other], columns=columns).set_index( "unique_id", drop=True ) + + if new_agents.index.dtype != "int64": + raise TypeError("unique_id must be of type int64.") + obj._agents = pd.concat([obj._agents, new_agents]) return obj @@ -149,13 +157,13 @@ def contains(self, ids: PandasIdsLike) -> pd.Series: ... def contains( self, - ids: int | Collection[int] | pd.Series[int] | pd.Index, + ids: PandasIdsLike, ) -> bool | pd.Series: if isinstance(ids, pd.Series): return ids.isin(self._agents.index) elif isinstance(ids, pd.Index): return pd.Series( - ids.isin(self._agents.index), index=ids, dtype=pd.BooleanDtype + ids.isin(self._agents.index), index=ids, dtype=pd.BooleanDtype() ) elif isinstance(ids, Collection): return pd.Series(list(ids), index=list(ids)).isin(self._agents.index) @@ -341,15 +349,9 @@ def __iter__(self) -> Iterator: def __len__(self) -> int: return len(self._agents) - def __repr__(self) -> str: - return repr(self._agents) - def __reversed__(self) -> Iterator: return iter(self._agents[::-1].iterrows()) - def __str__(self) -> str: - return str(self._agents) - @property def agents(self) -> pd.DataFrame: return self._agents diff --git a/mesa_frames/concrete/agentset_polars.py b/mesa_frames/concrete/agentset_polars.py index 1b02094..13cd4e7 100644 --- a/mesa_frames/concrete/agentset_polars.py +++ b/mesa_frames/concrete/agentset_polars.py @@ -30,7 +30,6 @@ class AgentSetPolars(AgentSetDF): } _copy_only_reference: list[str] = ["_model", "_mask"] _mask: pl.Expr | pl.Series - _model: "ModelDF" """A polars-based implementation of the AgentSet. @@ -161,6 +160,10 @@ def add( "Length of data must match the number of columns in the AgentSet if being added as a Collection." ) new_agents = pl.DataFrame([other], schema=obj._agents.schema) + + if new_agents["unique_id"].dtype != pl.Int64: + raise TypeError("unique_id column must be of type int64.") + obj._agents = pl.concat([obj._agents, new_agents], how="diagonal_relaxed") return obj @@ -457,15 +460,9 @@ def __iter__(self) -> Iterator: def __len__(self) -> int: return len(self._agents) - def __repr__(self) -> str: - return repr(self._agents) - def __reversed__(self) -> Iterator: return reversed(iter(self._agents.iter_rows(named=True))) - def __str__(self) -> str: - return str(self._agents) - @property def agents(self) -> pl.DataFrame: return self._agents diff --git a/mesa_frames/test.py b/mesa_frames/test.py deleted file mode 100644 index d5d153e..0000000 --- a/mesa_frames/test.py +++ /dev/null @@ -1,4 +0,0 @@ -import pandas as pd - -x = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) - diff --git a/mesa_frames/types.py b/mesa_frames/types.py index a5da994..499f19f 100644 --- a/mesa_frames/types.py +++ b/mesa_frames/types.py @@ -11,7 +11,7 @@ ArrayLike = pd.api.extensions.ExtensionArray | ndarray AnyArrayLike = ArrayLike | pd.Index | pd.Series PandasMaskLike = AgnosticMask | pd.Series | pd.DataFrame | AnyArrayLike -PandasIdsLike = AgnosticIds | pd.Series[int] | pd.Index +PandasIdsLike = AgnosticIds | pd.Series | pd.Index ###----- Polars Types -----### import polars as pl @@ -25,4 +25,4 @@ Series = pd.Series | pl.Series BoolSeries = pd.Series | pl.Series MaskLike = AgnosticMask | PandasMaskLike | PolarsMaskLike -IdsLike = AgnosticIds | pd.Series[int] | pd.Index | pl.Series +IdsLike = AgnosticIds | PandasIdsLike | PolarsIdsLike diff --git a/requirements.txt b/requirements.txt index 36c4867..f45d0f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ geopandas == 0.14.1 pandas == 2.1.4 -numpy == 1.26.3 \ No newline at end of file +numpy == 1.26.3 +pyarrow == 16.1.0 \ No newline at end of file