diff --git a/codecov.yaml b/codecov.yaml index a19c3daee15..5aa4e0dabc1 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -7,7 +7,6 @@ coverage: ignore: - "benchmarks/**" - - "mesa/experimental/**" - "mesa/visualization/**" comment: off diff --git a/docs/tutorials/intro_tutorial.ipynb b/docs/tutorials/intro_tutorial.ipynb index 4f257737abd..3398612bcb7 100644 --- a/docs/tutorials/intro_tutorial.ipynb +++ b/docs/tutorials/intro_tutorial.ipynb @@ -294,15 +294,17 @@ " def __init__(self, n, seed=None):\n", " super().__init__(seed=seed)\n", " self.num_agents = n\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", - " a = MoneyAgent(self) # This calls the agent class parameter n number of times \n", + " a = MoneyAgent(\n", + " self\n", + " ) # This calls the agent class parameter n number of times\n", "\n", " def step(self):\n", " \"\"\"Advance the model by one step.\"\"\"\n", "\n", - " # This function psuedo-randomly reorders the list of agent objects and \n", + " # This function psuedo-randomly reorders the list of agent objects and\n", " # then iterates through calling the function passed in as the parameter\n", " self.agents.shuffle_do(\"say_hi\")" ] @@ -467,21 +469,24 @@ " other_agent.wealth += 1\n", " self.wealth -= 1\n", "\n", + "\n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", "\n", " def __init__(self, n):\n", " super().__init__()\n", " self.num_agents = n\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", - " a = MoneyAgent(self) # This calls the agent class parameter n number of times \n", + " a = MoneyAgent(\n", + " self\n", + " ) # This calls the agent class parameter n number of times\n", "\n", " def step(self):\n", " \"\"\"Advance the model by one step.\"\"\"\n", "\n", - " # This function psuedo-randomly reorders the list of agent objects and \n", + " # This function psuedo-randomly reorders the list of agent objects and\n", " # then iterates through calling the function passed in as the parameter\n", " self.agents.shuffle_do(\"exchange\")" ] @@ -509,8 +514,10 @@ "metadata": {}, "outputs": [], "source": [ - "model = MoneyModel(10) # Tels the model to create 10 agents\n", - "for _ in range(30): #Runs the model for 10 steps; an underscore is common convention for a variable that is not used\n", + "model = MoneyModel(10) # Tells the model to create 10 agents\n", + "for _ in range(\n", + " 30\n", + "): # Runs the model for 10 steps; an underscore is common convention for a variable that is not used\n", " model.step()" ] }, @@ -665,23 +672,16 @@ " #...\n", " def give_money(self):\n", " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", - " if len(cellmates) > 1:\n", + " # Ensure agent is not giving money to itself\n", + " cellmates.pop(\n", + " cellmates.index(self)\n", + " )\n", + " if len(cellmates) > 0:\n", " other = self.random.choice(cellmates)\n", " other.wealth += 1\n", " self.wealth -= 1\n", "```\n", "\n", - "And with those two methods, the agent's ``step`` method becomes:\n", - "\n", - "```python\n", - "class MoneyAgent(mesa.Agent):\n", - " # ...\n", - " def step(self):\n", - " self.move()\n", - " if self.wealth > 0:\n", - " self.give_money()\n", - "```\n", - "\n", "Now, putting that all together should look like this:" ] }, @@ -706,22 +706,23 @@ " self.model.grid.move_agent(self, new_position)\n", "\n", " def give_money(self):\n", - " if self.wealth > 0: \n", - " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", - " if len(cellmates) > 1:\n", - " other_agent = self.random.choice(cellmates)\n", - " other_agent.wealth += 1\n", - " self.wealth -= 1\n", + " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", + " # Ensure agent is not giving money to itself\n", + " cellmates.pop(cellmates.index(self))\n", + " if len(cellmates) > 0:\n", + " other_agent = self.random.choice(cellmates)\n", + " other_agent.wealth += 1\n", + " self.wealth -= 1\n", + "\n", "\n", - " \n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", "\n", - " def __init__(self, n, width, height,seed=None):\n", + " def __init__(self, n, width, height, seed=None):\n", " super().__init__(seed=seed)\n", " self.num_agents = n\n", " self.grid = mesa.space.MultiGrid(width, height, True)\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", " a = MoneyAgent(self)\n", @@ -732,7 +733,7 @@ "\n", " def step(self):\n", " self.agents.shuffle_do(\"move\")\n", - " self.agents.shuffle_do(\"give_money\")" + " self.agents.do(\"give_money\")" ] }, { @@ -791,7 +792,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Challenge: Change from multigrid to grid (only one agent per cell) " + "# Challenge: Change from multigrid to grid (only one agent per cell)" ] }, { @@ -845,22 +846,13 @@ "\n", " def give_money(self):\n", " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", - " cellmates.pop(\n", - " cellmates.index(self)\n", - " ) # Ensure agent is not giving money to itself\n", - " if len(cellmates) > 1:\n", + " # Ensure agent is not giving money to itself\n", + " cellmates.pop(cellmates.index(self))\n", + " if len(cellmates) > 0:\n", " other = self.random.choice(cellmates)\n", " other.wealth += 1\n", " self.wealth -= 1\n", - " if other == self:\n", - " print(\"I JUST GAVE MOnEY TO MYSELF HEHEHE!\")\n", - "\n", - " # There are several ways in which one can combine functions to execute the model\n", - " def agent_act(self): \n", - " self.move()\n", - " if self.wealth > 0:\n", - " self.give_money()\n", - "\n", + " \n", "\n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", @@ -883,10 +875,10 @@ " y = self.random.randrange(self.grid.height)\n", " self.grid.place_agent(a, (x, y))\n", "\n", - " \n", " def step(self):\n", " self.datacollector.collect(self)\n", - " self.agents.shuffle_do(\"agent_act\")" + " self.agents.shuffle_do(\"move\")\n", + " self.agents.do(\"give_money\")" ] }, { @@ -1126,22 +1118,24 @@ " self.model.grid.move_agent(self, new_position)\n", "\n", " def give_money(self):\n", - " if self.wealth > 0: \n", + " if self.wealth > 0:\n", " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", - " if len(cellmates) > 1:\n", + " # Ensure agent is not giving money to itself\n", + " cellmates.pop(cellmates.index(self))\n", + " if len(cellmates) > 0:\n", " other_agent = self.random.choice(cellmates)\n", " other_agent.wealth += 1\n", " self.wealth -= 1\n", - " \n", + "\n", "\n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", "\n", - " def __init__(self, n, width, height,seed=None):\n", + " def __init__(self, n, width, height, seed=None):\n", " super().__init__(seed=seed)\n", " self.num_agents = n\n", " self.grid = mesa.space.MultiGrid(width, height, True)\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", " a = MoneyAgent(self)\n", @@ -1226,11 +1220,11 @@ " self.wealth = 1\n", "\n", " def give_money(self, poor_agents):\n", - " if self.wealth > 0: \n", + " if self.wealth > 0:\n", " other_agent = self.random.choice(poor_agents)\n", " other_agent.wealth += 1\n", " self.wealth -= 1\n", - " \n", + "\n", "\n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", @@ -1238,7 +1232,7 @@ " def __init__(self, n):\n", " super().__init__()\n", " self.num_agents = n\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", " a = MoneyAgent(self)\n", @@ -1246,17 +1240,17 @@ " self.datacollector = mesa.DataCollector(\n", " model_reporters={\"Gini\": compute_gini}, agent_reporters={\"Wealth\": \"wealth\"}\n", " )\n", - " \n", + "\n", " def step(self):\n", " self.datacollector.collect(self)\n", " # Get lists of rich and poor agents\n", " rich_agents = model.agents.select(lambda a: a.wealth >= 3)\n", " poor_agents = model.agents.select(lambda a: a.wealth < 3)\n", " # When there is rich agents only have them give money to the poor agents\n", - " if len(rich_agents) > 0: \n", + " if len(rich_agents) > 0:\n", " rich_agents.shuffle_do(\"give_money\", poor_agents)\n", - " else: \n", - " poor_agents.shuffle_do(\"give_money\", poor_agents) " + " else:\n", + " poor_agents.shuffle_do(\"give_money\", poor_agents)" ] }, { @@ -1274,7 +1268,7 @@ "source": [ "model = MoneyModel(100)\n", "for _ in range(20):\n", - " model.step() \n", + " model.step()\n", "\n", "\n", "data = model.datacollector.get_agent_vars_dataframe()\n", @@ -1309,11 +1303,11 @@ " self.ethnicity = ethnicity\n", "\n", " def give_money(self, similars):\n", - " if self.wealth > 0: \n", + " if self.wealth > 0:\n", " other_agent = self.random.choice(similars)\n", " other_agent.wealth += 1\n", " self.wealth -= 1\n", - " \n", + "\n", "\n", "class MoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", @@ -1321,29 +1315,31 @@ " def __init__(self, n):\n", " super().__init__()\n", " self.num_agents = n\n", - " \n", + "\n", " # Create a list of our different ethnicities\n", " ethnicities = [\"Green\", \"Blue\", \"Mixed\"]\n", - " \n", + "\n", " # Create agents\n", " for _ in range(self.num_agents):\n", " a = MoneyAgent(self, self.random.choice(ethnicities))\n", "\n", " self.datacollector = mesa.DataCollector(\n", - " model_reporters={\"Gini\": compute_gini}, agent_reporters={\"Wealth\": \"wealth\", \"Ethnicity\":\"ethnicity\"}\n", + " model_reporters={\"Gini\": compute_gini},\n", + " agent_reporters={\"Wealth\": \"wealth\", \"Ethnicity\": \"ethnicity\"},\n", " )\n", - " \n", + "\n", " def step(self):\n", " self.datacollector.collect(self)\n", " # groupby returns a dictionary of the different ethnicities with a list of agents\n", " grouped_agents = model.agents.groupby(\"ethnicity\")\n", "\n", - " for ethnic, similars in grouped_agents: \n", - " if ethnic != \"Mixed\": \n", + " for ethnic, similars in grouped_agents:\n", + " if ethnic != \"Mixed\":\n", " similars.shuffle_do(\"give_money\", similars)\n", - " else: \n", - " similars.shuffle_do(\"give_money\", self.agents) # This allows mixed to trade with anyone \n", - "\n" + " else:\n", + " similars.shuffle_do(\n", + " \"give_money\", self.agents\n", + " ) # This allows mixed to trade with anyone" ] }, { @@ -1356,12 +1352,12 @@ "model = MoneyModel(100)\n", "for _ in range(20):\n", " model.step()\n", - " \n", + "\n", "# get the data\n", "data = model.datacollector.get_agent_vars_dataframe()\n", "# assign histogram colors\n", - "palette = {'Green': 'green', 'Blue': 'blue', 'Mixed': 'purple'} \n", - "sns.histplot(data=data, x='Wealth', hue='Ethnicity',discrete=True, palette=palette)\n", + "palette = {\"Green\": \"green\", \"Blue\": \"blue\", \"Mixed\": \"purple\"}\n", + "sns.histplot(data=data, x=\"Wealth\", hue=\"Ethnicity\", discrete=True, palette=palette)\n", "g.set(title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"number of agents\");" ] }, @@ -1445,7 +1441,8 @@ "\n", " def give_money(self):\n", " cellmates = self.model.grid.get_cell_list_contents([self.pos])\n", - " if len(cellmates) > 1 and self.wealth > 0:\n", + " cellmates.pop(cellmates.index(self))\n", + " if len(cellmates) > 0 and self.wealth > 0:\n", " other = self.random.choice(cellmates)\n", " other.wealth += 1\n", " self.wealth -= 1\n", @@ -1498,7 +1495,9 @@ { "cell_type": "markdown", "metadata": {}, - "source": "**note for Windows OS users:** If you are running this tutorial in Jupyter, make sure that you set `number_processes = 1` (single process). If `number_processes` is greater than 1, it is less straightforward to set up. For details on how to use multiprocessing on windows, see [multiprocessing's programming guidelines](https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming). " + "source": [ + "**note for Windows OS users:** If you are running this tutorial in Jupyter, make sure that you set `number_processes = 1` (single process). If `number_processes` is greater than 1, it is less straightforward to set up. For details on how to use multiprocessing on windows, see [multiprocessing's programming guidelines](https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming). " + ] }, { "cell_type": "code", @@ -1593,7 +1592,7 @@ "outputs": [], "source": [ "# Create a point plot with error bars\n", - "g = sns.pointplot(data=results_filtered, x=\"n\", y=\"Gini\", linestyle='None')\n", + "g = sns.pointplot(data=results_filtered, x=\"n\", y=\"Gini\", linestyle=\"None\")\n", "g.figure.set_size_inches(8, 4)\n", "g.set(\n", " xlabel=\"number of agents\",\n", @@ -1680,7 +1679,7 @@ }, "outputs": [], "source": [ - "params = {\"seed\":None,\"width\": 10, \"height\": 10, \"n\": [5, 10, 20, 40, 80]}\n", + "params = {\"seed\": None, \"width\": 10, \"height\": 10, \"n\": [5, 10, 20, 40, 80]}\n", "\n", "results_5s = mesa.batch_run(\n", " MoneyModel,\n", @@ -1754,8 +1753,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Challenge: Treat the seed as a parameter and see the impact on the Gini Coefficient. \n", - "# You can also plot the seeds against the Gini Coefficient by changing the \"hue\" parameter in sns.lineplot function. " + "# Challenge: Treat the seed as a parameter and see the impact on the Gini Coefficient.\n", + "# You can also plot the seeds against the Gini Coefficient by changing the \"hue\" parameter in sns.lineplot function." ] }, { diff --git a/mesa/experimental/__init__.py b/mesa/experimental/__init__.py index 42a510cb9c6..069e418aa27 100644 --- a/mesa/experimental/__init__.py +++ b/mesa/experimental/__init__.py @@ -1,5 +1,5 @@ """Experimental init.""" -from mesa.experimental import cell_space, devs +from mesa.experimental import cell_space, devs, mesa_signals -__all__ = ["cell_space", "devs"] +__all__ = ["cell_space", "devs", "mesa_signals"] diff --git a/mesa/experimental/mesa_signals/__init__.py b/mesa/experimental/mesa_signals/__init__.py new file mode 100644 index 00000000000..a3a0b8053ef --- /dev/null +++ b/mesa/experimental/mesa_signals/__init__.py @@ -0,0 +1,13 @@ +"""Functionality for Observables.""" + +from .mesa_signal import All, Computable, Computed, HasObservables, Observable +from .observable_collections import ObservableList + +__all__ = [ + "Observable", + "ObservableList", + "HasObservables", + "All", + "Computable", + "Computed", +] diff --git a/mesa/experimental/mesa_signals/mesa_signal.py b/mesa/experimental/mesa_signals/mesa_signal.py new file mode 100644 index 00000000000..18128ce3dc0 --- /dev/null +++ b/mesa/experimental/mesa_signals/mesa_signal.py @@ -0,0 +1,470 @@ +"""Core classes for Observables.""" + +from __future__ import annotations + +import contextlib +import functools +import weakref +from abc import ABC, abstractmethod +from collections import defaultdict, namedtuple +from collections.abc import Callable +from typing import Any + +from mesa.experimental.mesa_signals.signals_util import AttributeDict, create_weakref + +__all__ = ["Observable", "HasObservables", "All", "Computable"] + +_hashable_signal = namedtuple("_HashableSignal", "instance name") + +CURRENT_COMPUTED: Computed | None = None # the current Computed that is evaluating +PROCESSING_SIGNALS: set[tuple[str,]] = set() + + +class BaseObservable(ABC): + """Abstract base class for all Observables.""" + + @abstractmethod + def __init__(self, fallback_value=None): + """Initialize a BaseObservable.""" + super().__init__() + self.public_name: str + self.private_name: str + + # fixme can we make this an inner class enum? + # or some SignalTypes helper class? + # its even more complicated. Ideally you can define + # signal_types throughout the class hierarchy and they are just + # combined together. + # while we also want to make sure that any signal being emitted is valid for that class + self.signal_types: set = set() + self.fallback_value = fallback_value + + def __get__(self, instance: HasObservables, owner): + value = getattr(instance, self.private_name) + + if CURRENT_COMPUTED is not None: + # there is a computed dependent on this Observable, so let's add + # this Observable as a parent + CURRENT_COMPUTED._add_parent(instance, self.public_name, value) + + # fixme, this can be done more cleanly + # problem here is that we cannot use self (i.e., the observable), we need to add the instance as well + PROCESSING_SIGNALS.add(_hashable_signal(instance, self.public_name)) + + return value + + def __set_name__(self, owner: HasObservables, name: str): + self.public_name = name + self.private_name = f"_{name}" + # owner.register_observable(self) + + @abstractmethod + def __set__(self, instance: HasObservables, value): + # this only emits an on change signal, subclasses need to specify + # this in more detail + instance.notify( + self.public_name, + getattr(instance, self.private_name, self.fallback_value), + value, + "change", + ) + + def __str__(self): + return f"{self.__class__.__name__}: {self.public_name}" + + +class Observable(BaseObservable): + """Observable class.""" + + def __init__(self, fallback_value=None): + """Initialize an Observable.""" + super().__init__(fallback_value=fallback_value) + + self.signal_types: set = { + "change", + } + + def __set__(self, instance: HasObservables, value): # noqa D103 + if ( + CURRENT_COMPUTED is not None + and _hashable_signal(instance, self.public_name) in PROCESSING_SIGNALS + ): + raise ValueError( + f"cyclical dependency detected: Computed({CURRENT_COMPUTED.name}) tries to change " + f"{instance.__class__.__name__}.{self.public_name} while also being dependent it" + ) + + super().__set__(instance, value) # send the notify + setattr(instance, self.private_name, value) + + PROCESSING_SIGNALS.clear() # we have notified our children, so we can clear this out + + +class Computable(BaseObservable): + """A Computable that is depended on one or more Observables. + + .. code-block:: python + + class MyAgent(Agent): + wealth = Computable() + + def __init__(self, model): + super().__init__(model) + wealth = Computed(func, args, kwargs) + + """ + + # fixme, with new _register_observable thing + # we can do computed without a descriptor, but then you + # don't have attribute like access, you would need to do a call operation to get the value + + def __init__(self): + """Initialize a Computable.""" + super().__init__() + + # fixme have 2 signal: change and is_dirty? + self.signal_types: set = { + "change", + } + + def __get__(self, instance, owner): # noqa: D105 + computed = getattr(instance, self.private_name) + old_value = computed._value + + if CURRENT_COMPUTED is not None: + CURRENT_COMPUTED._add_parent(instance, self.public_name, old_value) + + new_value = computed() + + if new_value != old_value: + instance.notify( + self.public_name, + old_value, + new_value, + "change", + ) + return new_value + else: + return old_value + + def __set__(self, instance: HasObservables, value: Computed): # noqa D103 + if not isinstance(value, Computed): + raise ValueError("value has to be a Computable instance") + + setattr(instance, self.private_name, value) + value.name = self.public_name + value.owner = instance + getattr( + instance, self.public_name + ) # force evaluation of the computed to build the dependency graph + + +class Computed: + def __init__(self, func: Callable, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + self._is_dirty = True + self._first = True + self._value = None + self.name: str = "" # set by Computable + self.owner: HasObservables # set by Computable + + self.parents: weakref.WeakKeyDictionary[HasObservables, dict[str, Any]] = ( + weakref.WeakKeyDictionary() + ) + + def __str__(self): + return f"COMPUTED: {self.name}" + + def _set_dirty(self, signal): + if not self._is_dirty: + self._is_dirty = True + self.owner.notify(self.name, self._value, None, "change") + + def _add_parent( + self, parent: HasObservables, name: str, current_value: Any + ) -> None: + """Add a parent Observable. + + Args: + parent: the HasObservable instance to which the Observable belongs + name: the public name of the Observable + current_value: the current value of the Observable + + """ + parent.observe(name, All(), self._set_dirty) + + try: + self.parents[parent][name] = current_value + except KeyError: + self.parents[parent] = {name: current_value} + + def _remove_parents(self): + """Remove all parent Observables.""" + # we can unsubscribe from everything on each parent + for parent in self.parents: + parent.unobserve(All(), All(), self._set_dirty) + + def __call__(self): + global CURRENT_COMPUTED # noqa: PLW0603 + + if self._is_dirty: + changed = False + + if self._first: + # fixme might be a cleaner solution for this + # basically, we have no parents. + changed = True + self._first = False + + # we might be dirty but values might have changed + # back and forth in our parents so let's check to make sure we + # really need to recalculate + if not changed: + for parent in self.parents.keyrefs(): + # does parent still exist? + if parent := parent(): + # if yes, compare old and new values for all + # tracked observables on this parent + for name, old_value in self.parents[parent].items(): + new_value = getattr(parent, name) + if new_value != old_value: + changed = True + break # we need to recalculate + else: + # trick for breaking cleanly out of nested for loops + # see https://stackoverflow.com/questions/653509/breaking-out-of-nested-loops + continue + break + else: + # one of our parents no longer exists + changed = True + break + + if changed: + # the dependencies of the computable function might have changed + # so, we rebuilt + self._remove_parents() + + old = CURRENT_COMPUTED + CURRENT_COMPUTED = self + + try: + self._value = self.func(*self.args, **self.kwargs) + except Exception as e: + raise e + finally: + CURRENT_COMPUTED = old + + self._is_dirty = False + + return self._value + + +class All: + """Helper constant to subscribe to all Observables.""" + + def __init__(self): # noqa: D107 + self.name = "all" + + def __copy__(self): # noqa: D105 + return self + + def __deepcopy__(self, memo): # noqa: D105 + return self + + +class HasObservables: + """HasObservables class.""" + + # we can't use a weakset here because it does not handle bound methods correctly + # also, a list is faster for our use case + subscribers: dict[str, dict[str, list]] + observables: dict[str, set[str]] + + def __init__(self, *args, **kwargs) -> None: + """Initialize a HasObservables.""" + super().__init__(*args, **kwargs) + self.subscribers = defaultdict(functools.partial(defaultdict, list)) + self.observables = dict(descriptor_generator(self)) + + def _register_signal_emitter(self, name: str, signal_types: set[str]): + """Helper function to register an Observable. + + This method can be used to register custom signals that are emitted by + the class for a given attribute, but which cannot be covered by the Observable descriptor + + Args: + name: the name of the signal emitter + signal_types: the set of signals that might be emitted + + """ + self.observables[name] = signal_types + + def observe( + self, + name: str | All, + signal_type: str | All, + handler: Callable, + ): + """Subscribe to the Observable for signal_type. + + Args: + name: name of the Observable to subscribe to + signal_type: the type of signal on the Observable to subscribe to + handler: the handler to call + + Raises: + ValueError: if the Observable is not registered or if the Observable + does not emit the given signal_type + + """ + # fixme should name/signal_type also take a list of str? + if not isinstance(name, All): + if name not in self.observables: + raise ValueError( + f"you are trying to subscribe to {name}, but this Observable is not known" + ) + else: + names = [ + name, + ] + else: + names = self.observables.keys() + + for name in names: + if not isinstance(signal_type, All): + if signal_type not in self.observables[name]: + raise ValueError( + f"you are trying to subscribe to a signal of {signal_type} " + f"on Observable {name}, which does not emit this signal_type" + ) + else: + signal_types = [ + signal_type, + ] + else: + signal_types = self.observables[name] + + ref = create_weakref(handler) + for signal_type in signal_types: + self.subscribers[name][signal_type].append(ref) + + def unobserve(self, name: str | All, signal_type: str | All, handler: Callable): + """Unsubscribe to the Observable for signal_type. + + Args: + name: name of the Observable to unsubscribe from + signal_type: the type of signal on the Observable to unsubscribe to + handler: the handler that is unsubscribing + + """ + names = ( + [ + name, + ] + if not isinstance(name, All) + else self.observables.keys() + ) + + for name in names: + # we need to do this here because signal types might + # differ for name so for each name we need to check + if isinstance(signal_type, All): + signal_types = self.observables[name] + else: + signal_types = [ + signal_type, + ] + for signal_type in signal_types: + with contextlib.suppress(KeyError): + remaining = [] + for ref in self.subscribers[name][signal_type]: + if subscriber := ref(): # noqa: SIM102 + if subscriber != handler: + remaining.append(ref) + self.subscribers[name][signal_type] = remaining + + def clear_all_subscriptions(self, name: str | All): + """Clears all subscriptions for the observable . + + if name is All, all subscriptions are removed + + Args: + name: name of the Observable to unsubscribe for all signal types + + """ + if not isinstance(name, All): + with contextlib.suppress(KeyError): + del self.subscribers[name] + # ignore when unsubscribing to Observables that have no subscription + else: + self.subscribers = defaultdict(functools.partial(defaultdict, list)) + + def notify( + self, + observable: str, + old_value: Any, + new_value: Any, + signal_type: str, + **kwargs, + ): + """Emit a signal. + + Args: + observable: the public name of the observable emitting the signal + old_value: the old value of the observable + new_value: the new value of the observable + signal_type: the type of signal to emit + kwargs: additional keyword arguments to include in the signal + + """ + signal = AttributeDict( + name=observable, + old=old_value, + new=new_value, + owner=self, + type=signal_type, + **kwargs, + ) + + self._mesa_notify(signal) + + def _mesa_notify(self, signal: AttributeDict): + """Send out the signal. + + Args: + signal: the signal + + Notes: + signal must contain name and type attributes because this is how observers are stored. + + """ + # we put this into a helper method, so we can emit signals with other fields + # then the default ones in notify. + observable = signal.name + signal_type = signal.type + + # because we are using a list of subscribers + # we should update this list to subscribers that are still alive + observers = self.subscribers[observable][signal_type] + active_observers = [] + for observer in observers: + if active_observer := observer(): + active_observer(signal) + active_observers.append(observer) + # use iteration to also remove inactive observers + self.subscribers[observable][signal_type] = active_observers + + +def descriptor_generator(obj) -> [str, BaseObservable]: + """Yield the name and signal_types for each Observable defined on obj.""" + # we need to traverse the entire class hierarchy to properly get + # also observables defined in super classes + for base in type(obj).__mro__: + base_dict = vars(base) + + for entry in base_dict.values(): + if isinstance(entry, BaseObservable): + yield entry.public_name, entry.signal_types diff --git a/mesa/experimental/mesa_signals/observable_collections.py b/mesa/experimental/mesa_signals/observable_collections.py new file mode 100644 index 00000000000..1111a8e64ca --- /dev/null +++ b/mesa/experimental/mesa_signals/observable_collections.py @@ -0,0 +1,126 @@ +"""This module defines observable collections classes. + +Observable collections behave like Observable but then for collections. + + +""" + +from collections.abc import Iterable, MutableSequence +from typing import Any + +from .mesa_signal import BaseObservable, HasObservables + +__all__ = [ + "ObservableList", +] + + +class ObservableList(BaseObservable): + """An ObservableList that emits signals on changes to the underlying list.""" + + def __init__(self): + """Initialize the ObservableList.""" + super().__init__() + self.signal_types: set = {"remove", "replace", "change", "insert", "append"} + self.fallback_value = [] + + def __set__(self, instance: "HasObservables", value: Iterable): + """Set the value of the descriptor attribute. + + Args: + instance: The instance on which to set the attribute. + value: The value to set the attribute to. + + """ + super().__set__(instance, value) + setattr( + instance, + self.private_name, + SignalingList(value, instance, self.public_name), + ) + + +class SignalingList(MutableSequence[Any]): + """A basic lists that emits signals on changes.""" + + __slots__ = ["owner", "name", "data"] + + def __init__(self, iterable: Iterable, owner: HasObservables, name: str): + """Initialize a SignalingList. + + Args: + iterable: initial values in the list + owner: the HasObservables instance on which this list is defined + name: the attribute name to which this list is assigned + + """ + self.owner: HasObservables = owner + self.name: str = name + self.data = list(iterable) + + def __setitem__(self, index: int, value: Any) -> None: + """Set item to index. + + Args: + index: the index to set item to + value: the item to set + + """ + old_value = self.data[index] + self.data[index] = value + self.owner.notify(self.name, old_value, value, "replace", index=index) + + def __delitem__(self, index: int) -> None: + """Delete item at index. + + Args: + index: The index of the item to remove + + """ + old_value = self.data + del self.data[index] + self.owner.notify(self.name, old_value, None, "remove", index=index) + + def __getitem__(self, index) -> Any: + """Get item at index. + + Args: + index: The index of the item to retrieve + + Returns: + the item at index + """ + return self.data[index] + + def __len__(self) -> int: + """Return the length of the list.""" + return len(self.data) + + def insert(self, index, value): + """Insert value at index. + + Args: + index: the index to insert value into + value: the value to insert + + """ + self.data.insert(index, value) + self.owner.notify(self.name, None, value, "insert", index=index) + + def append(self, value): + """Insert value at index. + + Args: + index: the index to insert value into + value: the value to insert + + """ + index = len(self.data) + self.data.append(value) + self.owner.notify(self.name, None, value, "append", index=index) + + def __str__(self): + return self.data.__str__() + + def __repr__(self): + return self.data.__repr__() diff --git a/mesa/experimental/mesa_signals/signals_util.py b/mesa/experimental/mesa_signals/signals_util.py new file mode 100644 index 00000000000..f5152452cea --- /dev/null +++ b/mesa/experimental/mesa_signals/signals_util.py @@ -0,0 +1,43 @@ +"""helper functions and classes for mesa signals.""" + +import weakref + +__all__ = [ + "AttributeDict", + "create_weakref", +] + + +class AttributeDict(dict): + """A dict with attribute like access. + + Each value can be accessed as if it were an attribute with its key as attribute name + + """ + + # I want our signals to act like traitlet signals, so this is inspired by trailets Bunch + # and some stack overflow posts. + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __getattr__(self, key): # noqa: D105 + try: + return self.__getitem__(key) + except KeyError as e: + # we need to go from key error to attribute error + raise AttributeError(key) from e + + def __dir__(self): # noqa: D105 + # allows us to easily access all defined attributes + names = dir({}) + names.extend(self.keys()) + return names + + +def create_weakref(item, callback=None): + """Helper function to create a correct weakref for any item.""" + if hasattr(item, "__self__"): + ref = weakref.WeakMethod(item, callback) + else: + ref = weakref.ref(item, callback) + return ref diff --git a/tests/test_mesa_signals.py b/tests/test_mesa_signals.py new file mode 100644 index 00000000000..a41afdf46ce --- /dev/null +++ b/tests/test_mesa_signals.py @@ -0,0 +1,290 @@ +"""Tests for mesa_signals.""" + +from unittest.mock import Mock + +import pytest + +from mesa import Agent, Model +from mesa.experimental.mesa_signals import ( + All, + Computable, + Computed, + HasObservables, + Observable, + ObservableList, +) +from mesa.experimental.mesa_signals.signals_util import AttributeDict + + +def test_observables(): + """Test Observable.""" + + class MyAgent(Agent, HasObservables): + some_attribute = Observable() + + def __init__(self, model, value): + super().__init__(model) + some_attribute = value # noqa: F841 + + handler = Mock() + + model = Model(seed=42) + agent = MyAgent(model, 10) + agent.observe("some_attribute", "change", handler) + + agent.some_attribute = 10 + handler.assert_called_once() + + +def test_HasObservables(): + """Test Observable.""" + + class MyAgent(Agent, HasObservables): + some_attribute = Observable() + some_other_attribute = Observable() + + def __init__(self, model, value): + super().__init__(model) + some_attribute = value # noqa: F841 + some_other_attribute = 5 # noqa: F841 + + handler = Mock() + + model = Model(seed=42) + agent = MyAgent(model, 10) + agent.observe("some_attribute", "change", handler) + + subscribers = {entry() for entry in agent.subscribers["some_attribute"]["change"]} + assert handler in subscribers + + agent.unobserve("some_attribute", "change", handler) + subscribers = {entry() for entry in agent.subscribers["some_attribute"]["change"]} + assert handler not in subscribers + + subscribers = { + entry() for entry in agent.subscribers["some_other_attribute"]["change"] + } + assert len(subscribers) == 0 + + # testing All() + agent.observe(All(), "change", handler) + + for attr in ["some_attribute", "some_other_attribute"]: + subscribers = {entry() for entry in agent.subscribers[attr]["change"]} + assert handler in subscribers + + agent.unobserve(All(), "change", handler) + for attr in ["some_attribute", "some_other_attribute"]: + subscribers = {entry() for entry in agent.subscribers[attr]["change"]} + assert handler not in subscribers + assert len(subscribers) == 0 + + # testing for clear_all_subscriptions + nr_observers = 3 + handlers = [Mock() for _ in range(nr_observers)] + for handler in handlers: + agent.observe("some_attribute", "change", handler) + agent.observe("some_other_attribute", "change", handler) + + subscribers = {entry() for entry in agent.subscribers["some_attribute"]["change"]} + assert len(subscribers) == nr_observers + + agent.clear_all_subscriptions("some_attribute") + subscribers = {entry() for entry in agent.subscribers["some_attribute"]["change"]} + assert len(subscribers) == 0 + + subscribers = { + entry() for entry in agent.subscribers["some_other_attribute"]["change"] + } + assert len(subscribers) == 3 + + agent.clear_all_subscriptions(All()) + subscribers = {entry() for entry in agent.subscribers["some_attribute"]["change"]} + assert len(subscribers) == 0 + + subscribers = { + entry() for entry in agent.subscribers["some_other_attribute"]["change"] + } + assert len(subscribers) == 0 + + # test raises + with pytest.raises(ValueError): + agent.observe("some_attribute", "unknonw_signal", handler) + + with pytest.raises(ValueError): + agent.observe("unknonw_attribute", "change", handler) + + +def test_ObservableList(): + """Test ObservableList.""" + + class MyAgent(Agent, HasObservables): + my_list = ObservableList() + + def __init__( + self, + model, + ): + super().__init__(model) + self.my_list = [] + + model = Model(seed=42) + agent = MyAgent(model) + + assert len(agent.my_list) == 0 + + # add + handler = Mock() + agent.observe("my_list", "append", handler) + + agent.my_list.append(1) + assert len(agent.my_list) == 1 + handler.assert_called_once() + handler.assert_called_once_with( + AttributeDict( + name="my_list", new=1, old=None, type="append", index=0, owner=agent + ) + ) + agent.unobserve("my_list", "append", handler) + + # remove + handler = Mock() + agent.observe("my_list", "remove", handler) + + agent.my_list.remove(1) + assert len(agent.my_list) == 0 + handler.assert_called_once() + + agent.unobserve("my_list", "remove", handler) + + # overwrite the existing list + a_list = [1, 2, 3, 4, 5] + handler = Mock() + agent.observe("my_list", "change", handler) + agent.my_list = a_list + assert len(agent.my_list) == len(a_list) + handler.assert_called_once() + + agent.my_list = a_list + assert len(agent.my_list) == len(a_list) + handler.assert_called() + agent.unobserve("my_list", "change", handler) + + # pop + handler = Mock() + agent.observe("my_list", "remove", handler) + + index = 4 + entry = agent.my_list.pop(index) + assert entry == a_list.pop(index) + assert len(agent.my_list) == len(a_list) + handler.assert_called_once() + agent.unobserve("my_list", "remove", handler) + + # insert + handler = Mock() + agent.observe("my_list", "insert", handler) + agent.my_list.insert(0, 5) + handler.assert_called() + agent.unobserve("my_list", "insert", handler) + + # overwrite + handler = Mock() + agent.observe("my_list", "replace", handler) + agent.my_list[0] = 10 + assert agent.my_list[0] == 10 + handler.assert_called_once() + agent.unobserve("my_list", "replace", handler) + + # combine two lists + handler = Mock() + agent.observe("my_list", "append", handler) + a_list = [1, 2, 3, 4, 5] + agent.my_list = a_list + assert len(agent.my_list) == len(a_list) + agent.my_list += a_list + assert len(agent.my_list) == 2 * len(a_list) + handler.assert_called() + + # some more non signalling functionality tests + assert 5 in agent.my_list + assert agent.my_list.index(5) == 4 + + +def test_AttributeDict(): + """Test AttributeDict.""" + + class MyAgent(Agent, HasObservables): + some_attribute = Observable() + + def __init__(self, model, value): + super().__init__(model) + self.some_attribute = value + + def on_change(signal): + assert signal.name == "some_attribute" + assert signal.type == "change" + assert signal.old == 10 + assert signal.new == 5 + assert signal.owner == agent + + items = dir(signal) + for entry in ["name", "type", "old", "new", "owner"]: + assert entry in items + + model = Model(seed=42) + agent = MyAgent(model, 10) + agent.observe("some_attribute", "change", on_change) + agent.some_attribute = 5 + + +def test_Computable(): + """Test Computable and Computed.""" + + class MyAgent(Agent, HasObservables): + some_attribute = Computable() + some_other_attribute = Observable() + + def __init__(self, model, value): + super().__init__(model) + self.some_other_attribute = value + self.some_attribute = Computed(lambda x: x.some_other_attribute * 2, self) + + model = Model(seed=42) + agent = MyAgent(model, 10) + assert agent.some_attribute == 20 + + handler = Mock() + agent.observe("some_attribute", "change", handler) + agent.some_other_attribute = 9 # we change the dependency of computed + handler.assert_called_once() + agent.unobserve("some_attribute", "change", handler) + + handler = Mock() + agent.observe("some_attribute", "change", handler) + assert ( + agent.some_attribute == 18 + ) # this forces a re-evaluation of the value of computed + handler.assert_called_once() # and so, our change handler should be called + agent.unobserve("some_attribute", "change", handler) + + # cyclical dependencies + def computed_func(agent): + # this creates a cyclical dependency + # our computed is dependent on o1, but also modifies o1 + agent.o1 = agent.o1 - 1 + + class MyAgent(Agent, HasObservables): + c1 = Computable() + o1 = Observable() + + def __init__(self, model, value): + super().__init__(model) + self.o1 = value + self.c1 = Computed(computed_func, self) + + model = Model(seed=42) + with pytest.raises(ValueError): + MyAgent(model, 10) + + # parents disappearing