From 47b3bd9558d8fd1db9c01eb6de1654596e3f99a9 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Sat, 5 Oct 2024 08:08:02 +0200 Subject: [PATCH] Visualize PropertyLayers (#2336) This PR adds support for visualizing PropertyLayers in the Matplotlib-based space visualization component. It allows users to overlay PropertyLayer data on top of the existing grid and agent visualizations, or on its own. It introduces a new `propertylayer_portrayal` parameter to customize the appearance of PropertyLayers and refactors the existing space visualization code for better modularity and flexibility. --- mesa/visualization/components/matplotlib.py | 263 +++++++++++++------- tests/test_solara_viz.py | 5 +- 2 files changed, 180 insertions(+), 88 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 6061356af7b..771978c1cd7 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,116 +1,190 @@ """Matplotlib based solara components for visualization MESA spaces and plots.""" -from collections import defaultdict +import warnings +import matplotlib.pyplot as plt import networkx as nx +import numpy as np import solara +from matplotlib.cm import ScalarMappable +from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba from matplotlib.figure import Figure -from matplotlib.ticker import MaxNLocator import mesa from mesa.experimental.cell_space import VoronoiGrid +from mesa.space import PropertyLayer from mesa.visualization.utils import update_counter -def make_space_matplotlib(agent_portrayal=None): # noqa: D103 +def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None): + """Create a Matplotlib-based space visualization component. + + Args: + agent_portrayal (function): Function to portray agents + propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications + + Returns: + function: A function that creates a SpaceMatplotlib component + """ if agent_portrayal is None: def agent_portrayal(a): return {"id": a.unique_id} def MakeSpaceMatplotlib(model): - return SpaceMatplotlib(model, agent_portrayal) + return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal) return MakeSpaceMatplotlib @solara.component -def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103 +def SpaceMatplotlib( + model, + agent_portrayal, + propertylayer_portrayal, + dependencies: list[any] | None = None, +): + """Create a Matplotlib-based space visualization component.""" update_counter.get() space_fig = Figure() space_ax = space_fig.subplots() space = getattr(model, "grid", None) if space is None: - # Sometimes the space is defined as model.space instead of model.grid - space = model.space - if isinstance(space, mesa.space.NetworkGrid): - _draw_network_grid(space, space_ax, agent_portrayal) + space = getattr(model, "space", None) + + if isinstance(space, mesa.space._Grid): + _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model) elif isinstance(space, mesa.space.ContinuousSpace): - _draw_continuous_space(space, space_ax, agent_portrayal) + _draw_continuous_space(space, space_ax, agent_portrayal, model) + elif isinstance(space, mesa.space.NetworkGrid): + _draw_network_grid(space, space_ax, agent_portrayal) elif isinstance(space, VoronoiGrid): _draw_voronoi(space, space_ax, agent_portrayal) - else: - _draw_grid(space, space_ax, agent_portrayal) - solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) - + elif space is None and propertylayer_portrayal: + draw_property_layers(space_ax, space, propertylayer_portrayal, model) -# matplotlib scatter does not allow for multiple shapes in one call -def _split_and_scatter(portray_data, space_ax): - grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) - - # Extract data from the dictionary - x = portray_data["x"] - y = portray_data["y"] - s = portray_data["s"] - c = portray_data["c"] - m = portray_data["m"] - - if not (len(x) == len(y) == len(s) == len(c) == len(m)): - raise ValueError( - "Length mismatch in portrayal data lists: " - f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " - f"color: {len(c)}, marker: {len(m)}" - ) - - # Group the data by marker - for i in range(len(x)): - marker = m[i] - grouped_data[marker]["x"].append(x[i]) - grouped_data[marker]["y"].append(y[i]) - grouped_data[marker]["s"].append(s[i]) - grouped_data[marker]["c"].append(c[i]) + solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) - # Plot each group with the same marker - for marker, data in grouped_data.items(): - space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) +def draw_property_layers(ax, space, propertylayer_portrayal, model): + """Draw PropertyLayers on the given axes. + + Args: + ax (matplotlib.axes.Axes): The axes to draw on. + space (mesa.space._Grid): The space containing the PropertyLayers. + propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications. + model (mesa.Model): The model instance. + """ + for layer_name, portrayal in propertylayer_portrayal.items(): + layer = getattr(model, layer_name, None) + if not isinstance(layer, PropertyLayer): + continue + + data = layer.data.astype(float) if layer.data.dtype == bool else layer.data + width, height = data.shape if space is None else (space.width, space.height) + + if space and data.shape != (width, height): + warnings.warn( + f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).", + UserWarning, + stacklevel=2, + ) + + # Get portrayal properties, or use defaults + alpha = portrayal.get("alpha", 1) + vmin = portrayal.get("vmin", np.min(data)) + vmax = portrayal.get("vmax", np.max(data)) + colorbar = portrayal.get("colorbar", True) + + # Draw the layer + if "color" in portrayal: + rgba_color = to_rgba(portrayal["color"]) + normalized_data = (data - vmin) / (vmax - vmin) + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha + rgba_data = np.clip(rgba_data, 0, 1) + cmap = LinearSegmentedColormap.from_list( + layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] + ) + im = ax.imshow( + rgba_data.transpose(1, 0, 2), + extent=(0, width, 0, height), + origin="lower", + ) + if colorbar: + norm = Normalize(vmin=vmin, vmax=vmax) + sm = ScalarMappable(norm=norm, cmap=cmap) + sm.set_array([]) + ax.figure.colorbar(sm, ax=ax, orientation="vertical") + + elif "colormap" in portrayal: + cmap = portrayal.get("colormap", "viridis") + if isinstance(cmap, list): + cmap = LinearSegmentedColormap.from_list(layer_name, cmap) + im = ax.imshow( + data.T, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + extent=(0, width, 0, height), + origin="lower", + ) + if colorbar: + plt.colorbar(im, ax=ax, label=layer_name) + else: + raise ValueError( + f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." + ) + + +def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model): + if propertylayer_portrayal: + draw_property_layers(space_ax, space, propertylayer_portrayal, model) + + agent_data = _get_agent_data(space, agent_portrayal) + + space_ax.set_xlim(0, space.width) + space_ax.set_ylim(0, space.height) + _split_and_scatter(agent_data, space_ax) + + # Draw grid lines + for x in range(space.width + 1): + space_ax.axvline(x, color="gray", linestyle=":") + for y in range(space.height + 1): + space_ax.axhline(y, color="gray", linestyle=":") + + +def _get_agent_data(space, agent_portrayal): + """Helper function to get agent data for visualization.""" + x, y, s, c, m = [], [], [], [], [] + for agents, pos in space.coord_iter(): + if not agents: + continue + if not isinstance(agents, list): + agents = [agents] # noqa PLW2901 + for agent in agents: + data = agent_portrayal(agent) + x.append(pos[0] + 0.5) # Center the agent in the cell + y.append(pos[1] + 0.5) # Center the agent in the cell + default_size = (180 / max(space.width, space.height)) ** 2 + s.append(data.get("size", default_size)) + c.append(data.get("color", "tab:blue")) + m.append(data.get("shape", "o")) + return {"x": x, "y": y, "s": s, "c": c, "m": m} -def _draw_grid(space, space_ax, agent_portrayal): - def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] - if not content: - continue - if not hasattr(content, "__iter__"): - # Is a single grid - content = [content] - for agent in content: - data = agent_portrayal(agent) - x.append(i) - y.append(j) - - # This is the default value for the marker size, which auto-scales - # according to the grid area. - default_size = (180 / max(g.width, g.height)) ** 2 - # establishing a default prevents misalignment if some agents are not given size, color, etc. - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} - return out - space_ax.set_xlim(-1, space.width) - space_ax.set_ylim(-1, space.height) - _split_and_scatter(portray(space), space_ax) +def _split_and_scatter(portray_data, space_ax): + """Helper function to split and scatter agent data.""" + for marker in set(portray_data["m"]): + mask = [m == marker for m in portray_data["m"]] + space_ax.scatter( + [x for x, show in zip(portray_data["x"], mask) if show], + [y for y, show in zip(portray_data["y"], mask) if show], + s=[s for s, show in zip(portray_data["s"], mask) if show], + c=[c for c, show in zip(portray_data["c"], mask) if show], + marker=marker, + ) def _draw_network_grid(space, space_ax, agent_portrayal): @@ -124,7 +198,7 @@ def _draw_network_grid(space, space_ax, agent_portrayal): ) -def _draw_continuous_space(space, space_ax, agent_portrayal): +def _draw_continuous_space(space, space_ax, agent_portrayal, model): def portray(space): x = [] y = [] @@ -139,15 +213,13 @@ def portray(space): # This is matplotlib's default marker size default_size = 20 - # establishing a default prevents misalignment if some agents are not given size, color, etc. size = data.get("size", default_size) s.append(size) color = data.get("color", "tab:blue") c.append(color) mark = data.get("shape", "o") m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} - return out + return {"x": x, "y": y, "s": s, "c": c, "m": m} # Determine border style based on space.torus border_style = "solid" if not space.torus else (0, (5, 10)) @@ -186,8 +258,6 @@ def portray(g): if "color" in data: c.append(data["color"]) out = {"x": x, "y": y} - # This is the default value for the marker size, which auto-scales - # according to the grid area. out["s"] = s if len(c) > 0: out["c"] = c @@ -216,10 +286,19 @@ def portray(g): alpha=min(1, cell.properties[space.cell_coloring_property]), c="red", ) # Plot filled polygon - space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red + space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black + + +def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): + """Create a plotting function for a specified measure. + Args: + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + + Returns: + function: A function that creates a PlotMatplotlib component. + """ -def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103 def MakePlotMeasure(model): return PlotMatplotlib(model, measure) @@ -227,7 +306,17 @@ def MakePlotMeasure(model): @solara.component -def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103 +def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): + """Create a Matplotlib-based plot for a measure or measures. + + Args: + model (mesa.Model): The model instance. + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + dependencies (list[any] | None): Optional dependencies for the plot. + + Returns: + solara.FigureMatplotlib: A component for rendering the plot. + """ update_counter.get() fig = Figure() ax = fig.subplots() @@ -244,5 +333,5 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # no ax.plot(df.loc[:, m], label=m) fig.legend() # Set integer x axis - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) solara.FigureMatplotlib(fig, dependencies=dependencies) diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 301294f25ba..a0d2b449399 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -100,11 +100,14 @@ def test_call_space_drawer(mocker): # noqa: D103 "Shape": "circle", "color": "gray", } + propertylayer_portrayal = None # initialize with space drawer unspecified (use default) # component must be rendered for code to run solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)])) # should call default method with class instance and agent portrayal - mock_space_matplotlib.assert_called_with(model, agent_portrayal) + mock_space_matplotlib.assert_called_with( + model, agent_portrayal, propertylayer_portrayal + ) # specify no space should be drawn mock_space_matplotlib.reset_mock()