From 134b8519a4716f785c696c8d2d5349fb0ec5dbcf Mon Sep 17 00:00:00 2001 From: adityasuthar20 Date: Wed, 22 Jan 2025 22:24:41 +0530 Subject: [PATCH 1/3] Stop the _adapt_batch() from changing the batch in-place --- pvnet/models/multimodal/multimodal_base.py | 27 +++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 0a710ebf..f377f16a 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -1,4 +1,6 @@ """Base model class for multimodal model and unimodal teacher""" +import copy + from torchvision.transforms.functional import center_crop from pvnet.models.base_model import BaseModel @@ -8,26 +10,29 @@ class MultimodalBaseModel(BaseModel): """Base model class for multimodal model and unimodal teacher""" def _adapt_batch(self, batch): - """Slice batches into appropriate shapes for model + """Slice batches into appropriate shapes for model. + Returns a new batch dictionary with adapted data, leaving the original batch unchanged. We make some specific assumptions about the original batch and the derived sliced batch: - We are only limiting the future projections. I.e. we are never shrinking the batch from the left hand side of the time axis, only slicing it from the right - We are only shrinking the spatial crop of the satellite and NWP data """ + # Create a copy of the batch to avoid modifying the original + new_batch = {key: copy.deepcopy(value) for key, value in batch.items()} - if "gsp" in batch.keys(): + if "gsp" in new_batch.keys(): # Slice off the end of the GSP data gsp_len = self.forecast_len + self.history_len + 1 - batch["gsp"] = batch["gsp"][:, :gsp_len] - batch["gsp_time_utc"] = batch["gsp_time_utc"][:, :gsp_len] + new_batch["gsp"] = new_batch["gsp"][:, :gsp_len] + new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len] if self.include_sat: # Slice off the end of the satellite data and spatially crop # Shape: batch_size, seq_length, channel, height, width - batch["satellite_actual"] = center_crop( - batch["satellite_actual"][:, : self.sat_sequence_len], + new_batch["satellite_actual"] = center_crop( + new_batch["satellite_actual"][:, : self.sat_sequence_len], output_size=self.sat_encoder.image_size_pixels, ) @@ -35,8 +40,8 @@ def _adapt_batch(self, batch): # Slice off the end of the NWP data and spatially crop for nwp_source in self.nwp_encoders_dict: # shape: batch_size, seq_len, n_chans, height, width - batch["nwp"][nwp_source]["nwp"] = center_crop( - batch["nwp"][nwp_source]["nwp"], + new_batch["nwp"][nwp_source]["nwp"] = center_crop( + new_batch["nwp"][nwp_source]["nwp"], output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels, )[:, : self.nwp_encoders_dict[nwp_source].sequence_length] @@ -44,8 +49,8 @@ def _adapt_batch(self, batch): # Slice off the end of the solar coords data for s in ["solar_azimuth", "solar_elevation"]: key = f"{self._target_key}_{s}" - if key in batch.keys(): + if key in new_batch.keys(): sun_len = self.forecast_len + self.history_len + 1 - batch[key] = batch[key][:, :sun_len] + new_batch[key] = new_batch[key][:, :sun_len] - return batch + return new_batch From 381311222b828dcfbb3a4f481e4616316d9d6d57 Mon Sep 17 00:00:00 2001 From: adityasuthar20 Date: Thu, 13 Feb 2025 17:21:22 +0530 Subject: [PATCH 2/3] deepcopy dict items --- pvnet/models/multimodal/multimodal_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index f377f16a..4676daf9 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -2,7 +2,7 @@ import copy from torchvision.transforms.functional import center_crop - +import copy from pvnet.models.base_model import BaseModel From 96408c90ab601c68b3ee22e4ca308ea162b7d0b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Feb 2025 11:51:57 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/multimodal/multimodal_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 4676daf9..f377f16a 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -2,7 +2,7 @@ import copy from torchvision.transforms.functional import center_crop -import copy + from pvnet.models.base_model import BaseModel