From bc9db124248fcec967decc437f32e529c8735f93 Mon Sep 17 00:00:00 2001 From: adityasuthar20 Date: Wed, 22 Jan 2025 22:24:41 +0530 Subject: [PATCH 1/4] Stop the _adapt_batch() from changing the batch in-place --- pvnet/models/multimodal/multimodal_base.py | 23 ++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index dbab8556..4d6b0102 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -10,6 +10,7 @@ class MultimodalBaseModel(BaseModel): def _adapt_batch(self, batch): """Slice batches into appropriate shapes for model + Returns a new batch dictionary with adapted data, leaving the original batch unchanged. We make some specific assumptions about the original batch and the derived sliced batch: - We are only limiting the future projections. I.e. we are never shrinking the batch from @@ -17,18 +18,20 @@ def _adapt_batch(self, batch): - We are only shrinking the spatial crop of the satellite and NWP data """ + # Create a copy of the batch to avoid modifying the original + new_batch = {key: value.copy() for key, value in batch.items()} - if BatchKey.gsp in batch.keys(): + if BatchKey.gsp in new_batch.keys(): # Slice off the end of the GSP data gsp_len = self.forecast_len + self.history_len + 1 - batch[BatchKey.gsp] = batch[BatchKey.gsp][:, :gsp_len] - batch[BatchKey.gsp_time_utc] = batch[BatchKey.gsp_time_utc][:, :gsp_len] + new_batch[BatchKey.gsp] = new_batch[BatchKey.gsp][:, :gsp_len] + new_batch[BatchKey.gsp_time_utc] = new_batch[BatchKey.gsp_time_utc][:, :gsp_len] if self.include_sat: # Slice off the end of the satellite data and spatially crop # Shape: batch_size, seq_length, channel, height, width - batch[BatchKey.satellite_actual] = center_crop( - batch[BatchKey.satellite_actual][:, : self.sat_sequence_len], + new_batch[BatchKey.satellite_actual] = center_crop( + new_batch[BatchKey.satellite_actual][:, : self.sat_sequence_len], output_size=self.sat_encoder.image_size_pixels, ) @@ -36,8 +39,8 @@ def _adapt_batch(self, batch): # Slice off the end of the NWP data and spatially crop for nwp_source in self.nwp_encoders_dict: # shape: batch_size, seq_len, n_chans, height, width - batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp] = center_crop( - batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp], + new_batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp] = center_crop( + new_batch[BatchKey.nwp][nwp_source][NWPBatchKey.nwp], output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels, )[:, : self.nwp_encoders_dict[nwp_source].sequence_length] @@ -45,8 +48,8 @@ def _adapt_batch(self, batch): # Slice off the end of the solar coords data for s in ["solar_azimuth", "solar_elevation"]: key = BatchKey[f"{self._target_key_name}_{s}"] - if key in batch.keys(): + if key in new_batch.keys(): sun_len = self.forecast_len + self.history_len + 1 - batch[key] = batch[key][:, :sun_len] + new_batch[key] = new_batch[key][:, :sun_len] - return batch + return new_batch From 802042679f04591ed504511f23812307665552f9 Mon Sep 17 00:00:00 2001 From: adityasuthar20 Date: Thu, 23 Jan 2025 09:46:46 +0530 Subject: [PATCH 2/4] fix ruff formatting issue --- pvnet/models/multimodal/multimodal_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 4d6b0102..ed579b43 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -9,9 +9,9 @@ class MultimodalBaseModel(BaseModel): """Base model class for multimodal model and unimodal teacher""" def _adapt_batch(self, batch): - """Slice batches into appropriate shapes for model - Returns a new batch dictionary with adapted data, leaving the original batch unchanged. + """Slice batches into appropriate shapes for model. + Returns a new batch dictionary with adapted data, leaving the original batch unchanged. We make some specific assumptions about the original batch and the derived sliced batch: - We are only limiting the future projections. I.e. we are never shrinking the batch from the left hand side of the time axis, only slicing it from the right From 26238c51e97862fa9c93c3c78bda00dc63c2223d Mon Sep 17 00:00:00 2001 From: adityasuthar20 Date: Thu, 13 Feb 2025 17:21:22 +0530 Subject: [PATCH 3/4] deepcopy dict items --- pvnet/models/multimodal/multimodal_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index ed579b43..81a705f4 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -1,7 +1,7 @@ """Base model class for multimodal model and unimodal teacher""" from ocf_datapipes.batch import BatchKey, NWPBatchKey from torchvision.transforms.functional import center_crop - +import copy from pvnet.models.base_model import BaseModel @@ -19,7 +19,7 @@ def _adapt_batch(self, batch): """ # Create a copy of the batch to avoid modifying the original - new_batch = {key: value.copy() for key, value in batch.items()} + new_batch = {key: copy.deepcopy(value) for key, value in batch.items()} if BatchKey.gsp in new_batch.keys(): # Slice off the end of the GSP data From 98520dfc73d182fc560364651aa4f230b1543b54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Feb 2025 11:51:57 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/multimodal/multimodal_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 81a705f4..f2cbebd7 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -1,7 +1,9 @@ """Base model class for multimodal model and unimodal teacher""" +import copy + from ocf_datapipes.batch import BatchKey, NWPBatchKey from torchvision.transforms.functional import center_crop -import copy + from pvnet.models.base_model import BaseModel