Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Typing support for (Cell)agents #2072

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmarks/Schelling/schelling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from mesa import Model
from mesa.experimental.cell_space import CellAgent, OrthogonalMooreGrid
from mesa.experimental.cell_space.cell import Cell
from mesa.time import RandomActivation


class SchellingAgent(CellAgent):
class SchellingAgent(CellAgent["Schelling", "SchellingAgent"]):
"""
Schelling segregation agent
"""
Expand Down Expand Up @@ -73,6 +74,8 @@ def __init__(
torus=True,
capacity=1,
random=self.random,
cell_klass=Cell[SchellingAgent],
agent_class=SchellingAgent,
)

self.happy = 0
Expand Down
9 changes: 6 additions & 3 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Core Objects: Agent
"""

# Mypy; for the `|` operator purpose
# Remove this __future__ import once the oldest supported Python is 3.10
from __future__ import annotations
Expand All @@ -17,16 +18,18 @@
from random import Random

# mypy
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

if TYPE_CHECKING:
# We ensure that these are not imported during runtime to prevent cyclic
# dependency.
from mesa.model import Model
from mesa.space import Position

T = TypeVar("T", bound="Model")


class Agent:
class Agent(Generic[T]):
"""
Base class for a model agent in Mesa.

Expand All @@ -36,7 +39,7 @@ class Agent:
self.pos: Position | None = None
"""

def __init__(self, unique_id: int, model: Model) -> None:
def __init__(self, unique_id: int, model: T) -> None:
"""
Create a new agent.

Expand Down
12 changes: 8 additions & 4 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

from functools import cache
from random import Random
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, TypeVar

from mesa.experimental.cell_space.cell_collection import CellCollection

if TYPE_CHECKING:
from mesa.experimental.cell_space.cell_agent import CellAgent

U = TypeVar("U", bound="CellAgent")

class Cell:

class Cell(Generic[U]):
"""The cell represents a position in a discrete space.

Attributes:
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(
"""
super().__init__()
self.coordinate = coordinate
self._connections: list[Cell] = [] # TODO: change to CellCollection?
self._connections: list[Cell[U]] = [] # TODO: change to CellCollection?
self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
self.capacity = capacity
self.properties: dict[str, object] = {}
Expand Down Expand Up @@ -121,7 +123,9 @@ def __repr__(self):

# FIXME: Revisit caching strategy on methods
@cache # noqa: B019
def neighborhood(self, radius=1, include_center=False):
def neighborhood(
self, radius=1, include_center=False
) -> CellCollection[Cell[U], U]:
return CellCollection(
self._neighborhood(radius=radius, include_center=include_center),
random=self.random,
Expand Down
13 changes: 8 additions & 5 deletions mesa/experimental/cell_space/cell_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, TypeVar

from mesa import Agent, Model

if TYPE_CHECKING:
from mesa.experimental.cell_space.cell import Cell

T = TypeVar("T", bound=Model)
U = TypeVar("U", bound="CellAgent")

class CellAgent(Agent):

class CellAgent(Agent[T], Generic[T, U]):
"""Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces


Expand All @@ -19,7 +22,7 @@ class CellAgent(Agent):
cell: (Cell | None): the cell which the agent occupies
"""

def __init__(self, unique_id: int, model: Model) -> None:
def __init__(self, unique_id: int, model: T) -> None:
"""
Create a new agent.

Expand All @@ -28,9 +31,9 @@ def __init__(self, unique_id: int, model: Model) -> None:
model (Model): The model instance in which the agent exists.
"""
super().__init__(unique_id, model)
self.cell: Cell | None = None
self.cell: Cell[U] | None = None

def move_to(self, cell) -> None:
def move_to(self, cell: Cell[U]) -> None:
if self.cell is not None:
self.cell.remove_agent(self)
self.cell = cell
Expand Down
11 changes: 6 additions & 5 deletions mesa/experimental/cell_space/cell_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from mesa.experimental.cell_space.cell_agent import CellAgent

T = TypeVar("T", bound="Cell")
U = TypeVar("U", bound="CellAgent")


class CellCollection(Generic[T]):
class CellCollection(Generic[T, U]):
"""An immutable collection of cells

Attributes:
Expand All @@ -25,7 +26,7 @@ class CellCollection(Generic[T]):

def __init__(
self,
cells: Mapping[T, list[CellAgent]] | Iterable[T],
cells: Mapping[T, list[U]] | Iterable[T],
random: Random | None = None,
) -> None:
if isinstance(cells, dict):
Expand All @@ -43,7 +44,7 @@ def __init__(
def __iter__(self):
return iter(self._cells)

def __getitem__(self, key: T) -> Iterable[CellAgent]:
def __getitem__(self, key: T) -> Iterable[U]:
return self._cells[key]

# @cached_property
Expand All @@ -58,13 +59,13 @@ def cells(self) -> list[T]:
return list(self._cells.keys())

@property
def agents(self) -> Iterable[CellAgent]:
def agents(self) -> Iterable[U]:
return itertools.chain.from_iterable(self._cells.values())

def select_random_cell(self) -> T:
return self.random.choice(self.cells)

def select_random_agent(self) -> CellAgent:
def select_random_agent(self) -> U:
return self.random.choice(list(self.agents))

def select(self, filter_func: Callable[[T], bool] | None = None, n=0):
Expand Down
8 changes: 5 additions & 3 deletions mesa/experimental/cell_space/discrete_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from random import Random
from typing import Generic, TypeVar

from mesa.experimental.cell_space.cell import Cell
from mesa.experimental.cell_space import Cell, CellAgent
from mesa.experimental.cell_space.cell_collection import CellCollection

T = TypeVar("T", bound=Cell)
U = TypeVar("U", bound=CellAgent)


class DiscreteSpace(Generic[T]):
class DiscreteSpace(Generic[T, U]):
"""Base class for all discrete spaces.

Attributes:
Expand All @@ -25,8 +26,9 @@ class DiscreteSpace(Generic[T]):
def __init__(
self,
capacity: int | None = None,
cell_klass: type[T] = Cell,
cell_klass: type[T] = Cell[U],
random: Random | None = None,
agent_class: type[U] = CellAgent,
):
super().__init__()
self.capacity = capacity
Expand Down
19 changes: 13 additions & 6 deletions mesa/experimental/cell_space/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from random import Random
from typing import Generic, TypeVar

from mesa.experimental.cell_space import Cell, DiscreteSpace
from mesa.experimental.cell_space import Cell, CellAgent, DiscreteSpace

T = TypeVar("T", bound=Cell)
U = TypeVar("U", bound=CellAgent)


class Grid(DiscreteSpace, Generic[T]):
class Grid(DiscreteSpace[T, U], Generic[T, U]):
"""Base class for all grid classes

Attributes:
Expand All @@ -29,8 +30,14 @@ def __init__(
capacity: float | None = None,
random: Random | None = None,
cell_klass: type[T] = Cell,
agent_class: type[U] = CellAgent,
) -> None:
super().__init__(capacity=capacity, random=random, cell_klass=cell_klass)
super().__init__(
capacity=capacity,
random=random,
cell_klass=cell_klass,
agent_class=agent_class,
)
self.torus = torus
self.dimensions = dimensions
self._try_random = True
Expand Down Expand Up @@ -105,7 +112,7 @@ def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> No
cell.connect(self._cells[ni, nj])


class OrthogonalMooreGrid(Grid[T]):
class OrthogonalMooreGrid(Grid[T, U]):
"""Grid where cells are connected to their 8 neighbors.

Example for two dimensions:
Expand Down Expand Up @@ -137,7 +144,7 @@ def _connect_cells_nd(self) -> None:
self._connect_single_cell_nd(cell, offsets)


class OrthogonalVonNeumannGrid(Grid[T]):
class OrthogonalVonNeumannGrid(Grid[T, U]):
"""Grid where cells are connected to their 4 neighbors.

Example for two dimensions:
Expand Down Expand Up @@ -177,7 +184,7 @@ def _connect_cells_nd(self) -> None:
self._connect_single_cell_nd(cell, offsets)


class HexGrid(Grid[T]):
class HexGrid(Grid[T, U]):
def _connect_cells_2d(self) -> None:
# fmt: off
even_offsets = [
Expand Down
Loading