Skip to content

Commit

Permalink
Update plotting and turn back on
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Dec 11, 2023
1 parent e3db3ea commit 40a3812
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
18 changes: 9 additions & 9 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):

# We only create the figure every 8 log steps
# This was reduced as it was creating figures too often
#if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
# fig = plot_batch_forecasts(batch, y_hat, batch_idx, quantiles=self.output_quantiles)
# fig.savefig("latest_logged_train_batch.png")
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
fig = plot_batch_forecasts(batch, y_hat, batch_idx, quantiles=self.output_quantiles)
fig.savefig("latest_logged_train_batch.png")

def training_step(self, batch, batch_idx):
"""Run training step"""
Expand Down Expand Up @@ -477,13 +477,13 @@ def validation_step(self, batch: dict, batch_idx):
y_hat = self._val_y_hats.flush()
batch = self._val_batches.flush()

#fig = plot_batch_forecasts(batch, y_hat, quantiles=self.output_quantiles)
fig = plot_batch_forecasts(batch, y_hat, quantiles=self.output_quantiles)

#self.logger.experiment.log(
# {
# f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
# }
#)
self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
}
)
del self._val_y_hats
del self._val_batches

Expand Down
21 changes: 13 additions & 8 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,23 @@ def finish(
wandb.finish()


def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None):
def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_plot: str = "gsp"):
"""Plot a batch of data and the forecast from that batch"""

def _get_numpy(key):
return batch[key].cpu().numpy().squeeze()
y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.sensor
y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.sensor_id
t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.sensor_t0_idx
time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.sensor_time_utc

y = batch[BatchKey.gsp].cpu().numpy()
y = batch[y_key].cpu().numpy()
y_hat = y_hat.cpu().numpy()
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
t0_idx = batch[t0_idx_key]
plotting_name = "GSP" if key_to_plot == "gsp" else "Sensor"

gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze()
t0_idx = batch[BatchKey.gsp_t0_idx]

times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = [pd.to_datetime(t) for t in times_utc]

len(times_utc[0]) - t0_idx - 1
Expand Down Expand Up @@ -293,10 +297,11 @@ def _get_numpy(key):
for ax in axes[-1, :]:
ax.set_xlabel("Time (hour of day)")


if batch_idx is not None:
title = f"Normed GSP output : batch_idx={batch_idx}"
title = f"Normed {plotting_name} output : batch_idx={batch_idx}"
else:
title = "Normed GSP output"
title = f"Normed {plotting_name} output"
plt.suptitle(title)
plt.tight_layout()

Expand Down

0 comments on commit 40a3812

Please sign in to comment.