Skip to content

Commit

Permalink
Simplify AutoBNN plotting util
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618963756
  • Loading branch information
ursk authored and tensorflower-gardener committed Mar 25, 2024
1 parent 49838c8 commit cd5ceb7
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions tensorflow_probability/python/experimental/autobnn/training_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def plot_results(
y_train: Optional[jax.Array] = None,
diagnostics: Optional[Dict[str, jax.Array]] = None,
log_scale: bool = False,
show_particles: bool = True,
left_limit: int = 24*7*2,
right_limit: int = 24*7*2,
) -> plt.Figure:
Expand All @@ -307,44 +308,41 @@ def plot_results(
else:
fig, res_ax = plt.subplots(figsize=(16, 3), constrained_layout=True)

for idx, p in enumerate(preds):
res_ax.plot(
dates_preds,
p,
'k-',
alpha=0.1,
label='Particle predictions' if idx == 0 else None,
)
if show_particles:
for idx, p in enumerate(preds):
res_ax.plot(
dates_preds,
p,
'k-',
alpha=0.1,
label='Particle predictions' if idx == 0 else None,
)

color = 'steelblue'
if p50 is not None:
res_ax.plot(
dates_preds, p50, '-', lw=5, color=color, label='Prediction')
dates_preds, p50, '-', lw=2.5, color=color, label='Prediction')
if p97_5 is not None and p2_5 is not None:
res_ax.plot(dates_preds, p97_5, '-',
lw=3, color=color, label='Upper/lower bound')
res_ax.plot(dates_preds, p2_5, '-', lw=3, color=color)
lw=1.5, color=color, label='Upper/lower bound')
res_ax.plot(dates_preds, p2_5, '-', lw=1.5, color=color)
res_ax.fill_between(
dates_preds, p2_5, p97_5, color=color, alpha=0.2
)

data_kwargs = {'ms': 7, 'mec': 'k', 'mew': 2}
if dates_train is not None and y_train is not None:
res_ax.plot(
dates_train,
y_train,
'o',
mfc='red',
label='Train data',
**data_kwargs)
'k-',
label='Ground truth data',
)
if dates_test is not None and y_test is not None:
res_ax.plot(
dates_test,
y_test,
'o',
mfc='green',
label='Test data',
**data_kwargs)
'k-',
)
res_ax.set_title('Predictions')
res_ax.legend()
left_limit = min(len(dates_preds) - len(dates_test), left_limit)
Expand All @@ -353,6 +351,7 @@ def plot_results(
# TODO(ursk): Rather than modifying xlim, don't plot invisible points at all.
res_ax.set_xlim([dates_preds[first_test_point-left_limit],
dates_preds[first_test_point+right_limit]])
res_ax.axvline(dates_preds[first_test_point], linestyle='--')
return fig


Expand Down

0 comments on commit cd5ceb7

Please sign in to comment.