Skip to content

Commit

Permalink
Remove prints
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Dec 5, 2023
1 parent 149180d commit 76a79bf
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 18 deletions.
8 changes: 0 additions & 8 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def num_output_features(self):
out_features = self.forecast_len_30 * len(self.output_quantiles)
else:
out_features = self.forecast_len_30
print(f"num_output_features: {out_features}")
return out_features

def _quantiles_to_prediction(self, y_quantiles):
Expand Down Expand Up @@ -326,8 +325,6 @@ def _calculate_qauntile_loss(self, y_quantiles, y):
"""
# calculate quantile loss
losses = []
print(f"y_quantiles.shape: {y_quantiles.shape}")
print(f"y.shape: {y.shape}")
for i, q in enumerate(self.output_quantiles):
errors = y - y_quantiles[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
Expand Down Expand Up @@ -380,7 +377,6 @@ def _calculate_val_losses(self, y, y_hat):

# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)
# print(f"{y_hat.shape=}, {y.shape=}")
common_metrics_each_step = {"mae": torch.mean(torch.abs(y_hat - y), dim=0),
"rmse": torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0))}
# common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
Expand Down Expand Up @@ -451,10 +447,6 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
print(f"y_hat.shape: {y_hat.shape}")
print(f"{batch[self._target_key].shape=}")
print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape=}")
print(f"{self.forecast_len_30=}")
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, 0, -self.forecast_len_30 :]

Expand Down
2 changes: 0 additions & 2 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,9 @@ def forward(self, x):
modes["sun"] = sun

out = self.output_network(modes)
print(f"out.shape: {out.shape}")

if self.use_quantile_regression:
# Shape: batch_size, seq_length * num_quantiles
out = out.reshape(out.shape[0], self.forecast_len_30, len(self.output_quantiles))
print(f"out.shape: {out.shape}")

return out
8 changes: 0 additions & 8 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,25 +328,20 @@ def __init__(
def _encode_query(self, x):
# Select the first one
gsp_ids = x[BatchKey.sensor_id][:, 0].squeeze().int()
print(f"{gsp_ids.shape=}")
query = self.sensor_id_embedding(gsp_ids).unsqueeze(1)
print(f"{query.shape=}")
return query

def _encode_key(self, x):
# Shape: [batch size, sequence length, PV site]
sensor_site_seqs = x[BatchKey.sensor][:, : self.sequence_length].float()
batch_size = sensor_site_seqs.shape[0]
print(f"{sensor_site_seqs.shape=}")

# Sensor ID embeddings are the same for each sample
sensor_id_embed = torch.tile(self.pv_id_embedding(self._sensor_ids), (batch_size, 1, 1))
print(f"{sensor_id_embed.shape=}")
# Each concated (Sensor sequence, Sensor ID embedding) is processed with encoder
x_seq_in = torch.cat((sensor_site_seqs.swapaxes(1, 2), sensor_id_embed), dim=2).flatten(
0, 1
)
print(f"{x_seq_in.shape=}")
key = self._key_encoder(x_seq_in)

# Reshape to [batch size, PV site, kdim]
Expand Down Expand Up @@ -381,9 +376,6 @@ def _attention_forward(self, x, average_attn_weights=True):
query = self._encode_query(x)
key = self._encode_key(x)
value = self._encode_value(x)
print(f"{query.shape=}")
print(f"{key.shape=}")
print(f"{value.shape=}")

attn_output, attn_weights = self.multihead_attn(
query, key, value, average_attn_weights=average_attn_weights
Expand Down

0 comments on commit 76a79bf

Please sign in to comment.