Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Allow multiple names for vector indicators (#382) #980

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions backtesting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,23 @@ def __eq__(self, other):
colors = value._opts['color']
colors = colors and cycle(_as_list(colors)) or (
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
legend_label = LegendStr(value.name)
for j, arr in enumerate(value, 1):

if isinstance(value.name, str):
tooltip_label = value.name
if len(value) == 1:
legend_labels = [LegendStr(value.name)]
else:
legend_labels = [
LegendStr(f"{value.name}[{i}]")
for i in range(len(value))
]
else:
tooltip_label = ", ".join(value.name)
legend_labels = [LegendStr(item) for item in value.name]

for j, arr in enumerate(value):
color = next(colors)
source_name = f'{legend_label}_{i}_{j}'
source_name = f'{legend_labels[j]}_{i}_{j}'
kernc marked this conversation as resolved.
Show resolved Hide resolved
if arr.dtype == bool:
arr = arr.astype(int)
source.add(arr, source_name)
Expand All @@ -550,24 +563,24 @@ def __eq__(self, other):
if is_scatter:
fig.scatter(
'index', source_name, source=source,
legend_label=legend_label, color=color,
legend_label=legend_labels[j], color=color,
line_color='black', fill_alpha=.8,
marker='circle', radius=BAR_WIDTH / 2 * 1.5)
else:
fig.line(
'index', source_name, source=source,
legend_label=legend_label, line_color=color,
legend_label=legend_labels[j], line_color=color,
line_width=1.3)
else:
if is_scatter:
r = fig.scatter(
'index', source_name, source=source,
legend_label=LegendStr(legend_label), color=color,
legend_label=legend_labels[j], color=color,
marker='circle', radius=BAR_WIDTH / 2 * .9)
else:
r = fig.line(
'index', source_name, source=source,
legend_label=LegendStr(legend_label), line_color=color,
legend_label=legend_labels[j], line_color=color,
line_width=1.3)
# Add dashed centerline just because
mean = float(pd.Series(arr).mean())
Expand All @@ -578,9 +591,9 @@ def __eq__(self, other):
line_color='#666666', line_dash='dashed',
line_width=.5))
if is_overlay:
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
else:
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
# If the sole indicator line on this figure,
# have the legend only contain text without the glyph
if len(value) == 1:
Expand Down
21 changes: 18 additions & 3 deletions backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def I(self, # noqa: E743
same length as `backtesting.backtesting.Strategy.data`.

In the plot legend, the indicator is labeled with
function name, unless `name` overrides it.
function name, unless `name` overrides it. If `func` returns
multiple arrays, `name` can be a sequence of strings, and
its size must agree with the number of arrays returned.

If `plot` is `True`, the indicator is plotted on the resulting
`backtesting.backtesting.Backtest.plot`.
Expand All @@ -115,13 +117,21 @@ def I(self, # noqa: E743
def init():
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
"""
def _format_name(name: str) -> str:
return name.format(*map(_as_str, args),
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))

if name is None:
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
func_name = _as_str(func)
name = (f'{func_name}({params})' if params else f'{func_name}')
elif isinstance(name, str):
name = _format_name(name)
elif try_(lambda: all(isinstance(item, str) for item in name), False):
name = [_format_name(item) for item in name]
else:
name = name.format(*map(_as_str, args),
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or '
'`Sequence[str]`')

try:
value = func(*args, **kwargs)
Expand All @@ -139,6 +149,11 @@ def init():
if is_arraylike and np.argmax(value.shape) == 0:
value = value.T

if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)):
raise ValueError(
f'Length of `name=` ({len(name)}) must agree with the number '
f'of arrays the indicator returns ({value.shape[0]}).')

if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
raise ValueError(
'Indicators must return (optionally a tuple of) numpy.arrays of same '
Expand Down
31 changes: 31 additions & 0 deletions backtesting/test/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,37 @@ def test_resample(self):
# Give browser time to open before tempfile is removed
time.sleep(1)

def test_indicator_name(self):
test_self = self

class S(Strategy):
def init(self):
def _SMA():
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)

test_self.assertRaises(TypeError, self.I, _SMA, name=42)
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", ))
test_self.assertRaises(
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))

for overlay in (True, False):
self.I(SMA, self.data.Close, 5, overlay=overlay)
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay)
self.I(_SMA, overlay=overlay)
self.I(_SMA, name="My SMA", overlay=overlay)
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)

def next(self):
pass

bt = Backtest(GOOG, S)
bt.run()
with _tempfile() as f:
bt.plot(filename=f,
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
open_browser=False)

def test_indicator_color(self):
class S(Strategy):
def init(self):
Expand Down