Skip to content

Commit

Permalink
Merge pull request matplotlib#605 from alexpvpmindustry/master
Browse files Browse the repository at this point in the history
API for adding labels: `mpf.make_addplot(..., label="myLabel")`
  • Loading branch information
DanielGoldfarb authored Aug 1, 2023
2 parents 46dcc89 + cbda0af commit 50d7eb3
Show file tree
Hide file tree
Showing 4 changed files with 626 additions and 9 deletions.
576 changes: 576 additions & 0 deletions examples/addplot_legends.ipynb

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib as mpl
import warnings


def _check_and_prepare_data(data, config):
'''
Check and Prepare the data input:
Expand Down Expand Up @@ -94,6 +95,19 @@ def _check_and_prepare_data(data, config):

return dates, opens, highs, lows, closes, volumes


def _label_validator(label_value):
''' Validates the input of [legend] label for added plots.
label_value may be a str or a sequence of str.
'''
if isinstance(label_value,str):
return True
if isinstance(label_value,(list,tuple,np.ndarray)):
if all([isinstance(v,str) for v in label_value]):
return True
return False


def _get_valid_plot_types(plottype=None):

_alias_types = {
Expand Down
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_info = (0, 12, 9, 'beta', 9)
version_info = (0, 12, 10, 'beta', 0)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
43 changes: 35 additions & 8 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from mplfinance import _styles

from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
Expand Down Expand Up @@ -765,6 +765,8 @@ def plot( data, **kwargs ):

elif not _list_of_dict(addplot):
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))

contains_legend_label=[] # a list of axes that contains legend labels

for apdict in addplot:

Expand All @@ -788,10 +790,28 @@ def plot( data, **kwargs ):
else:
havedf = False # must be a single series or array
apdata = [apdata,] # make it iterable
if havedf and apdict['label']:
if not isinstance(apdict['label'],(list,tuple,np.ndarray)):
nlabels = 1
else:
nlabels = len(apdict['label'])
ncolumns = len(apdata.columns)
#print('nlabels=',nlabels,'ncolumns=',ncolumns)
if nlabels < ncolumns:
warnings.warn('\n =======================================\n'+
' addplot MISMATCH between data and labels:\n'+
' have '+str(ncolumns)+' columns to plot \n'+
' BUT '+str(nlabels)+' labels for them.\n')
colcount = 0
for column in apdata:
ydata = apdata.loc[:,column] if havedf else column
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount)
_addplot_apply_supplements(ax,apdict,xdates)
colcount += 1
if apdict['label']: # not supported for aptype == 'ohlc' or 'candle'
contains_legend_label.append(ax)
for ax in set(contains_legend_label): # there might be duplicates
ax.legend()

# fill_between is NOT supported for external_axes_mode
# (caller can easily call ax.fill_between() themselves).
Expand Down Expand Up @@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
ax.autoscale_view()
return ax

def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
def _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount):
external_axes_mode = apdict['ax'] is not None
if not external_axes_mode:
secondary_y = False
Expand All @@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
ax = apdict['ax']

aptype = apdict['type']
if isinstance(apdict['label'],(list,tuple,np.ndarray)):
label = apdict['label'][colcount]
else: # isinstance(...,str)
label = apdict['label']
if aptype == 'scatter':
size = apdict['markersize']
mark = apdict['marker']
Expand All @@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):

if isinstance(mark,(list,tuple,np.ndarray)):
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
elif aptype == 'bar':
width = 0.8 if apdict['width'] is None else apdict['width']
bottom = apdict['bottom']
color = apdict['color']
alpha = apdict['alpha']
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
elif aptype == 'line':
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
elif aptype == 'step':
stepwhere = apdict['stepwhere']
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
else:
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')

Expand Down Expand Up @@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
'fill_between': { 'Default' : None, # added by Wen
'Description' : " fill region",
'Validator' : _fill_between_validator },
'label' : { 'Default' : None,
'Description' : 'Label for the added plot. One per added plot.',
'Validator' : _label_validator },

}

Expand Down

0 comments on commit 50d7eb3

Please sign in to comment.