Skip to content

Commit

Permalink
Support simulators in SolaraViz (#2470)
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored Nov 10, 2024
1 parent 084275a commit ed2d8fd
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 30 deletions.
23 changes: 17 additions & 6 deletions mesa/examples/advanced/wolf_sheep/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from mesa.examples.advanced.wolf_sheep.agents import GrassPatch, Sheep, Wolf
from mesa.examples.advanced.wolf_sheep.model import WolfSheep
from mesa.experimental.devs import ABMSimulator
from mesa.visualization import (
Slider,
SolaraViz,
Expand Down Expand Up @@ -36,7 +37,11 @@ def wolf_sheep_portrayal(agent):


model_params = {
# The following line is an example to showcase StaticText.
"seed": {
"type": "InputText",
"value": 42,
"label": "Random Seed",
},
"grass": {
"type": "Select",
"value": True,
Expand All @@ -59,26 +64,32 @@ def wolf_sheep_portrayal(agent):
}


def post_process(ax):
def post_process_space(ax):
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])


def post_process_lines(ax):
ax.legend(loc="center left", bbox_to_anchor=(1, 0.9))


space_component = make_space_component(
wolf_sheep_portrayal, draw_grid=False, post_process=post_process
wolf_sheep_portrayal, draw_grid=False, post_process=post_process_space
)
lineplot_component = make_plot_component(
{"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"}
{"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"},
post_process=post_process_lines,
)

model = WolfSheep(grass=True)

simulator = ABMSimulator()
model = WolfSheep(simulator, grass=True)

page = SolaraViz(
model,
components=[space_component, lineplot_component],
model_params=model_params,
name="Wolf Sheep",
simulator=simulator,
)
page # noqa
2 changes: 1 addition & 1 deletion mesa/examples/advanced/wolf_sheep/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class WolfSheep(Model):

def __init__(
self,
height=20,
width=20,
height=20,
initial_sheep=100,
initial_wolves=50,
sheep_reproduce=0.04,
Expand Down
109 changes: 86 additions & 23 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import solara

import mesa.visualization.components.altair_components as components_altair
from mesa.experimental.devs.simulator import Simulator
from mesa.visualization.user_param import Slider
from mesa.visualization.utils import force_update, update_counter

Expand All @@ -42,10 +43,12 @@
@solara.component
def SolaraViz(
model: Model | solara.Reactive[Model],
*,
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]]
| Literal["default"] = "default",
play_interval: int = 100,
simulator: Simulator | None = None,
model_params=None,
name: str | None = None,
):
Expand All @@ -65,6 +68,7 @@ def SolaraViz(
Defaults to "default", which uses the default Altair space visualization.
play_interval (int, optional): Interval for playing the model steps in milliseconds.
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
simulator: A simulator that controls the model (optional)
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
name (str | None, optional): Name of the visualization. Defaults to the models class name.
Expand Down Expand Up @@ -92,21 +96,6 @@ def SolaraViz(
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model) # noqa: SH102, RUF100

def connect_to_model():
# Patch the step function to force updates
original_step = model.value.step

def step():
original_step()
force_update()

model.value.step = step
# Add a trigger to model itself
model.value.force_update = force_update
force_update()

solara.use_effect(connect_to_model, [model.value])

# set up reactive model_parameters shared by ModelCreator and ModelController
reactive_model_parameters = solara.use_reactive({})

Expand All @@ -115,11 +104,19 @@ def step():

with solara.Sidebar(), solara.Column():
with solara.Card("Controls"):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=play_interval,
)
if not isinstance(simulator, Simulator):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=play_interval,
)
else:
SimulatorController(
model,
simulator,
model_parameters=reactive_model_parameters,
play_interval=play_interval,
)
with solara.Card("Model Parameters"):
ModelCreator(
model, model_params, model_parameters=reactive_model_parameters
Expand Down Expand Up @@ -207,6 +204,7 @@ def do_step():
"""Advance the model by one step."""
model.value.step()
running.value = model.value.running
force_update()

def do_reset():
"""Reset the model to its initial state."""
Expand Down Expand Up @@ -234,6 +232,73 @@ def do_play_pause():
)


@solara.component
def SimulatorController(
model: solara.Reactive[Model],
simulator,
*,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int = 100,
):
"""Create controls for model execution (step, play, pause, reset).
Args:
model: Reactive model instance
simulator: Simulator instance
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval / 1000)
do_step()

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)

def do_step():
"""Advance the model by one step."""
simulator.run_for(1)
running.value = model.value.running
force_update()

def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
simulator.reset()
model.value = model.value = model.value.__class__(
simulator, **model_parameters.value
)

def do_play_pause():
"""Toggle play/pause."""
playing.value = not playing.value

with solara.Row(justify="space-between"):
solara.Button(label="Reset", color="primary", on_click=do_reset)
solara.Button(
label="▶" if not playing.value else "❚❚",
color="primary",
on_click=do_play_pause,
disabled=not running.value,
)
solara.Button(
label="Step",
color="primary",
on_click=do_step,
disabled=playing.value or not running.value,
)


def split_model_params(model_params):
"""Split model parameters into user-adjustable and fixed parameters.
Expand Down Expand Up @@ -324,9 +389,7 @@ def ModelCreator(
}

def on_change(name, value):
new_model_parameters = {**model_parameters.value, name: value}
model.value = model.value.__class__(**new_model_parameters)
model_parameters.value = new_model_parameters
model_parameters.value = {**model_parameters.value, name: value}

UserInputs(user_params, on_change=on_change)

Expand Down

0 comments on commit ed2d8fd

Please sign in to comment.