diff --git a/benchmarks/Schelling/schelling.py b/benchmarks/Schelling/schelling.py index c7dd3bf1deb..065b9ad6d39 100644 --- a/benchmarks/Schelling/schelling.py +++ b/benchmarks/Schelling/schelling.py @@ -1,9 +1,10 @@ from mesa import Model from mesa.experimental.cell_space import CellAgent, OrthogonalMooreGrid +from mesa.experimental.cell_space.cell import Cell from mesa.time import RandomActivation -class SchellingAgent(CellAgent): +class SchellingAgent(CellAgent["Schelling", "SchellingAgent"]): """ Schelling segregation agent """ @@ -73,6 +74,8 @@ def __init__( torus=True, capacity=1, random=self.random, + cell_klass=Cell[SchellingAgent], + agent_class=SchellingAgent, ) self.happy = 0 diff --git a/mesa/agent.py b/mesa/agent.py index 7ae76871b95..2b14bbf6ef0 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -3,6 +3,7 @@ Core Objects: Agent """ + # Mypy; for the `|` operator purpose # Remove this __future__ import once the oldest supported Python is 3.10 from __future__ import annotations @@ -17,7 +18,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -25,8 +26,10 @@ from mesa.model import Model from mesa.space import Position +T = TypeVar("T", bound="Model") + -class Agent: +class Agent(Generic[T]): """ Base class for a model agent in Mesa. @@ -36,7 +39,7 @@ class Agent: self.pos: Position | None = None """ - def __init__(self, unique_id: int, model: Model) -> None: + def __init__(self, unique_id: int, model: T) -> None: """ Create a new agent. diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 55264f68daa..3e283270db4 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -2,15 +2,17 @@ from functools import cache from random import Random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar from mesa.experimental.cell_space.cell_collection import CellCollection if TYPE_CHECKING: from mesa.experimental.cell_space.cell_agent import CellAgent +U = TypeVar("U", bound="CellAgent") -class Cell: + +class Cell(Generic[U]): """The cell represents a position in a discrete space. Attributes: @@ -56,7 +58,7 @@ def __init__( """ super().__init__() self.coordinate = coordinate - self._connections: list[Cell] = [] # TODO: change to CellCollection? + self._connections: list[Cell[U]] = [] # TODO: change to CellCollection? self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) self.capacity = capacity self.properties: dict[str, object] = {} @@ -121,7 +123,9 @@ def __repr__(self): # FIXME: Revisit caching strategy on methods @cache # noqa: B019 - def neighborhood(self, radius=1, include_center=False): + def neighborhood( + self, radius=1, include_center=False + ) -> CellCollection[Cell[U], U]: return CellCollection( self._neighborhood(radius=radius, include_center=include_center), random=self.random, diff --git a/mesa/experimental/cell_space/cell_agent.py b/mesa/experimental/cell_space/cell_agent.py index abc5155a670..3294c098509 100644 --- a/mesa/experimental/cell_space/cell_agent.py +++ b/mesa/experimental/cell_space/cell_agent.py @@ -1,14 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar from mesa import Agent, Model if TYPE_CHECKING: from mesa.experimental.cell_space.cell import Cell +T = TypeVar("T", bound=Model) +U = TypeVar("U", bound="CellAgent") -class CellAgent(Agent): + +class CellAgent(Agent[T], Generic[T, U]): """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces @@ -19,7 +22,7 @@ class CellAgent(Agent): cell: (Cell | None): the cell which the agent occupies """ - def __init__(self, unique_id: int, model: Model) -> None: + def __init__(self, unique_id: int, model: T) -> None: """ Create a new agent. @@ -28,9 +31,9 @@ def __init__(self, unique_id: int, model: Model) -> None: model (Model): The model instance in which the agent exists. """ super().__init__(unique_id, model) - self.cell: Cell | None = None + self.cell: Cell[U] | None = None - def move_to(self, cell) -> None: + def move_to(self, cell: Cell[U]) -> None: if self.cell is not None: self.cell.remove_agent(self) self.cell = cell diff --git a/mesa/experimental/cell_space/cell_collection.py b/mesa/experimental/cell_space/cell_collection.py index 9ca36589849..503ffde1759 100644 --- a/mesa/experimental/cell_space/cell_collection.py +++ b/mesa/experimental/cell_space/cell_collection.py @@ -11,9 +11,10 @@ from mesa.experimental.cell_space.cell_agent import CellAgent T = TypeVar("T", bound="Cell") +U = TypeVar("U", bound="CellAgent") -class CellCollection(Generic[T]): +class CellCollection(Generic[T, U]): """An immutable collection of cells Attributes: @@ -25,7 +26,7 @@ class CellCollection(Generic[T]): def __init__( self, - cells: Mapping[T, list[CellAgent]] | Iterable[T], + cells: Mapping[T, list[U]] | Iterable[T], random: Random | None = None, ) -> None: if isinstance(cells, dict): @@ -43,7 +44,7 @@ def __init__( def __iter__(self): return iter(self._cells) - def __getitem__(self, key: T) -> Iterable[CellAgent]: + def __getitem__(self, key: T) -> Iterable[U]: return self._cells[key] # @cached_property @@ -58,13 +59,13 @@ def cells(self) -> list[T]: return list(self._cells.keys()) @property - def agents(self) -> Iterable[CellAgent]: + def agents(self) -> Iterable[U]: return itertools.chain.from_iterable(self._cells.values()) def select_random_cell(self) -> T: return self.random.choice(self.cells) - def select_random_agent(self) -> CellAgent: + def select_random_agent(self) -> U: return self.random.choice(list(self.agents)) def select(self, filter_func: Callable[[T], bool] | None = None, n=0): diff --git a/mesa/experimental/cell_space/discrete_space.py b/mesa/experimental/cell_space/discrete_space.py index d2161c5b46a..a05ffdb8828 100644 --- a/mesa/experimental/cell_space/discrete_space.py +++ b/mesa/experimental/cell_space/discrete_space.py @@ -4,13 +4,14 @@ from random import Random from typing import Generic, TypeVar -from mesa.experimental.cell_space.cell import Cell +from mesa.experimental.cell_space import Cell, CellAgent from mesa.experimental.cell_space.cell_collection import CellCollection T = TypeVar("T", bound=Cell) +U = TypeVar("U", bound=CellAgent) -class DiscreteSpace(Generic[T]): +class DiscreteSpace(Generic[T, U]): """Base class for all discrete spaces. Attributes: @@ -25,8 +26,9 @@ class DiscreteSpace(Generic[T]): def __init__( self, capacity: int | None = None, - cell_klass: type[T] = Cell, + cell_klass: type[T] = Cell[U], random: Random | None = None, + agent_class: type[U] = CellAgent, ): super().__init__() self.capacity = capacity diff --git a/mesa/experimental/cell_space/grid.py b/mesa/experimental/cell_space/grid.py index cc4b4b9e489..1f722df9f73 100644 --- a/mesa/experimental/cell_space/grid.py +++ b/mesa/experimental/cell_space/grid.py @@ -5,12 +5,13 @@ from random import Random from typing import Generic, TypeVar -from mesa.experimental.cell_space import Cell, DiscreteSpace +from mesa.experimental.cell_space import Cell, CellAgent, DiscreteSpace T = TypeVar("T", bound=Cell) +U = TypeVar("U", bound=CellAgent) -class Grid(DiscreteSpace, Generic[T]): +class Grid(DiscreteSpace[T, U], Generic[T, U]): """Base class for all grid classes Attributes: @@ -29,8 +30,14 @@ def __init__( capacity: float | None = None, random: Random | None = None, cell_klass: type[T] = Cell, + agent_class: type[U] = CellAgent, ) -> None: - super().__init__(capacity=capacity, random=random, cell_klass=cell_klass) + super().__init__( + capacity=capacity, + random=random, + cell_klass=cell_klass, + agent_class=agent_class, + ) self.torus = torus self.dimensions = dimensions self._try_random = True @@ -105,7 +112,7 @@ def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> No cell.connect(self._cells[ni, nj]) -class OrthogonalMooreGrid(Grid[T]): +class OrthogonalMooreGrid(Grid[T, U]): """Grid where cells are connected to their 8 neighbors. Example for two dimensions: @@ -137,7 +144,7 @@ def _connect_cells_nd(self) -> None: self._connect_single_cell_nd(cell, offsets) -class OrthogonalVonNeumannGrid(Grid[T]): +class OrthogonalVonNeumannGrid(Grid[T, U]): """Grid where cells are connected to their 4 neighbors. Example for two dimensions: @@ -177,7 +184,7 @@ def _connect_cells_nd(self) -> None: self._connect_single_cell_nd(cell, offsets) -class HexGrid(Grid[T]): +class HexGrid(Grid[T, U]): def _connect_cells_2d(self) -> None: # fmt: off even_offsets = [