diff --git a/spynnaker/spynnaker_plotting.py b/spynnaker/spynnaker_plotting.py index 28fa4b7126..bdfbf8f09e 100644 --- a/spynnaker/spynnaker_plotting.py +++ b/spynnaker/spynnaker_plotting.py @@ -17,21 +17,33 @@ https://github.com/NeuralEnsemble/PyNN/blob/master/pyNN/utility/plotting.py """ +from typing import Any, Dict, List, Union from types import ModuleType + from neo import SpikeTrain, Block, Segment, AnalogSignal +#from neo.core import AnalogSignal, Block, Segment +from neo.core.spiketrainlist import SpikeTrainList # type: ignore[import] import numpy as np +from numpy.typing import NDArray import quantities +from typing_extensions import TypeAlias + plt: ModuleType try: from pyNN.utility.plotting import repeat import matplotlib.pyplot # type: ignore[import] + from matplotlib.axes import Axes # type: ignore[import] plt = matplotlib.pyplot _matplotlib_missing = False except ImportError: _matplotlib_missing = True + Axes = object + +TA_Data: TypeAlias = Union[List[SpikeTrain], SpikeTrainList, AnalogSignal, + NDArray, Block, Segment] -def _handle_options(axes, options): +def _handle_options(axes: Axes, options: Dict[str, Any]) -> None: """ Handles options that can not be passed to `axes.plot`. @@ -115,22 +127,18 @@ def plot_spikes_numpy(axes, spikes, label='', **options): _plot_spikes(axes, spike_times, neurons, label=label, **options) -def _heat_plot(axes, neurons, times, values, label='', **options): +def _heat_plot(axes: Axes, values: NDArray, label: str = '', + **options: Any) -> None: """ Plots three lists of neurons, times and values into a heat map. :param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure - :param neurons: List of neuron IDs - :param times: List of times :param values: List of values to plot :param str label: Label for the graph :param options: plotting options """ _handle_options(axes, options) - info_array = np.empty((max(neurons)+1, max(times)+1)) - info_array[:] = np.nan - info_array[neurons, times] = values - heat_map = axes.imshow(info_array, cmap='hot', interpolation='none', + heat_map = axes.imshow(values, cmap='hot', interpolation='none', origin='lower', aspect='auto') axes.figure.colorbar(heat_map) if label: @@ -151,10 +159,14 @@ def heat_plot_numpy(axes, data, label='', **options): neurons = data[:, 0].astype(int) times = data[:, 1].astype(int) values = data[:, 2] - _heat_plot(axes, neurons, times, values, label=label, **options) + info_array = np.empty((max(neurons) + 1, max(times) + 1)) + info_array[:] = np.nan + info_array[neurons, times] = values + _heat_plot(axes, info_array, label=label, **options) -def heat_plot_neo(axes, signal_array, label='', **options): +def heat_plot_neo(axes: Axes, signal_array: AnalogSignal, label: str = '', + **options: Any): """ Plots neurons, times and values into a heat map. @@ -165,18 +177,12 @@ def heat_plot_neo(axes, signal_array, label='', **options): """ if label is None: label = signal_array.name - n_neurons = signal_array.shape[-1] - xs = list(range(n_neurons)) - times = signal_array.times / signal_array.sampling_period - times = np.rint(times.magnitude).astype(int) - all_times = np.tile(times, n_neurons) - neurons = np.repeat(xs, len(times)) - magnitude = signal_array.magnitude - values = np.concatenate([magnitude[:, x] for x in xs]) - _heat_plot(axes, neurons, all_times, values, label=label, **options) - - -def plot_segment(axes, segment, label='', **options): + values = np.transpose(signal_array.magnitude) + _heat_plot(axes, values, label=label, **options) + + +def plot_segment(axes: Axes, segment: Segment, label: str = '', + **options: Any) -> None: """ Plots a segment into a plot of spikes or a heat map. @@ -244,7 +250,7 @@ class SpynnakerPanel(object): Whole Segments can be passed in only if they only contain one type of data. """ - def __init__(self, *data, **options): + def __init__(self, *data: TA_Data, **options: Any): """ :param data: One or more data series to be plotted. :type data: list(~neo.core.SpikeTrain) or ~neo.core.AnalogSignal @@ -258,7 +264,7 @@ def __init__(self, *data, **options): self.data_labels = options.pop("data_labels", repeat(None)) self.line_properties = options.pop("line_properties", repeat({})) - def plot(self, axes): + def plot(self, axes: Axes): """ Plot the Panel's data in the provided Axes/Subplot instance. @@ -276,7 +282,7 @@ def plot(self, axes): if len(datum) == 1 and not isinstance(datum[0], SpikeTrain): datum = datum[0] - if isinstance(datum, list): + if isinstance(datum, (list, SpikeTrainList)): self.__plot_list(axes, datum, label, properties) # AnalogSignal is also a ndarray, but data format different! # We import them as a single name here @@ -293,13 +299,16 @@ def plot(self, axes): f"consider using pyNN.utility.plotting") @staticmethod - def __plot_list(axes, datum, label, properties): + def __plot_list( + axes: Axes, datum: Union[List[SpikeTrain], SpikeTrainList], + label: str, properties: Dict[str, Any]) -> None: if not isinstance(datum[0], SpikeTrain): raise ValueError(f"Can't handle lists of type {type(datum)}") plot_spiketrains(axes, datum, label=label, **properties) @staticmethod - def __plot_array(axes, datum, label, properties): + def __plot_array(axes: Axes, datum: NDArray, label: str, + properties: Dict[str, Any]) -> None: if len(datum[0]) == 2: plot_spikes_numpy(axes, datum, label=label, **properties) elif len(datum[0]) == 3: @@ -309,7 +318,8 @@ def __plot_array(axes, datum, label, properties): f"Can't handle ndarray with {len(datum[0])} columns") @staticmethod - def __plot_block(axes, datum, label, properties): + def __plot_block(axes: Axes, datum: Block, label: str, + properties: Dict[str, Any]) -> None: if "run" in properties: run = int(properties.pop("run")) if len(datum.segments) <= run: diff --git a/unittests/test_pop_views_assembly/manual_plot_checker.py b/unittests/test_pop_views_assembly/manual_plot_checker.py new file mode 100644 index 0000000000..7b47046fa8 --- /dev/null +++ b/unittests/test_pop_views_assembly/manual_plot_checker.py @@ -0,0 +1,38 @@ +import os +import matplotlib.pyplot as plt +import pyNN.utility.plotting as plot +from spynnaker.pyNN.utilities.neo_buffer_database import NeoBufferDatabase +from spynnaker.spynnaker_plotting import SpynnakerPanel + + +my_dir = os.path.dirname(os.path.abspath(__file__)) +my_buffer = os.path.join(my_dir, "all_data.sqlite3") +with NeoBufferDatabase(my_buffer) as db: + pop = db.get_population("pop_1") +sneo = pop.get_data("spikes") +neo = pop.get_data(["spikes", "v"]) +vneo = pop.get_data("v") +spikes = neo.segments[0].spiketrains +v = neo.segments[0].filter(name='v')[0] +v_matrix = pop.spinnaker_get_data("v") +print(v) + +plot.Figure( + # plot spikes (or in this case spike) + plot.Panel(spikes, yticks=True, markersize=5, xlim=(0, 35)), + SpynnakerPanel(spikes, yticks=True, xticks=True, markersize=4), + SpynnakerPanel(sneo, yticks=True, xticks=True, markersize=4), + SpynnakerPanel(sneo.segments[0], yticks=True, xticks=True, markersize=4), + # plot voltage for first ([0]) neuron + SpynnakerPanel(neo, yticks=True, xticks=True, markersize=4, name="spikes"), + plot.Panel(v, ylabel="Membrane potential (mV)", + data_labels=[pop.label], yticks=True, xlim=(0, 35)), + SpynnakerPanel(v, yticks=True, xticks=True, markersize=4), + SpynnakerPanel(vneo, yticks=True, xticks=True, markersize=4), + SpynnakerPanel(vneo.segments[0], yticks=True, xticks=True, markersize=4), + SpynnakerPanel(v_matrix, yticks=True, xticks=True, markersize=4), + SpynnakerPanel(neo, yticks=True, xticks=True, markersize=4, name="v"), + title="Simple Example", + annotations=f"Simulated with " +) +plt.show()