Skip to content

Commit

Permalink
Fix BatchKey
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jan 8, 2024
1 parent d171ab0 commit 6142796
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def forward(self, x):
sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len]
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
if self.add_image_embedding_channel:
id = x[BatchKey.sensor_id][:, 0].int()
id = x[BatchKey.wind_id][:, 0].int()
sat_data = self.sat_embed(sat_data, id)
modes["sat"] = self.sat_encoder(sat_data)

Expand Down Expand Up @@ -264,7 +264,7 @@ def forward(self, x):

# ********************** Embedding of GSP ID ********************
if self.embedding_dim:
id = x[BatchKey.sensor_id][:, 0].int()
id = x[BatchKey.wind_id][:, 0].int()
id_embedding = self.embed(id)
modes["id"] = id_embedding

Expand Down

0 comments on commit 6142796

Please sign in to comment.