From be2bbbc46113c632e7c325a581a68b189e83ab1c Mon Sep 17 00:00:00 2001 From: "T.J. Gaffney" Date: Mon, 9 Dec 2024 11:57:02 -0600 Subject: [PATCH] Added more mypy stuff to plot --- axelrod/plot.py | 35 ++++++++++++++++++++++++++------- axelrod/tests/unit/test_plot.py | 20 +++++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/axelrod/plot.py b/axelrod/plot.py index d7891985b..23f5b7c9c 100644 --- a/axelrod/plot.py +++ b/axelrod/plot.py @@ -1,5 +1,5 @@ import pathlib -from typing import List, Optional, Union +from typing import Any, Callable, List, Optional, Union import matplotlib import matplotlib.pyplot as plt @@ -10,7 +10,7 @@ from .load_data_ import axl_filename from .result_set import ResultSet -titleType = List[str] +titleType = str namesType = List[str] dataType = List[List[Union[int, float]]] @@ -27,6 +27,9 @@ def _violinplot( names: namesType, title: Optional[titleType] = None, ax: Optional[matplotlib.axes.Axes] = None, + get_figure: Callable[ + [matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None] + ] = lambda ax: ax.get_figure(), ) -> matplotlib.figure.Figure: """For making violinplots.""" @@ -35,7 +38,11 @@ def _violinplot( else: ax = ax - figure = ax.get_figure() + figure = get_figure(ax) + if not isinstance(figure, matplotlib.figure.Figure): + raise RuntimeError( + "get_figure unexpectedly returned a non-figure object" + ) width = max(self.num_players / 3, 12) height = width / 2 spacing = 4 @@ -50,7 +57,7 @@ def _violinplot( ) ax.set_xticks(positions) ax.set_xticklabels(names, rotation=90) - ax.set_xlim([0, spacing * (self.num_players + 1)]) + ax.set_xlim((0, spacing * (self.num_players + 1))) ax.tick_params(axis="both", which="both", labelsize=8) if title: ax.set_title(title) @@ -185,6 +192,9 @@ def _payoff_heatmap( title: Optional[titleType] = None, ax: Optional[matplotlib.axes.Axes] = None, cmap: str = "viridis", + get_figure: Callable[ + [matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None] + ] = lambda ax: ax.get_figure(), ) -> matplotlib.figure.Figure: """Generic heatmap plot""" @@ -193,7 +203,11 @@ def _payoff_heatmap( else: ax = ax - figure = ax.get_figure() + figure = get_figure(ax) + if not isinstance(figure, matplotlib.figure.Figure): + raise RuntimeError( + "get_figure unexpectedly returned a non-figure object" + ) width = max(self.num_players / 4, 12) height = width figure.set_size_inches(width, height) @@ -238,6 +252,9 @@ def stackplot( title: Optional[titleType] = None, logscale: bool = True, ax: Optional[matplotlib.axes.Axes] = None, + get_figure: Callable[ + [matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None] + ] = lambda ax: ax.get_figure(), ) -> matplotlib.figure.Figure: populations = eco.population_sizes @@ -247,7 +264,11 @@ def stackplot( else: ax = ax - figure = ax.get_figure() + figure = get_figure(ax) + if not isinstance(figure, matplotlib.figure.Figure): + raise RuntimeError( + "get_figure unexpectedly returned a non-figure object" + ) turns = range(len(populations)) pops = [ [populations[iturn][ir] for iturn in turns] @@ -259,7 +280,7 @@ def stackplot( ax.yaxis.set_label_position("right") ax.yaxis.labelpad = 25.0 - ax.set_ylim([0.0, 1.0]) + ax.set_ylim((0.0, 1.0)) ax.set_ylabel("Relative population size") ax.set_xlabel("Turn") if title is not None: diff --git a/axelrod/tests/unit/test_plot.py b/axelrod/tests/unit/test_plot.py index 0db464361..e34b7fabf 100644 --- a/axelrod/tests/unit/test_plot.py +++ b/axelrod/tests/unit/test_plot.py @@ -255,3 +255,23 @@ def test_all_plots(self): progress_bar=True, ) ) + + def test_figure_generation_failure_violinplot(self): + plot = axl.Plot(self.test_result_set) + with self.assertRaises(RuntimeError): + plot._violinplot( + [0, 0, 0], ["a", "b", "c"], get_figure=lambda _: None + ) + + def test_figure_generation_failure_payoff_heatmap(self): + plot = axl.Plot(self.test_result_set) + with self.assertRaises(RuntimeError): + plot._payoff_heatmap( + [0, 0, 0], ["a", "b", "c"], get_figure=lambda _: None + ) + + def test_figure_generation_failure_stackplot(self): + plot = axl.Plot(self.test_result_set) + eco = axl.Ecosystem(self.test_result_set) + with self.assertRaises(RuntimeError): + plot.stackplot(eco, get_figure=lambda _: None)