diff --git a/mesa/examples/advanced/wolf_sheep/agents.py b/mesa/examples/advanced/wolf_sheep/agents.py index b14963442c2..35ad09e66a5 100644 --- a/mesa/examples/advanced/wolf_sheep/agents.py +++ b/mesa/examples/advanced/wolf_sheep/agents.py @@ -1,8 +1,26 @@ from mesa.discrete_space import CellAgent, FixedAgent +from mesa.experimental.mesa_signals import ( + Computable, + Computed, + ContinuousObservable, + HasObservables, + Observable, +) -class Animal(CellAgent): - """The base animal class.""" +class Animal(CellAgent, HasObservables): + """The base animal class with reactive energy management.""" + + # Energy depletes continuously over time + energy = ContinuousObservable( + initial_value=8.0, rate_func=lambda value, elapsed, agent: -agent.metabolic_rate + ) + + # Computed property: animal is hungry when energy is low + is_hungry = Computable() + + # Computed property: animal can reproduce when energy is sufficient + can_reproduce = Computable() def __init__( self, model, energy=8, p_reproduce=0.04, energy_from_food=4, cell=None @@ -17,14 +35,40 @@ def __init__( cell: Cell in which the animal starts """ super().__init__(model) + + # Set base metabolic rate (energy loss per time unit when idle) + self.metabolic_rate = 0.5 + + # Initialize energy (triggers continuous depletion) self.energy = energy self.p_reproduce = p_reproduce self.energy_from_food = energy_from_food self.cell = cell + # Set up computed properties + self.is_hungry = Computed(lambda: self.energy < self.energy_from_food * 2) + self.can_reproduce = Computed(lambda: self.energy > self.energy_from_food * 4) + + # Register threshold: die when energy reaches zero + self.add_threshold("energy", 0.0, self._on_energy_depleted) + + # Register threshold: become critically hungry at 25% of starting energy + self.add_threshold("energy", energy * 0.25, self._on_critical_hunger) + + def _on_energy_depleted(self, signal): + """Called when energy crosses zero - animal dies.""" + if signal.direction == "down": # Only trigger on downward crossing + self.remove() + + def _on_critical_hunger(self, signal): + """Called when energy becomes critically low.""" + if signal.direction == "down": + # Increase metabolic efficiency when starving (survival mode) + self.metabolic_rate *= 0.8 + def spawn_offspring(self): """Create offspring by splitting energy and creating new instance.""" - self.energy /= 2 + self.energy /= 2 # This updates the continuous observable self.__class__( self.model, self.energy, @@ -35,26 +79,31 @@ def spawn_offspring(self): def feed(self): """Abstract method to be implemented by subclasses.""" + raise NotImplementedError def step(self): """Execute one step of the animal's behavior.""" - # Move to random neighboring cell + # Move to neighboring cell (uses more energy than standing still) + self.metabolic_rate = 1.0 # Movement costs more energy self.move() - self.energy -= 1 - # Try to feed self.feed() - # Handle death and reproduction - if self.energy < 0: - self.remove() - elif self.random.random() < self.p_reproduce: + # Return to resting metabolic rate + self.metabolic_rate = 0.5 + + # Reproduce if conditions are met (using computed property) + if self.can_reproduce and self.random.random() < self.p_reproduce: self.spawn_offspring() class Sheep(Animal): - """A sheep that walks around, reproduces (asexually) and gets eaten.""" + """A sheep that walks around, reproduces and gets eaten. + + Sheep prefer cells with grass and avoid wolves. They gain energy by + eating grass, which continuously depletes over time. + """ def feed(self): """If possible, eat grass at current location.""" @@ -62,6 +111,7 @@ def feed(self): obj for obj in self.cell.agents if isinstance(obj, GrassPatch) ) if grass_patch.fully_grown: + # Eating gives instant energy boost self.energy += self.energy_from_food grass_patch.fully_grown = False @@ -70,64 +120,82 @@ def move(self): cells_without_wolves = self.cell.neighborhood.select( lambda cell: not any(isinstance(obj, Wolf) for obj in cell.agents) ) - # If all surrounding cells have wolves, stay put + + # If all surrounding cells have wolves, stay put (fear overrides hunger) if len(cells_without_wolves) == 0: return - # Among safe cells, prefer those with grown grass - cells_with_grass = cells_without_wolves.select( - lambda cell: any( - isinstance(obj, GrassPatch) and obj.fully_grown for obj in cell.agents + # If critically hungry, prioritize grass over safety + if self.is_hungry: # Using computed property + cells_with_grass = cells_without_wolves.select( + lambda cell: any( + isinstance(obj, GrassPatch) and obj.fully_grown + for obj in cell.agents + ) ) - ) - # Move to a cell with grass if available, otherwise move to any safe cell - target_cells = ( - cells_with_grass if len(cells_with_grass) > 0 else cells_without_wolves - ) + # Move to grass if available, otherwise any safe cell + target_cells = ( + cells_with_grass if len(cells_with_grass) > 0 else cells_without_wolves + ) + else: + # Not hungry - just avoid wolves + target_cells = cells_without_wolves + self.cell = target_cells.select_random_cell() class Wolf(Animal): - """A wolf that walks around, reproduces (asexually) and eats sheep.""" + """A wolf that walks around, reproduces and eats sheep. + + Wolves are more efficient predators, with higher base energy and + metabolic rate. They actively hunt sheep and gain substantial energy + from successful kills. + """ + + def __init__( + self, model, energy=20, p_reproduce=0.05, energy_from_food=20, cell=None + ): + """Initialize a wolf with higher energy needs than sheep.""" + super().__init__(model, energy, p_reproduce, energy_from_food, cell) + # Wolves have higher metabolic rate (they're larger predators) + self.metabolic_rate = 1.0 def feed(self): """If possible, eat a sheep at current location.""" sheep = [obj for obj in self.cell.agents if isinstance(obj, Sheep)] - if sheep: # If there are any sheep present + if sheep: # Successful hunt sheep_to_eat = self.random.choice(sheep) + # Eating gives instant energy boost self.energy += self.energy_from_food sheep_to_eat.remove() def move(self): """Move to a neighboring cell, preferably one with sheep.""" - cells_with_sheep = self.cell.neighborhood.select( - lambda cell: any(isinstance(obj, Sheep) for obj in cell.agents) - ) - target_cells = ( - cells_with_sheep if len(cells_with_sheep) > 0 else self.cell.neighborhood - ) - self.cell = target_cells.select_random_cell() + # When hungry, actively hunt for sheep + if self.is_hungry: # Using computed property + cells_with_sheep = self.cell.neighborhood.select( + lambda cell: any(isinstance(obj, Sheep) for obj in cell.agents) + ) + target_cells = ( + cells_with_sheep + if len(cells_with_sheep) > 0 + else self.cell.neighborhood + ) + else: + # When not hungry, wander randomly (conserve energy) + target_cells = self.cell.neighborhood + self.cell = target_cells.select_random_cell() -class GrassPatch(FixedAgent): - """A patch of grass that grows at a fixed rate and can be eaten by sheep.""" - @property - def fully_grown(self): - """Whether the grass patch is fully grown.""" - return self._fully_grown +class GrassPatch(FixedAgent, HasObservables): + """A patch of grass that grows at a fixed rate and can be eaten by sheep. - @fully_grown.setter - def fully_grown(self, value: bool) -> None: - """Set grass growth state and schedule regrowth if eaten.""" - self._fully_grown = value + Grass growth is modeled as a continuous process with a fixed regrowth time. + """ - if not value: # If grass was just eaten - self.model.simulator.schedule_event_relative( - setattr, - self.grass_regrowth_time, - function_args=[self, "fully_grown", True], - ) + # Observable: grass growth state + fully_grown = Observable() def __init__(self, model, countdown, grass_regrowth_time, cell): """Create a new patch of grass. @@ -139,12 +207,25 @@ def __init__(self, model, countdown, grass_regrowth_time, cell): cell: Cell to which this grass patch belongs """ super().__init__(model) - self._fully_grown = countdown == 0 + + self.fully_grown = countdown == 0 self.grass_regrowth_time = grass_regrowth_time self.cell = cell + # Listen for when grass gets eaten, schedule regrowth + self.observe("fully_grown", "change", self._on_growth_change) + # Schedule initial growth if not fully grown if not self.fully_grown: + self.model.simulator.schedule_event_relative(self._regrow, countdown) + + def _on_growth_change(self, signal): + """React to grass being eaten - schedule regrowth.""" + if signal.new is False: # Grass was just eaten self.model.simulator.schedule_event_relative( - setattr, countdown, function_args=[self, "fully_grown", True] + self._regrow, self.grass_regrowth_time ) + + def _regrow(self): + """Regrow the grass patch.""" + self.fully_grown = True diff --git a/mesa/examples/advanced/wolf_sheep/model.py b/mesa/examples/advanced/wolf_sheep/model.py index e93a0bffa4a..674e635f7ea 100644 --- a/mesa/examples/advanced/wolf_sheep/model.py +++ b/mesa/examples/advanced/wolf_sheep/model.py @@ -2,6 +2,8 @@ Wolf-Sheep Predation Model ================================ +Enhanced version with continuous energy depletion and reactive behaviors. + Replication of the model found in NetLogo: Wilensky, U. (1997). NetLogo Wolf Sheep Predation model. http://ccl.northwestern.edu/netlogo/models/WolfSheepPredation. @@ -21,11 +23,16 @@ class WolfSheep(Model): """Wolf-Sheep Predation Model. - A model for simulating wolf and sheep (predator-prey) ecosystem modelling. + A model for simulating wolf and sheep (predator-prey) ecosystem with: + - Continuous energy depletion over time + - Reactive behaviors based on hunger levels + - Threshold-triggered events (death, starvation mode) + - Computed properties for decision making """ description = ( - "A model for simulating wolf and sheep (predator-prey) ecosystem modelling." + "A model for simulating wolf and sheep (predator-prey) ecosystem modelling " + "with continuous energy dynamics and reactive behaviors." ) def __init__( @@ -55,12 +62,13 @@ def __init__( wolf_gain_from_food: Energy a wolf gains from eating a sheep grass: Whether to have the sheep eat grass for energy grass_regrowth_time: How long it takes for a grass patch to regrow - once it is eaten sheep_gain_from_food: Energy sheep gain from grass, if enabled seed: Random seed simulator: ABMSimulator instance for event scheduling """ super().__init__(seed=seed) + + # Initialize time-based simulator for continuous energy dynamics self.simulator = simulator self.simulator.setup(self) @@ -77,11 +85,24 @@ def __init__( random=self.random, ) - # Set up data collection + # Set up data collection (tracks observable changes automatically) model_reporters = { "Wolves": lambda m: len(m.agents_by_type[Wolf]), "Sheep": lambda m: len(m.agents_by_type[Sheep]), + "Avg Wolf Energy": lambda m: ( + sum(w.energy for w in m.agents_by_type[Wolf]) + / len(m.agents_by_type[Wolf]) + if len(m.agents_by_type[Wolf]) > 0 + else 0 + ), + "Avg Sheep Energy": lambda m: ( + sum(s.energy for s in m.agents_by_type[Sheep]) + / len(m.agents_by_type[Sheep]) + if len(m.agents_by_type[Sheep]) > 0 + else 0 + ), } + if grass: model_reporters["Grass"] = lambda m: len( m.agents_by_type[GrassPatch].select(lambda a: a.fully_grown) @@ -89,7 +110,7 @@ def __init__( self.datacollector = DataCollector(model_reporters) - # Create sheep: + # Create sheep with random initial energy Sheep.create_agents( self, initial_sheep, @@ -98,7 +119,8 @@ def __init__( energy_from_food=sheep_gain_from_food, cell=self.random.choices(self.grid.all_cells.cells, k=initial_sheep), ) - # Create Wolves: + + # Create wolves with random initial energy Wolf.create_agents( self, initial_wolves, @@ -123,10 +145,15 @@ def __init__( self.datacollector.collect(self) def step(self): - """Execute one step of the model.""" - # First activate all sheep, then all wolves, both in random order + """Execute one step of the model. + + Energy continuously depletes between steps via ContinuousObservable. + This step method only triggers agent decisions and actions. + """ + # Activate all sheep, then all wolves, both in random order + # Their energy has been continuously depleting since last step self.agents_by_type[Sheep].shuffle_do("step") self.agents_by_type[Wolf].shuffle_do("step") - # Collect data + # Collect data (automatically captures current energy levels) self.datacollector.collect(self) diff --git a/mesa/experimental/mesa_signals/__init__.py b/mesa/experimental/mesa_signals/__init__.py index 39fe83e6495..0036b404dbd 100644 --- a/mesa/experimental/mesa_signals/__init__.py +++ b/mesa/experimental/mesa_signals/__init__.py @@ -10,13 +10,21 @@ when modified. """ -from .mesa_signal import All, Computable, Computed, HasObservables, Observable +from .mesa_signal import ( + All, + Computable, + Computed, + ContinuousObservable, + HasObservables, + Observable, +) from .observable_collections import ObservableList __all__ = [ "All", "Computable", "Computed", + "ContinuousObservable", "HasObservables", "Observable", "ObservableList", diff --git a/mesa/experimental/mesa_signals/mesa_signal.py b/mesa/experimental/mesa_signals/mesa_signal.py index d240260320f..a632ba75d92 100644 --- a/mesa/experimental/mesa_signals/mesa_signal.py +++ b/mesa/experimental/mesa_signals/mesa_signal.py @@ -27,7 +27,7 @@ from mesa.experimental.mesa_signals.signals_util import AttributeDict, create_weakref -__all__ = ["All", "Computable", "HasObservables", "Observable"] +__all__ = ["All", "Computable", "ContinuousObservable", "HasObservables", "Observable"] _hashable_signal = namedtuple("_HashableSignal", "instance name") @@ -472,6 +472,189 @@ def _mesa_notify(self, signal: AttributeDict): # use iteration to also remove inactive observers self.subscribers[observable][signal_type] = active_observers + def add_threshold(self, observable_name: str, threshold: float, callback: Callable): + """Convenience method for adding thresholds.""" + obs = getattr(type(self), observable_name) + if not isinstance(obs, ContinuousObservable): + raise ValueError(f"{observable_name} is not a ContinuousObservable") + + # Get the instance's ContinuousState + state = getattr(self, obs.private_name, None) + if state is None: + # State not yet created - will be created on first access/set + # We need to ensure the observable is initialized first + _ = getattr(self, observable_name) # Trigger initialization + state = getattr(self, obs.private_name) + + # Add threshold to the instance's state + if threshold not in state.thresholds: + state.thresholds[threshold] = set() + + # Add callback to this threshold's callback set + state.thresholds[threshold].add(callback) + + # Subscribe to the threshold_crossed signal + # Check if callback is already subscribed to avoid duplicates + existing_subscribers = self.subscribers.get(observable_name, {}).get( + "threshold_crossed", [] + ) + already_subscribed = any( + ref() == callback for ref in existing_subscribers if ref() is not None + ) + + if not already_subscribed: + self.observe(observable_name, "threshold_crossed", callback) + + +class ContinuousObservable(Observable): + """An Observable that changes continuously over time.""" + + def __init__(self, initial_value: float, rate_func: Callable): + """Initialize a ContinuousObservable.""" + super().__init__(fallback_value=initial_value) + self.signal_types.add("threshold_crossed") + self._rate_func = rate_func + + def __set__(self, instance: HasObservables, value): + """Set the value, ensuring we store a ContinuousState.""" + # Get or create state + state = getattr(instance, self.private_name, None) + + if state is None: + # First time - create ContinuousState + state = ContinuousState( + value=float(value), + last_update=self._get_time(instance), + rate_func=self._rate_func, + ) + setattr(instance, self.private_name, state) + else: + # Update existing - just change the value and reset timestamp + old_value = state.value + state.value = float(value) + state.last_update = self._get_time(instance) + + # Notify changes + instance.notify(self.public_name, old_value, state.value, "change") + + # Check thresholds + for threshold, direction in state.check_thresholds(old_value, state.value): + instance.notify( + self.public_name, + old_value, + state.value, + "threshold_crossed", + threshold=threshold, + direction=direction, + ) + + def __get__(self, instance: HasObservables, owner): + """Lazy evaluation - compute current value based on elapsed time.""" + if instance is None: + return self + + # Get stored state + state = getattr(instance, self.private_name, None) + if state is None: + # First access - initialize + # Use simulator time if available, otherwise fall back to steps + current_time = self._get_time(instance) + state = ContinuousState( + value=self.fallback_value, + last_update=current_time, + rate_func=self._rate_func, + ) + setattr(instance, self.private_name, state) + + # Calculate new value based on time + current_time = self._get_time(instance) + elapsed = current_time - state.last_update + + if elapsed > 0: + old_value = state.value + new_value = state.calculate(elapsed, instance) + + # Check thresholds + crossed = state.check_thresholds(old_value, new_value) + + # Update stored state + state.value = new_value + state.last_update = current_time + + # Emit signals + if new_value != old_value: + instance.notify(self.public_name, old_value, new_value, "change") + + for threshold, direction in crossed: + instance.notify( + self.public_name, + old_value, + new_value, + "threshold_crossed", + threshold=threshold, + direction=direction, + ) + + # Register dependency if inside a Computed + if CURRENT_COMPUTED is not None: + CURRENT_COMPUTED._add_parent(instance, self.public_name, state.value) + PROCESSING_SIGNALS.add(_hashable_signal(instance, self.public_name)) + + return state.value + + # TODO: A universal truth for time should be implemented structurally in Mesa. See https://github.com/projectmesa/mesa/discussions/2228 + def _get_time(self, instance): + """Get current time from model, trying multiple sources.""" + model = instance.model + + # Try simulator time first (for DEVS models) + if hasattr(model, "simulator") and hasattr(model.simulator, "time"): + return model.simulator.time + + # Fall back to model.time if it exists + if hasattr(model, "time"): + return model.time + + # Last resort: use steps as a proxy for time + return float(model.steps) + + +class ContinuousState: + """Internal state tracker for continuous observables.""" + + __slots__ = ["last_update", "rate_func", "thresholds", "value"] + + def __init__(self, value: float, last_update: float, rate_func: Callable): + self.value = value + self.last_update = last_update + self.rate_func = rate_func + self.thresholds = {} # {threshold_value: set(callbacks)} + + def calculate(self, elapsed: float, instance: Any) -> float: + """Calculate new value based on elapsed time. + + Uses simple linear integration for now. Could be extended + to support more complex integration methods. + """ + rate = self.rate_func(self.value, elapsed, instance) + return self.value + (rate * elapsed) + + def check_thresholds(self, old_value: float, new_value: float) -> list: + """Check if any thresholds were crossed. + + Returns: + List of (threshold_value, direction) tuples for crossed thresholds + """ + crossed = [] + for threshold in self.thresholds: + # Crossed upward + if old_value < threshold <= new_value: + crossed.append((threshold, "up")) + # Crossed downward + elif new_value <= threshold < old_value: + crossed.append((threshold, "down")) + return crossed + def descriptor_generator(obj) -> [str, BaseObservable]: """Yield the name and signal_types for each Observable defined on obj.""" diff --git a/tests/test_continuous_observables.py b/tests/test_continuous_observables.py new file mode 100644 index 00000000000..dfd86b4ef58 --- /dev/null +++ b/tests/test_continuous_observables.py @@ -0,0 +1,880 @@ +"""Tests for continuous observables in mesa_signals.""" + +from unittest.mock import Mock + +import numpy as np + +from mesa import Agent, Model +from mesa.experimental.devs import ABMSimulator, DEVSimulator +from mesa.experimental.mesa_signals import ( + Computable, + Computed, + ContinuousObservable, + HasObservables, +) + + +class SimpleModel(Model): + """Simple model with time tracking for testing.""" + + def __init__(self, seed=None): + """Initialize the model.""" + super().__init__(seed=seed) + self.simulator = DEVSimulator() + self.simulator.setup(self) + + +def test_continuous_observable_basic(): + """Test basic ContinuousObservable functionality.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, + rate_func=lambda value, elapsed, agent: -1.0, # Constant depletion + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + # Initial value + assert agent.energy == 100.0 + + # Schedule an event to check energy later + def check_energy(): + assert agent.energy == 90.0 # 100 - (1.0 * 10) + + model.simulator.schedule_event_absolute(check_energy, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_variable_rate(): + """Test ContinuousObservable with variable rate function.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, + rate_func=lambda value, elapsed, agent: -agent.metabolic_rate, + ) + + def __init__(self, model): + super().__init__(model) + self.metabolic_rate = 1.0 + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + def check_first(): + assert agent.energy == 90.0 + # Change metabolic rate + agent.metabolic_rate = 2.0 + + def check_second(): + assert agent.energy == 80.0 # 90 - (2.0 * 5) + + model.simulator.schedule_event_absolute(check_first, 10.0) + model.simulator.schedule_event_absolute(check_second, 15.0) + model.simulator.run_until(15.0) + + +def test_continuous_observable_manual_set(): + """Test manually setting a ContinuousObservable value.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + def check_and_eat(): + assert agent.energy == 90.0 + # Manually increase energy (e.g., eating) + agent.energy = 120.0 + assert agent.energy == 120.0 + + def check_after_eating(): + assert agent.energy == 110.0 # 120 - (1.0 * 10) + + model.simulator.schedule_event_absolute(check_and_eat, 10.0) + model.simulator.schedule_event_absolute(check_after_eating, 20.0) + model.simulator.run_until(20.0) + + +def test_continuous_observable_change_signal(): + """Test that change signals are emitted correctly.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + handler = Mock() + agent.observe("energy", "change", handler) + + def check_signal(): + _ = agent.energy # Access triggers recalculation + + handler.assert_called_once() + call_args = handler.call_args[0][0] + assert call_args.name == "energy" + assert call_args.old == 100.0 + assert call_args.new == 90.0 + assert call_args.type == "change" + + model.simulator.schedule_event_absolute(check_signal, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_no_change_no_signal(): + """Test that no signal is emitted when value doesn't change.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, + rate_func=lambda value, elapsed, agent: 0.0, # No change + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + handler = Mock() + agent.observe("energy", "change", handler) + + def check_no_signal(): + _ = agent.energy + # Should not call handler since value didn't change + handler.assert_not_called() + + model.simulator.schedule_event_absolute(check_no_signal, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_threshold_crossing_down(): + """Test threshold crossing detection (downward).""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + self.low_energy_triggered = False + + def on_low_energy(self, signal): + if signal.direction == "down": + self.low_energy_triggered = True + + model = SimpleModel() + agent = MyAgent(model) + + # Register threshold at 50 + agent.add_threshold("energy", 50.0, agent.on_low_energy) + + def check_threshold(): + _ = agent.energy # Should be 40.0, crossed 50.0 + assert agent.low_energy_triggered + + model.simulator.schedule_event_absolute(check_threshold, 60.0) + model.simulator.run_until(60.0) + + +def test_continuous_observable_threshold_crossing_up(): + """Test threshold crossing detection (upward).""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=40.0, rate_func=lambda value, elapsed, agent: 1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 40.0 + self.recharged = False + + def on_recharged(self, signal): + if signal.direction == "up": + self.recharged = True + + model = SimpleModel() + agent = MyAgent(model) + + # Register threshold at 50 + agent.add_threshold("energy", 50.0, agent.on_recharged) + + def check_threshold(): + _ = agent.energy # Should be 60.0, crossed 50.0 + assert agent.recharged + + model.simulator.schedule_event_absolute(check_threshold, 20.0) + model.simulator.run_until(20.0) + + +def test_continuous_observable_multiple_thresholds(): + """Test multiple threshold crossings.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + self.crossings = [] + + def on_threshold(self, signal): + self.crossings.append((signal.threshold, signal.direction)) + + model = SimpleModel() + agent = MyAgent(model) + + # Register multiple thresholds + for threshold in [75.0, 50.0, 25.0]: + agent.add_threshold("energy", threshold, agent.on_threshold) + + def check_thresholds(): + _ = agent.energy # Should be 20.0, crossed all three + + assert len(agent.crossings) == 3 + assert (75.0, "down") in agent.crossings + assert (50.0, "down") in agent.crossings + assert (25.0, "down") in agent.crossings + + model.simulator.schedule_event_absolute(check_thresholds, 80.0) + model.simulator.run_until(80.0) + + +def test_continuous_observable_threshold_on_manual_set(): + """Test that thresholds are checked when manually setting values.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + self.triggered = False + + def on_threshold(self, signal): + if signal.direction == "down": + self.triggered = True + + model = SimpleModel() + agent = MyAgent(model) + + agent.add_threshold("energy", 50.0, agent.on_threshold) + + # Manually set below threshold + agent.energy = 30.0 + + assert agent.triggered + + +def test_continuous_observable_no_threshold_cross_same_side(): + """Test that thresholds aren't triggered when staying on same side.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=60.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 60.0 + self.triggered = False + + def on_threshold(self, signal): + self.triggered = True + + model = SimpleModel() + agent = MyAgent(model) + + agent.add_threshold("energy", 50.0, agent.on_threshold) + + def check_no_trigger(): + _ = agent.energy # Move from 60 to 55 - doesn't cross threshold + assert not agent.triggered + + model.simulator.schedule_event_absolute(check_no_trigger, 5.0) + model.simulator.run_until(5.0) + + +def test_continuous_observable_exact_threshold_value(): + """Test behavior when value equals threshold exactly.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + self.crossings = [] + + def on_threshold(self, signal): + self.crossings.append(signal.direction) + + model = SimpleModel() + agent = MyAgent(model) + + agent.add_threshold("energy", 50.0, agent.on_threshold) + + # Set exactly to threshold + agent.energy = 50.0 + + # Should trigger downward crossing (100 -> 50) + assert len(agent.crossings) == 1 + assert agent.crossings[0] == "down" + + +def test_continuous_observable_with_computed(): + """Test ContinuousObservable working with Computed properties.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + is_hungry = Computable() + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + # Pass self as an argument to Computed + self.is_hungry = Computed(lambda agent: agent.energy < 50.0, self) + + model = SimpleModel() + agent = MyAgent(model) + + # Not hungry initially + assert not agent.is_hungry + + def check_hungry(): + # The Computed will access agent.energy, which will trigger + # the ContinuousObservable to recalculate based on current time + print(f"Energy at t=60: {agent.energy}") # Debug + print(f"Is hungry: {agent.is_hungry}") # Debug + assert agent.is_hungry # Energy is now 40.0 + + model.simulator.schedule_event_absolute(check_hungry, 60.0) + model.simulator.run_until(60.0) + + +def test_continuous_observable_multiple_accesses_same_time(): + """Test that multiple accesses at same time don't recalculate.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + handler = Mock() + agent.observe("energy", "change", handler) + + def check_multiple_access(): + # Multiple accesses at same time + value1 = agent.energy + value2 = agent.energy + value3 = agent.energy + + # All should be same + assert value1 == value2 == value3 == 90.0 + + # Should only notify once + handler.assert_called_once() + + model.simulator.schedule_event_absolute(check_multiple_access, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_zero_elapsed_time(): + """Test behavior when no time has elapsed.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = SimpleModel() + agent = MyAgent(model) + + handler = Mock() + agent.observe("energy", "change", handler) + + # Access without time passing (at t=0) + value = agent.energy + + assert value == 100.0 + # Should not emit signal since nothing changed + handler.assert_not_called() + + +def test_continuous_observable_numpy_float_compatibility(): + """Test compatibility with numpy float values (like from random).""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, initial_energy): + super().__init__(model) + self.energy = initial_energy # numpy float + + model = SimpleModel() + + # Create with numpy float + numpy_value = np.float64(85.5) + agent = MyAgent(model, numpy_value) + + # Should work without AttributeError + assert agent.energy == 85.5 + + def check_after_time(): + assert agent.energy == 75.5 # 85.5 - (1.0 * 10) + + model.simulator.schedule_event_absolute(check_after_time, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_with_create_agents(): + """Test ContinuousObservable with batch agent creation.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, energy=100.0): + super().__init__(model) + self.energy = energy + + model = SimpleModel() + + # Create multiple agents with numpy array of energies + initial_energies = np.random.random(10) * 100 + + agents = MyAgent.create_agents(model, 10, energy=initial_energies) + + # All should be created successfully + assert len(agents) == 10 + + # Each should have correct energy + for agent, expected_energy in zip(agents, initial_energies): + assert abs(agent.energy - expected_energy) < 1e-10 + + +def test_continuous_observable_with_abm_simulator(): + """Test ContinuousObservable with ABMSimulator (integer time steps).""" + + class StepModel(Model): + def __init__(self): + super().__init__() + self.simulator = ABMSimulator() + self.simulator.setup(self) + + def step(self): + pass # Model step gets called automatically + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + + model = StepModel() + agent = MyAgent(model) + + # Run for 10 steps (integer time) + model.simulator.run_for(10) + + # Energy should have depleted + assert agent.energy == 90.0 + + +def test_continuous_observable_negative_values(): + """Test ContinuousObservable can go negative.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=10.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 10.0 + + model = SimpleModel() + agent = MyAgent(model) + + def check_negative(): + assert agent.energy == -10.0 + + model.simulator.schedule_event_absolute(check_negative, 20.0) + model.simulator.run_until(20.0) + + +def test_continuous_observable_integration_with_wolf_sheep(): + """Integration test simulating wolf-sheep scenario.""" + + class Animal(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, + rate_func=lambda value, elapsed, agent: -agent.metabolic_rate, + ) + + def __init__(self, model): + super().__init__(model) + self.metabolic_rate = 0.5 + self.energy = 100.0 + self.died = False + + # Death threshold - use add_threshold helper + self.add_threshold("energy", 0.0, self._on_death) + + def _on_death(self, signal): + if signal.direction == "down": + self.died = True + + def eat(self): + """Boost energy when eating.""" + self.energy += 20 + + model = SimpleModel() + agent = Animal(model) + + # Check survival at t=50 + def check_survival(): + assert agent.energy == 75.0 # 100 - (0.5 * 50) + assert not agent.died + # Eat + agent.eat() + assert agent.energy == 95.0 + + # Check continued depletion at t=100 + def check_continued(): + assert agent.energy == 70.0 # 95 - (0.5 * 50) + assert not agent.died + + # Check death at t=300 + def check_death(): + _ = agent.energy # Trigger check + assert agent.died + + model.simulator.schedule_event_absolute(check_survival, 50.0) + model.simulator.schedule_event_absolute(check_continued, 100.0) + model.simulator.schedule_event_absolute(check_death, 300.0) + model.simulator.run_until(300.0) + + +def test_continuous_observable_multiple_agents_independent_values(): + """Test that multiple agents maintain independent continuous values.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, + rate_func=lambda value, elapsed, agent: -agent.metabolic_rate, + ) + + def __init__(self, model, metabolic_rate): + super().__init__(model) + self.metabolic_rate = metabolic_rate + self.energy = 100.0 + + model = SimpleModel() + + # Create agents with different metabolic rates + agent1 = MyAgent(model, metabolic_rate=1.0) + agent2 = MyAgent(model, metabolic_rate=2.0) + agent3 = MyAgent(model, metabolic_rate=0.5) + + def check_values(): + # Each agent should deplete at their own rate + assert agent1.energy == 90.0 # 100 - (1.0 * 10) + assert agent2.energy == 80.0 # 100 - (2.0 * 10) + assert agent3.energy == 95.0 # 100 - (0.5 * 10) + + model.simulator.schedule_event_absolute(check_values, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_multiple_agents_independent_thresholds(): + """Test that different agents can have different thresholds.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, name): + super().__init__(model) + self.name = name + self.energy = 100.0 + self.threshold_crossed = False + + def on_threshold(self, signal): + if signal.direction == "down": + self.threshold_crossed = True + + model = SimpleModel() + + # Create agents with different thresholds + agent1 = MyAgent(model, "agent1") + agent1.add_threshold("energy", 75.0, agent1.on_threshold) + + agent2 = MyAgent(model, "agent2") + agent2.add_threshold("energy", 25.0, agent2.on_threshold) + + agent3 = MyAgent(model, "agent3") + agent3.add_threshold("energy", 50.0, agent3.on_threshold) + + def check_at_30(): + # At t=30, all agents at energy=70 + _ = agent1.energy + _ = agent2.energy + _ = agent3.energy + + # Only agent1 should have crossed their threshold (75) + assert agent1.threshold_crossed + assert not agent2.threshold_crossed # Hasn't reached 25 yet + assert not agent3.threshold_crossed # Hasn't reached 50 yet + + def check_at_55(): + # At t=55, all agents at energy=45 + _ = agent1.energy + _ = agent2.energy + _ = agent3.energy + + # agent1 and agent3 should have crossed + assert agent1.threshold_crossed + assert not agent2.threshold_crossed # Still hasn't reached 25 + assert agent3.threshold_crossed # Crossed 50 + + def check_at_80(): + # At t=80, all agents at energy=20 + _ = agent1.energy + _ = agent2.energy + _ = agent3.energy + + # All should have crossed now + assert agent1.threshold_crossed + assert agent2.threshold_crossed # Finally crossed 25 + assert agent3.threshold_crossed + + model.simulator.schedule_event_absolute(check_at_30, 30.0) + model.simulator.schedule_event_absolute(check_at_55, 55.0) + model.simulator.schedule_event_absolute(check_at_80, 80.0) + model.simulator.run_until(80.0) + + +def test_continuous_observable_multiple_agents_same_threshold_different_callbacks(): + """Test that multiple agents can watch the same threshold value with different callbacks.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, name): + super().__init__(model) + self.name = name + self.energy = 100.0 + self.crossed_count = 0 + + def on_threshold(self, signal): + if signal.direction == "down": + self.crossed_count += 1 + + model = SimpleModel() + + # Create multiple agents, all watching threshold at 50 + agents = [MyAgent(model, f"agent{i}") for i in range(5)] + + for agent in agents: + agent.add_threshold("energy", 50.0, agent.on_threshold) + + def check_crossings(): + # Access all agents' energy + for agent in agents: + _ = agent.energy + + # Each should have crossed independently + for agent in agents: + assert agent.crossed_count == 1 + + model.simulator.schedule_event_absolute(check_crossings, 60.0) + model.simulator.run_until(60.0) + + +def test_continuous_observable_agents_with_different_initial_values(): + """Test agents starting with different energy values.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, initial_energy): + super().__init__(model) + self.energy = initial_energy + + model = SimpleModel() + + # Create agents with different starting energies + agent1 = MyAgent(model, initial_energy=100.0) + agent2 = MyAgent(model, initial_energy=50.0) + agent3 = MyAgent(model, initial_energy=150.0) + + def check_values(): + # Each should deplete from their starting value + assert agent1.energy == 90.0 # 100 - 10 + assert agent2.energy == 40.0 # 50 - 10 + assert agent3.energy == 140.0 # 150 - 10 + + model.simulator.schedule_event_absolute(check_values, 10.0) + model.simulator.run_until(10.0) + + +def test_continuous_observable_agent_interactions(): + """Test agents affecting each other's continuous observables.""" + + class Predator(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=50.0, rate_func=lambda value, elapsed, agent: -0.5 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 50.0 + self.kills = 0 + + def eat(self, prey): + """Eat prey and gain energy.""" + self.energy += 20 + self.kills += 1 + prey.die() + + class Prey(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model): + super().__init__(model) + self.energy = 100.0 + self.alive = True + + def die(self): + self.alive = False + + model = SimpleModel() + + predator = Predator(model) + prey1 = Prey(model) + prey2 = Prey(model) + + def predator_hunts(): + # Predator energy should have depleted + assert predator.energy == 45.0 # 50 - (0.5 * 10) + + # Predator eats prey1 + predator.eat(prey1) + + # Predator gains energy + assert predator.energy == 65.0 # 45 + 20 + assert not prey1.alive + assert prey2.alive + + def check_final(): + # Predator continues depleting from boosted energy + assert predator.energy == 60.0 # 65 - (0.5 * 10) + + # prey2 continues depleting + assert prey2.energy == 80.0 # 100 - (1.0 * 20) + assert prey2.alive + + model.simulator.schedule_event_absolute(predator_hunts, 10.0) + model.simulator.schedule_event_absolute(check_final, 20.0) + model.simulator.run_until(20.0) + + +def test_continuous_observable_batch_creation_with_thresholds(): + """Test batch agent creation where each agent has instance-specific thresholds.""" + + class MyAgent(Agent, HasObservables): + energy = ContinuousObservable( + initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0 + ) + + def __init__(self, model, critical_threshold): + super().__init__(model) + self.energy = 100.0 + self.critical_threshold = critical_threshold + self.critical = False + + # Each agent watches their own critical threshold + self.add_threshold("energy", critical_threshold, self.on_critical) + + def on_critical(self, signal): + if signal.direction == "down": + self.critical = True + + model = SimpleModel() + + # Create 10 agents with different critical thresholds + thresholds = [90.0, 80.0, 70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0, 5.0] + agents = [MyAgent(model, threshold) for threshold in thresholds] + + def check_at_45(): + # At t=45, all agents at energy=55 + for agent in agents: + _ = agent.energy # Trigger recalculation + + # Agents with thresholds > 55 should be critical + for agent, threshold in zip(agents, thresholds): + if threshold > 55: + assert agent.critical + else: + assert not agent.critical + + model.simulator.schedule_event_absolute(check_at_45, 45.0) + model.simulator.run_until(45.0)