Skip to content

Commit

Permalink
Added more mypy stuff to plot
Browse files Browse the repository at this point in the history
  • Loading branch information
gaffney2010 committed Dec 9, 2024
1 parent d75fdd5 commit be2bbbc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
35 changes: 28 additions & 7 deletions axelrod/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import List, Optional, Union
from typing import Any, Callable, List, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -10,7 +10,7 @@
from .load_data_ import axl_filename
from .result_set import ResultSet

titleType = List[str]
titleType = str
namesType = List[str]
dataType = List[List[Union[int, float]]]

Expand All @@ -27,6 +27,9 @@ def _violinplot(
names: namesType,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""For making violinplots."""

Expand All @@ -35,7 +38,11 @@ def _violinplot(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 3, 12)
height = width / 2
spacing = 4
Expand All @@ -50,7 +57,7 @@ def _violinplot(
)
ax.set_xticks(positions)
ax.set_xticklabels(names, rotation=90)
ax.set_xlim([0, spacing * (self.num_players + 1)])
ax.set_xlim((0, spacing * (self.num_players + 1)))
ax.tick_params(axis="both", which="both", labelsize=8)
if title:
ax.set_title(title)
Expand Down Expand Up @@ -185,6 +192,9 @@ def _payoff_heatmap(
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
cmap: str = "viridis",
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""Generic heatmap plot"""

Expand All @@ -193,7 +203,11 @@ def _payoff_heatmap(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 4, 12)
height = width
figure.set_size_inches(width, height)
Expand Down Expand Up @@ -238,6 +252,9 @@ def stackplot(
title: Optional[titleType] = None,
logscale: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:

populations = eco.population_sizes
Expand All @@ -247,7 +264,11 @@ def stackplot(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
turns = range(len(populations))
pops = [
[populations[iturn][ir] for iturn in turns]
Expand All @@ -259,7 +280,7 @@ def stackplot(
ax.yaxis.set_label_position("right")
ax.yaxis.labelpad = 25.0

ax.set_ylim([0.0, 1.0])
ax.set_ylim((0.0, 1.0))
ax.set_ylabel("Relative population size")
ax.set_xlabel("Turn")
if title is not None:
Expand Down
20 changes: 20 additions & 0 deletions axelrod/tests/unit/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,23 @@ def test_all_plots(self):
progress_bar=True,
)
)

def test_figure_generation_failure_violinplot(self):
plot = axl.Plot(self.test_result_set)
with self.assertRaises(RuntimeError):
plot._violinplot(
[0, 0, 0], ["a", "b", "c"], get_figure=lambda _: None
)

def test_figure_generation_failure_payoff_heatmap(self):
plot = axl.Plot(self.test_result_set)
with self.assertRaises(RuntimeError):
plot._payoff_heatmap(
[0, 0, 0], ["a", "b", "c"], get_figure=lambda _: None
)

def test_figure_generation_failure_stackplot(self):
plot = axl.Plot(self.test_result_set)
eco = axl.Ecosystem(self.test_result_set)
with self.assertRaises(RuntimeError):
plot.stackplot(eco, get_figure=lambda _: None)

0 comments on commit be2bbbc

Please sign in to comment.