Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize PropertyLayers #2336

Merged
merged 13 commits into from
Oct 5, 2024
263 changes: 176 additions & 87 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,190 @@
"""Matplotlib based solara components for visualization MESA spaces and plots."""

from collections import defaultdict
import warnings

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import solara
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

import mesa
from mesa.experimental.cell_space import VoronoiGrid
from mesa.space import PropertyLayer
from mesa.visualization.utils import update_counter


def make_space_matplotlib(agent_portrayal=None): # noqa: D103
def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
"""Create a Matplotlib-based space visualization component.

Args:
agent_portrayal (function): Function to portray agents
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications

Returns:
function: A function that creates a SpaceMatplotlib component
"""
if agent_portrayal is None:

def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceMatplotlib(model):
return SpaceMatplotlib(model, agent_portrayal)
return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal)

return MakeSpaceMatplotlib


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103
def SpaceMatplotlib(
model,
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
):
"""Create a Matplotlib-based space visualization component."""
update_counter.get()
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
EwoutH marked this conversation as resolved.
Show resolved Hide resolved
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
space = getattr(model, "space", None)

if isinstance(space, mesa.space._Grid):
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

Check warning on line 56 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L56

Added line #L56 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either of:

  • Why pass an entire model instead of just the layers?
  • Why not use the space argument and access it from the model instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: s/Why not use the space/Why not remove the space/

Copy link
Member Author

@EwoutH EwoutH Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review and the questions, I will try to look into them soon.

Also feel free to open a PR if you already have concrete solutions in mind.

Edit: Yeah I now remind, there are some inconsistencies between the draw functions here. When I add support for other spaces I will try to address these.

elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
_draw_continuous_space(space, space_ax, agent_portrayal, model)

Check warning on line 58 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L58

Added line #L58 was not covered by tests
elif isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)

Check warning on line 60 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L60

Added line #L60 was not covered by tests
elif isinstance(space, VoronoiGrid):
_draw_voronoi(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

elif space is None and propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

Check warning on line 64 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L64

Added line #L64 was not covered by tests

# matplotlib scatter does not allow for multiple shapes in one call
def _split_and_scatter(portray_data, space_ax):
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})

# Extract data from the dictionary
x = portray_data["x"]
y = portray_data["y"]
s = portray_data["s"]
c = portray_data["c"]
m = portray_data["m"]

if not (len(x) == len(y) == len(s) == len(c) == len(m)):
raise ValueError(
"Length mismatch in portrayal data lists: "
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
f"color: {len(c)}, marker: {len(m)}"
)

# Group the data by marker
for i in range(len(x)):
marker = m[i]
grouped_data[marker]["x"].append(x[i])
grouped_data[marker]["y"].append(y[i])
grouped_data[marker]["s"].append(s[i])
grouped_data[marker]["c"].append(c[i])
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

# Plot each group with the same marker
for marker, data in grouped_data.items():
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)

def draw_property_layers(ax, space, propertylayer_portrayal, model):
"""Draw PropertyLayers on the given axes.

Args:
ax (matplotlib.axes.Axes): The axes to draw on.
space (mesa.space._Grid): The space containing the PropertyLayers.
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
model (mesa.Model): The model instance.
"""
for layer_name, portrayal in propertylayer_portrayal.items():
layer = getattr(model, layer_name, None)

Check warning on line 79 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L79

Added line #L79 was not covered by tests
if not isinstance(layer, PropertyLayer):
continue

Check warning on line 81 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L81

Added line #L81 was not covered by tests

data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
width, height = data.shape if space is None else (space.width, space.height)

Check warning on line 84 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L83-L84

Added lines #L83 - L84 were not covered by tests

if space and data.shape != (width, height):
warnings.warn(

Check warning on line 87 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L87

Added line #L87 was not covered by tests
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).",
UserWarning,
stacklevel=2,
)

# Get portrayal properties, or use defaults
alpha = portrayal.get("alpha", 1)
vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

Check warning on line 97 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L94-L97

Added lines #L94 - L97 were not covered by tests

# Draw the layer
if "color" in portrayal:
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
cmap = LinearSegmentedColormap.from_list(

Check warning on line 106 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L101-L106

Added lines #L101 - L106 were not covered by tests
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
)
im = ax.imshow(

Check warning on line 109 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L109

Added line #L109 was not covered by tests
rgba_data.transpose(1, 0, 2),
extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
ax.figure.colorbar(sm, ax=ax, orientation="vertical")

Check warning on line 118 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L115-L118

Added lines #L115 - L118 were not covered by tests

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")

Check warning on line 121 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L121

Added line #L121 was not covered by tests
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
im = ax.imshow(

Check warning on line 124 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L123-L124

Added lines #L123 - L124 were not covered by tests
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
plt.colorbar(im, ax=ax, label=layer_name)

Check warning on line 134 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L134

Added line #L134 was not covered by tests
else:
raise ValueError(

Check warning on line 136 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L136

Added line #L136 was not covered by tests
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)


def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
if propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

Check warning on line 143 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L143

Added line #L143 was not covered by tests

agent_data = _get_agent_data(space, agent_portrayal)

Check warning on line 145 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L145

Added line #L145 was not covered by tests

space_ax.set_xlim(0, space.width)
space_ax.set_ylim(0, space.height)
_split_and_scatter(agent_data, space_ax)

Check warning on line 149 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L147-L149

Added lines #L147 - L149 were not covered by tests

# Draw grid lines
for x in range(space.width + 1):
space_ax.axvline(x, color="gray", linestyle=":")

Check warning on line 153 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L153

Added line #L153 was not covered by tests
for y in range(space.height + 1):
space_ax.axhline(y, color="gray", linestyle=":")

Check warning on line 155 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L155

Added line #L155 was not covered by tests


def _get_agent_data(space, agent_portrayal):
"""Helper function to get agent data for visualization."""
x, y, s, c, m = [], [], [], [], []

Check warning on line 160 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L160

Added line #L160 was not covered by tests
for agents, pos in space.coord_iter():
if not agents:
continue

Check warning on line 163 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L163

Added line #L163 was not covered by tests
if not isinstance(agents, list):
agents = [agents] # noqa PLW2901

Check warning on line 165 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L165

Added line #L165 was not covered by tests
for agent in agents:
data = agent_portrayal(agent)
x.append(pos[0] + 0.5) # Center the agent in the cell
y.append(pos[1] + 0.5) # Center the agent in the cell
default_size = (180 / max(space.width, space.height)) ** 2
s.append(data.get("size", default_size))
c.append(data.get("color", "tab:blue"))
m.append(data.get("shape", "o"))
return {"x": x, "y": y, "s": s, "c": c, "m": m}

Check warning on line 174 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L167-L174

Added lines #L167 - L174 were not covered by tests

def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
m = [] # shape
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)

# This is the default value for the marker size, which auto-scales
# according to the grid area.
default_size = (180 / max(g.width, g.height)) ** 2
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out

space_ax.set_xlim(-1, space.width)
space_ax.set_ylim(-1, space.height)
_split_and_scatter(portray(space), space_ax)
def _split_and_scatter(portray_data, space_ax):
"""Helper function to split and scatter agent data."""
for marker in set(portray_data["m"]):
mask = [m == marker for m in portray_data["m"]]
space_ax.scatter(
[x for x, show in zip(portray_data["x"], mask) if show],
[y for y, show in zip(portray_data["y"], mask) if show],
s=[s for s, show in zip(portray_data["s"], mask) if show],
c=[c for c, show in zip(portray_data["c"], mask) if show],
marker=marker,
)


def _draw_network_grid(space, space_ax, agent_portrayal):
Expand All @@ -124,7 +198,7 @@
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def _draw_continuous_space(space, space_ax, agent_portrayal, model):
def portray(space):
x = []
y = []
Expand All @@ -139,15 +213,13 @@

# This is matplotlib's default marker size
default_size = 20
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out
return {"x": x, "y": y, "s": s, "c": c, "m": m}

Check warning on line 222 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L222

Added line #L222 was not covered by tests

# Determine border style based on space.torus
border_style = "solid" if not space.torus else (0, (5, 10))
Expand Down Expand Up @@ -186,8 +258,6 @@
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
# This is the default value for the marker size, which auto-scales
# according to the grid area.
out["s"] = s
if len(c) > 0:
out["c"] = c
Expand Down Expand Up @@ -216,18 +286,37 @@
alpha=min(1, cell.properties[space.cell_coloring_property]),
c="red",
) # Plot filled polygon
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black


def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
"""Create a plotting function for a specified measure.

Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.

Returns:
function: A function that creates a PlotMatplotlib component.
"""

def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103
def MakePlotMeasure(model):
return PlotMatplotlib(model, measure)

return MakePlotMeasure


@solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
"""Create a Matplotlib-based plot for a measure or measures.

Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
dependencies (list[any] | None): Optional dependencies for the plot.

Returns:
solara.FigureMatplotlib: A component for rendering the plot.
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
Expand All @@ -244,5 +333,5 @@
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

Check warning on line 336 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L336

Added line #L336 was not covered by tests
solara.FigureMatplotlib(fig, dependencies=dependencies)
5 changes: 4 additions & 1 deletion tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,14 @@ def test_call_space_drawer(mocker): # noqa: D103
"Shape": "circle",
"color": "gray",
}
propertylayer_portrayal = None
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)]))
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(model, agent_portrayal)
mock_space_matplotlib.assert_called_with(
model, agent_portrayal, propertylayer_portrayal
)

# specify no space should be drawn
mock_space_matplotlib.reset_mock()
Expand Down
Loading