diff --git a/src/emcpy/plots/create_plots.py b/src/emcpy/plots/create_plots.py index 3b11601b..1c74ae44 100644 --- a/src/emcpy/plots/create_plots.py +++ b/src/emcpy/plots/create_plots.py @@ -260,6 +260,7 @@ def create_figure(self): 'histogram': self._histogram, 'density': self._density, 'line_plot': self._lineplot, + 'gridded_plot': self._gridded, 'vertical_line': self._verticalline, 'horizontal_line': self._horizontalline, 'horizontal_span': self._horizontalspan, @@ -538,6 +539,20 @@ def _scatter(self, plotobj, ax): plotobj.linear_regression['color'] = inputs['color'] ax.plot(plotobj.x, y_pred, label=label, **plotobj.linear_regression) + def _gridded(self, plotobj, ax): + """ + Uses Gridded object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'x', 'y', 'z', + 'colorbar'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + cs = ax.pcolormesh(plotobj.x, plotobj.y, + plotobj.z, **inputs) + + if plotobj.colorbar: + self.cs = cs + def _skewt(self, plotobj, ax): """ Creates a skewt-logp profile plot on axis. diff --git a/src/emcpy/plots/plots.py b/src/emcpy/plots/plots.py index e4f9d1e1..bae84d21 100644 --- a/src/emcpy/plots/plots.py +++ b/src/emcpy/plots/plots.py @@ -148,6 +148,33 @@ def __init__(self, x, y): self.label = None +class GriddedPlot: + + def __init__(self, x, y, z): + """ + Constructor for GriddedPlot. + Args: + x : (array type) + y : (array type) + z : (array type) + """ + super().__init__() + self.plottype = 'gridded_plot' + + self.x = x + self.y = y + self.z = z + + self.cmap = 'viridis' + self.norm = None + self.vmin = None + self.vmax = None + self.edgecolors = None + self.shading = 'auto' + self.alpha = None + self.colorbar = True + + class VerticalLine: def __init__(self, x): diff --git a/src/tests/test_plots.py b/src/tests/test_plots.py index aef722a3..3bf47734 100644 --- a/src/tests/test_plots.py +++ b/src/tests/test_plots.py @@ -1,9 +1,10 @@ import numpy as np +from scipy.ndimage.filters import gaussian_filter import matplotlib.pyplot as plt from emcpy.plots.plots import LinePlot, VerticalLine,\ Histogram, Density, Scatter, HorizontalLine, BarPlot, \ - HorizontalBar, HorizontalSpan, SkewT + GriddedPlot, HorizontalBar, HorizontalSpan, SkewT from emcpy.plots.create_plots import CreatePlot, CreateFigure @@ -248,6 +249,26 @@ def test_bar_plot(): fig.save_figure('test_bar_plot.png') +def test_gridded_plot(): + # Create gridded plot + + x, y, z = _getGriddedData() + + gp = GriddedPlot(x, y, z) + gp.cmap = 'plasma' + + plot1 = CreatePlot() + plot1.plot_layers = [gp] + plot1.add_xlabel(xlabel='X Axis Label') + plot1.add_ylabel(ylabel='Y Axis Label') + plot1.add_title('Test Gridded Plot') + + fig = CreateFigure() + fig.plot_list = [plot1] + fig.create_figure() + fig.save_figure('test_gridded_plot.png') + + def test_horizontal_bar_plot(): # Create horizontal bar plot @@ -522,6 +543,17 @@ def _getBarData(): return x_pos, heights, variance +def _getGriddedData(): + # generate test data for gridded data + + x = np.linspace(0, 1, 51) + y = np.linspace(0, 1, 51) + r = np.random.RandomState(25) + z = gaussian_filter(r.random_sample([50, 50]), sigma=5, mode='wrap') + + return x, y, z + + def _getSkewTData(): # use data for skew-t log-p plot from io import StringIO @@ -616,6 +648,7 @@ def main(): test_histogram_plot() test_scatter_plot() test_bar_plot() + test_gridded_plot() test_horizontal_bar_plot() test_multi_subplot() test_HorizontalSpan()