Skip to content

Commit

Permalink
Visualize PropertyLayers (projectmesa#2336)
Browse files Browse the repository at this point in the history
This PR adds support for visualizing PropertyLayers in the Matplotlib-based space visualization component. It allows users to overlay PropertyLayer data on top of the existing grid and agent visualizations, or on its own.

It introduces a new `propertylayer_portrayal` parameter to customize the appearance of PropertyLayers and refactors the existing space visualization code for better modularity and flexibility.
  • Loading branch information
EwoutH authored Oct 5, 2024
1 parent a7dc9b2 commit 47b3bd9
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 88 deletions.
263 changes: 176 additions & 87 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,190 @@
"""Matplotlib based solara components for visualization MESA spaces and plots."""

from collections import defaultdict
import warnings

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import solara
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

import mesa
from mesa.experimental.cell_space import VoronoiGrid
from mesa.space import PropertyLayer
from mesa.visualization.utils import update_counter


def make_space_matplotlib(agent_portrayal=None): # noqa: D103
def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
"""Create a Matplotlib-based space visualization component.
Args:
agent_portrayal (function): Function to portray agents
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications
Returns:
function: A function that creates a SpaceMatplotlib component
"""
if agent_portrayal is None:

def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceMatplotlib(model):
return SpaceMatplotlib(model, agent_portrayal)
return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal)

return MakeSpaceMatplotlib


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103
def SpaceMatplotlib(
model,
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
):
"""Create a Matplotlib-based space visualization component."""
update_counter.get()
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
space = getattr(model, "space", None)

if isinstance(space, mesa.space._Grid):
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
_draw_continuous_space(space, space_ax, agent_portrayal, model)
elif isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, VoronoiGrid):
_draw_voronoi(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

elif space is None and propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

# matplotlib scatter does not allow for multiple shapes in one call
def _split_and_scatter(portray_data, space_ax):
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})

# Extract data from the dictionary
x = portray_data["x"]
y = portray_data["y"]
s = portray_data["s"]
c = portray_data["c"]
m = portray_data["m"]

if not (len(x) == len(y) == len(s) == len(c) == len(m)):
raise ValueError(
"Length mismatch in portrayal data lists: "
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
f"color: {len(c)}, marker: {len(m)}"
)

# Group the data by marker
for i in range(len(x)):
marker = m[i]
grouped_data[marker]["x"].append(x[i])
grouped_data[marker]["y"].append(y[i])
grouped_data[marker]["s"].append(s[i])
grouped_data[marker]["c"].append(c[i])
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

# Plot each group with the same marker
for marker, data in grouped_data.items():
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)

def draw_property_layers(ax, space, propertylayer_portrayal, model):
"""Draw PropertyLayers on the given axes.
Args:
ax (matplotlib.axes.Axes): The axes to draw on.
space (mesa.space._Grid): The space containing the PropertyLayers.
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
model (mesa.Model): The model instance.
"""
for layer_name, portrayal in propertylayer_portrayal.items():
layer = getattr(model, layer_name, None)
if not isinstance(layer, PropertyLayer):
continue

data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
width, height = data.shape if space is None else (space.width, space.height)

if space and data.shape != (width, height):
warnings.warn(
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).",
UserWarning,
stacklevel=2,
)

# Get portrayal properties, or use defaults
alpha = portrayal.get("alpha", 1)
vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Draw the layer
if "color" in portrayal:
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
cmap = LinearSegmentedColormap.from_list(
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
)
im = ax.imshow(
rgba_data.transpose(1, 0, 2),
extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
ax.figure.colorbar(sm, ax=ax, orientation="vertical")

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
im = ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
plt.colorbar(im, ax=ax, label=layer_name)
else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)


def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
if propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

agent_data = _get_agent_data(space, agent_portrayal)

space_ax.set_xlim(0, space.width)
space_ax.set_ylim(0, space.height)
_split_and_scatter(agent_data, space_ax)

# Draw grid lines
for x in range(space.width + 1):
space_ax.axvline(x, color="gray", linestyle=":")
for y in range(space.height + 1):
space_ax.axhline(y, color="gray", linestyle=":")


def _get_agent_data(space, agent_portrayal):
"""Helper function to get agent data for visualization."""
x, y, s, c, m = [], [], [], [], []
for agents, pos in space.coord_iter():
if not agents:
continue
if not isinstance(agents, list):
agents = [agents] # noqa PLW2901
for agent in agents:
data = agent_portrayal(agent)
x.append(pos[0] + 0.5) # Center the agent in the cell
y.append(pos[1] + 0.5) # Center the agent in the cell
default_size = (180 / max(space.width, space.height)) ** 2
s.append(data.get("size", default_size))
c.append(data.get("color", "tab:blue"))
m.append(data.get("shape", "o"))
return {"x": x, "y": y, "s": s, "c": c, "m": m}

def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
m = [] # shape
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)

# This is the default value for the marker size, which auto-scales
# according to the grid area.
default_size = (180 / max(g.width, g.height)) ** 2
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out

space_ax.set_xlim(-1, space.width)
space_ax.set_ylim(-1, space.height)
_split_and_scatter(portray(space), space_ax)
def _split_and_scatter(portray_data, space_ax):
"""Helper function to split and scatter agent data."""
for marker in set(portray_data["m"]):
mask = [m == marker for m in portray_data["m"]]
space_ax.scatter(
[x for x, show in zip(portray_data["x"], mask) if show],
[y for y, show in zip(portray_data["y"], mask) if show],
s=[s for s, show in zip(portray_data["s"], mask) if show],
c=[c for c, show in zip(portray_data["c"], mask) if show],
marker=marker,
)


def _draw_network_grid(space, space_ax, agent_portrayal):
Expand All @@ -124,7 +198,7 @@ def _draw_network_grid(space, space_ax, agent_portrayal):
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def _draw_continuous_space(space, space_ax, agent_portrayal, model):
def portray(space):
x = []
y = []
Expand All @@ -139,15 +213,13 @@ def portray(space):

# This is matplotlib's default marker size
default_size = 20
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out
return {"x": x, "y": y, "s": s, "c": c, "m": m}

# Determine border style based on space.torus
border_style = "solid" if not space.torus else (0, (5, 10))
Expand Down Expand Up @@ -186,8 +258,6 @@ def portray(g):
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
# This is the default value for the marker size, which auto-scales
# according to the grid area.
out["s"] = s
if len(c) > 0:
out["c"] = c
Expand Down Expand Up @@ -216,18 +286,37 @@ def portray(g):
alpha=min(1, cell.properties[space.cell_coloring_property]),
c="red",
) # Plot filled polygon
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black


def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
"""Create a plotting function for a specified measure.
Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
Returns:
function: A function that creates a PlotMatplotlib component.
"""

def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103
def MakePlotMeasure(model):
return PlotMatplotlib(model, measure)

return MakePlotMeasure


@solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
"""Create a Matplotlib-based plot for a measure or measures.
Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
dependencies (list[any] | None): Optional dependencies for the plot.
Returns:
solara.FigureMatplotlib: A component for rendering the plot.
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
Expand All @@ -244,5 +333,5 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # no
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
solara.FigureMatplotlib(fig, dependencies=dependencies)
5 changes: 4 additions & 1 deletion tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,14 @@ def test_call_space_drawer(mocker): # noqa: D103
"Shape": "circle",
"color": "gray",
}
propertylayer_portrayal = None
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)]))
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(model, agent_portrayal)
mock_space_matplotlib.assert_called_with(
model, agent_portrayal, propertylayer_portrayal
)

# specify no space should be drawn
mock_space_matplotlib.reset_mock()
Expand Down

0 comments on commit 47b3bd9

Please sign in to comment.