Skip to content

Commit

Permalink
fix, type and test SpynnakerPanel
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Feb 5, 2025
1 parent d2753e2 commit fc12722
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 28 deletions.
66 changes: 38 additions & 28 deletions spynnaker/spynnaker_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,33 @@
https://github.com/NeuralEnsemble/PyNN/blob/master/pyNN/utility/plotting.py
"""

from typing import Any, Dict, List, Union
from types import ModuleType

from neo import SpikeTrain, Block, Segment, AnalogSignal
#from neo.core import AnalogSignal, Block, Segment
from neo.core.spiketrainlist import SpikeTrainList # type: ignore[import]
import numpy as np
from numpy.typing import NDArray
import quantities
from typing_extensions import TypeAlias

plt: ModuleType
try:
from pyNN.utility.plotting import repeat
import matplotlib.pyplot # type: ignore[import]
from matplotlib.axes import Axes # type: ignore[import]
plt = matplotlib.pyplot
_matplotlib_missing = False
except ImportError:
_matplotlib_missing = True
Axes = object

TA_Data: TypeAlias = Union[List[SpikeTrain], SpikeTrainList, AnalogSignal,
NDArray, Block, Segment]


def _handle_options(axes, options):
def _handle_options(axes: Axes, options: Dict[str, Any]) -> None:
"""
Handles options that can not be passed to `axes.plot`.
Expand Down Expand Up @@ -115,22 +127,18 @@ def plot_spikes_numpy(axes, spikes, label='', **options):
_plot_spikes(axes, spike_times, neurons, label=label, **options)


def _heat_plot(axes, neurons, times, values, label='', **options):
def _heat_plot(axes: Axes, values: NDArray, label: str = '',
**options: Any) -> None:
"""
Plots three lists of neurons, times and values into a heat map.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param neurons: List of neuron IDs
:param times: List of times
:param values: List of values to plot
:param str label: Label for the graph
:param options: plotting options
"""
_handle_options(axes, options)
info_array = np.empty((max(neurons)+1, max(times)+1))
info_array[:] = np.nan
info_array[neurons, times] = values
heat_map = axes.imshow(info_array, cmap='hot', interpolation='none',
heat_map = axes.imshow(values, cmap='hot', interpolation='none',
origin='lower', aspect='auto')
axes.figure.colorbar(heat_map)
if label:
Expand All @@ -151,10 +159,14 @@ def heat_plot_numpy(axes, data, label='', **options):
neurons = data[:, 0].astype(int)
times = data[:, 1].astype(int)
values = data[:, 2]
_heat_plot(axes, neurons, times, values, label=label, **options)
info_array = np.empty((max(neurons) + 1, max(times) + 1))
info_array[:] = np.nan
info_array[neurons, times] = values
_heat_plot(axes, info_array, label=label, **options)


def heat_plot_neo(axes, signal_array, label='', **options):
def heat_plot_neo(axes: Axes, signal_array: AnalogSignal, label: str = '',
**options: Any):
"""
Plots neurons, times and values into a heat map.
Expand All @@ -165,18 +177,12 @@ def heat_plot_neo(axes, signal_array, label='', **options):
"""
if label is None:
label = signal_array.name
n_neurons = signal_array.shape[-1]
xs = list(range(n_neurons))
times = signal_array.times / signal_array.sampling_period
times = np.rint(times.magnitude).astype(int)
all_times = np.tile(times, n_neurons)
neurons = np.repeat(xs, len(times))
magnitude = signal_array.magnitude
values = np.concatenate([magnitude[:, x] for x in xs])
_heat_plot(axes, neurons, all_times, values, label=label, **options)


def plot_segment(axes, segment, label='', **options):
values = np.transpose(signal_array.magnitude)
_heat_plot(axes, values, label=label, **options)


def plot_segment(axes: Axes, segment: Segment, label: str = '',
**options: Any) -> None:
"""
Plots a segment into a plot of spikes or a heat map.
Expand Down Expand Up @@ -244,7 +250,7 @@ class SpynnakerPanel(object):
Whole Segments can be passed in only if they only contain one type of data.
"""

def __init__(self, *data, **options):
def __init__(self, *data: TA_Data, **options: Any):
"""
:param data: One or more data series to be plotted.
:type data: list(~neo.core.SpikeTrain) or ~neo.core.AnalogSignal
Expand All @@ -258,7 +264,7 @@ def __init__(self, *data, **options):
self.data_labels = options.pop("data_labels", repeat(None))
self.line_properties = options.pop("line_properties", repeat({}))

def plot(self, axes):
def plot(self, axes: Axes):
"""
Plot the Panel's data in the provided Axes/Subplot instance.
Expand All @@ -276,7 +282,7 @@ def plot(self, axes):
if len(datum) == 1 and not isinstance(datum[0], SpikeTrain):
datum = datum[0]

if isinstance(datum, list):
if isinstance(datum, (list, SpikeTrainList)):
self.__plot_list(axes, datum, label, properties)
# AnalogSignal is also a ndarray, but data format different!
# We import them as a single name here
Expand All @@ -293,13 +299,16 @@ def plot(self, axes):
f"consider using pyNN.utility.plotting")

@staticmethod
def __plot_list(axes, datum, label, properties):
def __plot_list(
axes: Axes, datum: Union[List[SpikeTrain], SpikeTrainList],
label: str, properties: Dict[str, Any]) -> None:
if not isinstance(datum[0], SpikeTrain):
raise ValueError(f"Can't handle lists of type {type(datum)}")
plot_spiketrains(axes, datum, label=label, **properties)

@staticmethod
def __plot_array(axes, datum, label, properties):
def __plot_array(axes: Axes, datum: NDArray, label: str,
properties: Dict[str, Any]) -> None:
if len(datum[0]) == 2:
plot_spikes_numpy(axes, datum, label=label, **properties)
elif len(datum[0]) == 3:
Expand All @@ -309,7 +318,8 @@ def __plot_array(axes, datum, label, properties):
f"Can't handle ndarray with {len(datum[0])} columns")

@staticmethod
def __plot_block(axes, datum, label, properties):
def __plot_block(axes: Axes, datum: Block, label: str,
properties: Dict[str, Any]) -> None:
if "run" in properties:
run = int(properties.pop("run"))
if len(datum.segments) <= run:
Expand Down
38 changes: 38 additions & 0 deletions unittests/test_pop_views_assembly/manual_plot_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import matplotlib.pyplot as plt
import pyNN.utility.plotting as plot
from spynnaker.pyNN.utilities.neo_buffer_database import NeoBufferDatabase
from spynnaker.spynnaker_plotting import SpynnakerPanel


my_dir = os.path.dirname(os.path.abspath(__file__))
my_buffer = os.path.join(my_dir, "all_data.sqlite3")
with NeoBufferDatabase(my_buffer) as db:
pop = db.get_population("pop_1")
sneo = pop.get_data("spikes")
neo = pop.get_data(["spikes", "v"])
vneo = pop.get_data("v")
spikes = neo.segments[0].spiketrains
v = neo.segments[0].filter(name='v')[0]
v_matrix = pop.spinnaker_get_data("v")
print(v)

plot.Figure(
# plot spikes (or in this case spike)
plot.Panel(spikes, yticks=True, markersize=5, xlim=(0, 35)),
SpynnakerPanel(spikes, yticks=True, xticks=True, markersize=4),
SpynnakerPanel(sneo, yticks=True, xticks=True, markersize=4),
SpynnakerPanel(sneo.segments[0], yticks=True, xticks=True, markersize=4),
# plot voltage for first ([0]) neuron
SpynnakerPanel(neo, yticks=True, xticks=True, markersize=4, name="spikes"),
plot.Panel(v, ylabel="Membrane potential (mV)",
data_labels=[pop.label], yticks=True, xlim=(0, 35)),
SpynnakerPanel(v, yticks=True, xticks=True, markersize=4),
SpynnakerPanel(vneo, yticks=True, xticks=True, markersize=4),
SpynnakerPanel(vneo.segments[0], yticks=True, xticks=True, markersize=4),
SpynnakerPanel(v_matrix, yticks=True, xticks=True, markersize=4),
SpynnakerPanel(neo, yticks=True, xticks=True, markersize=4, name="v"),
title="Simple Example",
annotations=f"Simulated with "
)
plt.show()

0 comments on commit fc12722

Please sign in to comment.