diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index 51a4af84f..e668d75fc 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -13,7 +13,10 @@ from one.converters import ConversionMixin import one.alf.exceptions as alferr from one.api import ONE -from one.util import datasets2records +try: + from one.util import datasets2records +except ImportError: + from one.converters import datasets2records from iblutil.util import ensure_list import ibllib diff --git a/ibllib/plots/misc.py b/ibllib/plots/misc.py index bcc473be7..288d110d5 100644 --- a/ibllib/plots/misc.py +++ b/ibllib/plots/misc.py @@ -1,5 +1,7 @@ #!/usr/bin/env python # -*- coding:utf-8 -*- +from math import pi + import numpy as np import matplotlib.pyplot as plt import scipy @@ -274,7 +276,53 @@ def color_cycle(ind=None): return tuple(c[ind % c.shape[0], :]) -if __name__ == "__main__": - w = np.random.rand(500, 40) - 0.5 - wiggle(w, fs=30000) - Traces(w, fs=30000, color='r') +def starplot(labels, radii, ticks=None, ax=None, ylim=None, color=None, title=None): + """ + Function to create a star plot (also known as a spider plot, polar plot, or radar chart). + + Parameters: + labels (list): A list of labels for the variables to be plotted along the axes. + radii (numpy array): The values to be plotted for each variable. + ticks (numpy array, optional): A list of values to be used for the radial ticks. + If None, 5 ticks will be created between the minimum and maximum values of radii. + ax (matplotlib.axes._subplots.PolarAxesSubplot, optional): A polar axis object to plot on. + If None, a new figure and axis will be created. + ylim (tuple, optional): A tuple specifying the upper and lower limits of the y-axis. + If None, the limits will be set to the minimum and maximum values of radii. + color (str, optional): A string specifying the color of the plot. + If None, the color will be determined by the current matplotlib color cycle. + title (str, optional): A string specifying the title of the plot. + If None, no title will be displayed. + + Returns: + ax (matplotlib.axes._subplots.PolarAxesSubplot): The polar axis object containing the plot. + """ + + # What will be the angle of each axis in the plot? (we divide the plot / number of variable) + angles = [n / float(radii.size) * 2 * pi for n in range(radii.size)] + angles += angles[:1] + + if ax is None: + # Initialise the spider plot + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, polar=True) + # If you want the first axis to be on top: + ax.set_theta_offset(pi / 2) + ax.set_theta_direction(-1) + # Draw one axe per variable + add labels + plt.xticks(angles[:-1], labels) + # Draw ylabels + ax.set_rlabel_position(0) + if ylim is None: + ylim = (0, np.max(radii)) + if ticks is None: + ticks = np.linspace(ylim[0], ylim[1], 5) + plt.yticks(ticks, [f'{t:2.2f}' for t in ticks], color="grey", size=7) + plt.ylim(ylim) + + r = np.r_[radii, radii[0]] + p = ax.plot(angles, r, linewidth=1, linestyle='solid', label="group A", color=color) + ax.fill(angles, r, alpha=0.1, color=p[0].get_color()) + if title is not None: + ax.set_title(title) + return ax diff --git a/ibllib/tests/test_plots.py b/ibllib/tests/test_plots.py index 0b620564a..e2962fa98 100644 --- a/ibllib/tests/test_plots.py +++ b/ibllib/tests/test_plots.py @@ -1,19 +1,23 @@ -import unittest +from pathlib import Path import tempfile +import unittest import uuid -from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np from PIL import Image from urllib.parse import urlparse from one.api import ONE from one.webclient import http_download_file +import ibllib.plots.misc from ibllib.tests import TEST_DB from ibllib.tests.fixtures.utils import register_new_session from ibllib.plots.snapshot import Snapshot from ibllib.plots.figures import dlc_qc_plot + WIDTH, HEIGHT = 1000, 100 @@ -163,3 +167,19 @@ def test_without_inputs(self): # fig.savefig(fig_path) # with Image.open(fig_path) as im: # self.assertEqual(im.size, (1700, 1000)) + + +class TestMiscPlot(unittest.TestCase): + + def test_star_plot(self): + r = np.random.rand(6) + ax = ibllib.plots.misc.starplot(['a', 'b', 'c', 'd', 'e', 'f'], r, ylim=[0, 1]) + r = np.random.rand(6) + ibllib.plots.misc.starplot(['a', 'b', 'c', 'd', 'e', 'f'], r, ax=ax, color='r') + plt.close('all') + + def test_wiggle(self): + w = np.random.rand(500, 40) - 0.5 + ibllib.plots.misc.wiggle(w, fs=30000) + ibllib.plots.misc.Traces(w, fs=30000, color='r') + plt.close('all')