From 9c62f831436a2de90c8ef20780b66c56a0684618 Mon Sep 17 00:00:00 2001 From: Corvince Date: Wed, 6 Mar 2024 19:14:06 +0100 Subject: [PATCH 1/4] Make Agents generic over model --- benchmarks/Schelling/schelling.py | 2 +- mesa/agent.py | 9 ++++++--- mesa/experimental/cell_space/cell_agent.py | 8 +++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/benchmarks/Schelling/schelling.py b/benchmarks/Schelling/schelling.py index c7dd3bf1deb..bbd4427f72f 100644 --- a/benchmarks/Schelling/schelling.py +++ b/benchmarks/Schelling/schelling.py @@ -3,7 +3,7 @@ from mesa.time import RandomActivation -class SchellingAgent(CellAgent): +class SchellingAgent(CellAgent["Schelling"]): """ Schelling segregation agent """ diff --git a/mesa/agent.py b/mesa/agent.py index 7ae76871b95..a6584e05532 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -3,6 +3,7 @@ Core Objects: Agent """ + # Mypy; for the `|` operator purpose # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations @@ -17,7 +18,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -25,8 +26,10 @@ from mesa.model import Model from mesa.space import Position +T = TypeVar("T", bound=Model) + -class Agent: +class Agent(Generic[T]): """ Base class for a model agent in Mesa. @@ -36,7 +39,7 @@ class Agent: self.pos: Position | None = None """ - def __init__(self, unique_id: int, model: Model) -> None: + def __init__(self, unique_id: int, model: T) -> None: """ Create a new agent. diff --git a/mesa/experimental/cell_space/cell_agent.py b/mesa/experimental/cell_space/cell_agent.py index abc5155a670..ea79be8aa16 100644 --- a/mesa/experimental/cell_space/cell_agent.py +++ b/mesa/experimental/cell_space/cell_agent.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar from mesa import Agent, Model if TYPE_CHECKING: from mesa.experimental.cell_space.cell import Cell +T = TypeVar("T", bound=Model) -class CellAgent(Agent): + +class CellAgent(Agent[T], Generic[T]): """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces @@ -19,7 +21,7 @@ class CellAgent(Agent): cell: (Cell | None): the cell which the agent occupies """ - def __init__(self, unique_id: int, model: Model) -> None: + def __init__(self, unique_id: int, model: T) -> None: """ Create a new agent. From c58ae733a635e252605e0a514156fe23c4e54aa3 Mon Sep 17 00:00:00 2001 From: Corvince Date: Wed, 6 Mar 2024 20:01:55 +0100 Subject: [PATCH 2/4] Make Cells generic of Agents --- benchmarks/Schelling/schelling.py | 5 ++++- mesa/experimental/cell_space/cell.py | 16 +++++++++++----- mesa/experimental/cell_space/cell_agent.py | 7 ++++--- .../cell_space/cell_collection.py | 11 ++++++----- .../experimental/cell_space/discrete_space.py | 11 ++++++----- mesa/experimental/cell_space/grid.py | 19 ++++++++++++------- 6 files changed, 43 insertions(+), 26 deletions(-) diff --git a/benchmarks/Schelling/schelling.py b/benchmarks/Schelling/schelling.py index bbd4427f72f..065b9ad6d39 100644 --- a/benchmarks/Schelling/schelling.py +++ b/benchmarks/Schelling/schelling.py @@ -1,9 +1,10 @@ from mesa import Model from mesa.experimental.cell_space import CellAgent, OrthogonalMooreGrid +from mesa.experimental.cell_space.cell import Cell from mesa.time import RandomActivation -class SchellingAgent(CellAgent["Schelling"]): +class SchellingAgent(CellAgent["Schelling", "SchellingAgent"]): """ Schelling segregation agent """ @@ -73,6 +74,8 @@ def __init__( torus=True, capacity=1, random=self.random, + cell_klass=Cell[SchellingAgent], + agent_class=SchellingAgent, ) self.happy = 0 diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 55264f68daa..b948a4d6af9 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -2,15 +2,17 @@ from functools import cache from random import Random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, Generic from mesa.experimental.cell_space.cell_collection import CellCollection if TYPE_CHECKING: from mesa.experimental.cell_space.cell_agent import CellAgent +U = TypeVar("U", bound=CellAgent) -class Cell: + +class Cell(Generic[U]): """The cell represents a position in a discrete space. Attributes: @@ -56,8 +58,10 @@ def __init__( """ super().__init__() self.coordinate = coordinate - self._connections: list[Cell] = [] # TODO: change to CellCollection? - self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) + self._connections: list[Cell[U]] = [] # TODO: change to CellCollection? + self.agents = ( + [] + ) # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) self.capacity = capacity self.properties: dict[str, object] = {} self.random = random @@ -121,7 +125,9 @@ def __repr__(self): # FIXME: Revisit caching strategy on methods @cache # noqa: B019 - def neighborhood(self, radius=1, include_center=False): + def neighborhood( + self, radius=1, include_center=False + ) -> CellCollection[Cell[U], U]: return CellCollection( self._neighborhood(radius=radius, include_center=include_center), random=self.random, diff --git a/mesa/experimental/cell_space/cell_agent.py b/mesa/experimental/cell_space/cell_agent.py index ea79be8aa16..3294c098509 100644 --- a/mesa/experimental/cell_space/cell_agent.py +++ b/mesa/experimental/cell_space/cell_agent.py @@ -8,9 +8,10 @@ from mesa.experimental.cell_space.cell import Cell T = TypeVar("T", bound=Model) +U = TypeVar("U", bound="CellAgent") -class CellAgent(Agent[T], Generic[T]): +class CellAgent(Agent[T], Generic[T, U]): """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces @@ -30,9 +31,9 @@ def __init__(self, unique_id: int, model: T) -> None: model (Model): The model instance in which the agent exists. """ super().__init__(unique_id, model) - self.cell: Cell | None = None + self.cell: Cell[U] | None = None - def move_to(self, cell) -> None: + def move_to(self, cell: Cell[U]) -> None: if self.cell is not None: self.cell.remove_agent(self) self.cell = cell diff --git a/mesa/experimental/cell_space/cell_collection.py b/mesa/experimental/cell_space/cell_collection.py index 9ca36589849..503ffde1759 100644 --- a/mesa/experimental/cell_space/cell_collection.py +++ b/mesa/experimental/cell_space/cell_collection.py @@ -11,9 +11,10 @@ from mesa.experimental.cell_space.cell_agent import CellAgent T = TypeVar("T", bound="Cell") +U = TypeVar("U", bound="CellAgent") -class CellCollection(Generic[T]): +class CellCollection(Generic[T, U]): """An immutable collection of cells Attributes: @@ -25,7 +26,7 @@ class CellCollection(Generic[T]): def __init__( self, - cells: Mapping[T, list[CellAgent]] | Iterable[T], + cells: Mapping[T, list[U]] | Iterable[T], random: Random | None = None, ) -> None: if isinstance(cells, dict): @@ -43,7 +44,7 @@ def __init__( def __iter__(self): return iter(self._cells) - def __getitem__(self, key: T) -> Iterable[CellAgent]: + def __getitem__(self, key: T) -> Iterable[U]: return self._cells[key] # @cached_property @@ -58,13 +59,13 @@ def cells(self) -> list[T]: return list(self._cells.keys()) @property - def agents(self) -> Iterable[CellAgent]: + def agents(self) -> Iterable[U]: return itertools.chain.from_iterable(self._cells.values()) def select_random_cell(self) -> T: return self.random.choice(self.cells) - def select_random_agent(self) -> CellAgent: + def select_random_agent(self) -> U: return self.random.choice(list(self.agents)) def select(self, filter_func: Callable[[T], bool] | None = None, n=0): diff --git a/mesa/experimental/cell_space/discrete_space.py b/mesa/experimental/cell_space/discrete_space.py index d2161c5b46a..fe3c0af78de 100644 --- a/mesa/experimental/cell_space/discrete_space.py +++ b/mesa/experimental/cell_space/discrete_space.py @@ -4,13 +4,14 @@ from random import Random from typing import Generic, TypeVar -from mesa.experimental.cell_space.cell import Cell +from mesa.experimental.cell_space import Cell, CellAgent from mesa.experimental.cell_space.cell_collection import CellCollection T = TypeVar("T", bound=Cell) +U = TypeVar("U", bound=CellAgent) -class DiscreteSpace(Generic[T]): +class DiscreteSpace(Generic[T, U]): """Base class for all discrete spaces. Attributes: @@ -25,8 +26,9 @@ class DiscreteSpace(Generic[T]): def __init__( self, capacity: int | None = None, - cell_klass: type[T] = Cell, + cell_klass: type[T] = Cell[U], random: Random | None = None, + agent_class: type[U] = CellAgent, ): super().__init__() self.capacity = capacity @@ -43,8 +45,7 @@ def __init__( def cutoff_empties(self): return 7.953 * len(self._cells) ** 0.384 - def _connect_single_cell(self, cell: T): - ... + def _connect_single_cell(self, cell: T): ... @cached_property def all_cells(self): diff --git a/mesa/experimental/cell_space/grid.py b/mesa/experimental/cell_space/grid.py index cc4b4b9e489..a4a752b6869 100644 --- a/mesa/experimental/cell_space/grid.py +++ b/mesa/experimental/cell_space/grid.py @@ -5,12 +5,13 @@ from random import Random from typing import Generic, TypeVar -from mesa.experimental.cell_space import Cell, DiscreteSpace +from mesa.experimental.cell_space import Cell, CellAgent, DiscreteSpace T = TypeVar("T", bound=Cell) +U = TypeVar("U", bound=CellAgent) -class Grid(DiscreteSpace, Generic[T]): +class Grid(DiscreteSpace[T, U], Generic[T, U]): """Base class for all grid classes Attributes: @@ -29,8 +30,14 @@ def __init__( capacity: float | None = None, random: Random | None = None, cell_klass: type[T] = Cell, + agent_class: type[U] = CellAgent, ) -> None: - super().__init__(capacity=capacity, random=random, cell_klass=cell_klass) + super().__init__( + capacity=capacity, + random=random, + cell_klass=cell_klass, + agent_class=agent_class, + ) self.torus = torus self.dimensions = dimensions self._try_random = True @@ -51,11 +58,9 @@ def _connect_cells(self) -> None: else: self._connect_cells_nd() - def _connect_cells_2d(self) -> None: - ... + def _connect_cells_2d(self) -> None: ... - def _connect_cells_nd(self) -> None: - ... + def _connect_cells_nd(self) -> None: ... def _validate_parameters(self): if not all(isinstance(dim, int) and dim > 0 for dim in self.dimensions): From 8d2fc29ea96e8b73c19fd2080ca29f24d396c99e Mon Sep 17 00:00:00 2001 From: Corvince Date: Wed, 6 Mar 2024 20:04:44 +0100 Subject: [PATCH 3/4] Fix imports --- mesa/agent.py | 2 +- mesa/experimental/cell_space/cell.py | 2 +- mesa/experimental/cell_space/grid.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index a6584e05532..2b14bbf6ef0 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -26,7 +26,7 @@ from mesa.model import Model from mesa.space import Position -T = TypeVar("T", bound=Model) +T = TypeVar("T", bound="Model") class Agent(Generic[T]): diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index b948a4d6af9..1fcea747076 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from mesa.experimental.cell_space.cell_agent import CellAgent -U = TypeVar("U", bound=CellAgent) +U = TypeVar("U", bound="CellAgent") class Cell(Generic[U]): diff --git a/mesa/experimental/cell_space/grid.py b/mesa/experimental/cell_space/grid.py index a4a752b6869..457a2ce274f 100644 --- a/mesa/experimental/cell_space/grid.py +++ b/mesa/experimental/cell_space/grid.py @@ -110,7 +110,7 @@ def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> No cell.connect(self._cells[ni, nj]) -class OrthogonalMooreGrid(Grid[T]): +class OrthogonalMooreGrid(Grid[T, U]): """Grid where cells are connected to their 8 neighbors. Example for two dimensions: @@ -142,7 +142,7 @@ def _connect_cells_nd(self) -> None: self._connect_single_cell_nd(cell, offsets) -class OrthogonalVonNeumannGrid(Grid[T]): +class OrthogonalVonNeumannGrid(Grid[T, U]): """Grid where cells are connected to their 4 neighbors. Example for two dimensions: @@ -182,7 +182,7 @@ def _connect_cells_nd(self) -> None: self._connect_single_cell_nd(cell, offsets) -class HexGrid(Grid[T]): +class HexGrid(Grid[T, U]): def _connect_cells_2d(self) -> None: # fmt: off even_offsets = [ From de057a48bb472763d686b40234938d436026ffdc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:25:30 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/experimental/cell_space/cell.py | 6 ++---- mesa/experimental/cell_space/discrete_space.py | 3 ++- mesa/experimental/cell_space/grid.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 1fcea747076..3e283270db4 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -2,7 +2,7 @@ from functools import cache from random import Random -from typing import TYPE_CHECKING, TypeVar, Generic +from typing import TYPE_CHECKING, Generic, TypeVar from mesa.experimental.cell_space.cell_collection import CellCollection @@ -59,9 +59,7 @@ def __init__( super().__init__() self.coordinate = coordinate self._connections: list[Cell[U]] = [] # TODO: change to CellCollection? - self.agents = ( - [] - ) # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) + self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) self.capacity = capacity self.properties: dict[str, object] = {} self.random = random diff --git a/mesa/experimental/cell_space/discrete_space.py b/mesa/experimental/cell_space/discrete_space.py index fe3c0af78de..a05ffdb8828 100644 --- a/mesa/experimental/cell_space/discrete_space.py +++ b/mesa/experimental/cell_space/discrete_space.py @@ -45,7 +45,8 @@ def __init__( def cutoff_empties(self): return 7.953 * len(self._cells) ** 0.384 - def _connect_single_cell(self, cell: T): ... + def _connect_single_cell(self, cell: T): + ... @cached_property def all_cells(self): diff --git a/mesa/experimental/cell_space/grid.py b/mesa/experimental/cell_space/grid.py index 457a2ce274f..1f722df9f73 100644 --- a/mesa/experimental/cell_space/grid.py +++ b/mesa/experimental/cell_space/grid.py @@ -58,9 +58,11 @@ def _connect_cells(self) -> None: else: self._connect_cells_nd() - def _connect_cells_2d(self) -> None: ... + def _connect_cells_2d(self) -> None: + ... - def _connect_cells_nd(self) -> None: ... + def _connect_cells_nd(self) -> None: + ... def _validate_parameters(self): if not all(isinstance(dim, int) and dim > 0 for dim in self.dimensions):