From ec922df3306221c354a58abd5154b5cd6bec3b9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 10:46:39 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../configuration/gcp_configuration.yaml | 370 +++++++++++++++++- configs/model/multimodal.yaml | 4 +- pvnet/models/multimodal/multimodal.py | 54 ++- .../multimodal/site_encoders/basic_blocks.py | 3 +- .../multimodal/site_encoders/encoders.py | 155 ++++---- pvnet/optimizers.py | 65 ++- 6 files changed, 486 insertions(+), 165 deletions(-) diff --git a/configs/datamodule/configuration/gcp_configuration.yaml b/configs/datamodule/configuration/gcp_configuration.yaml index 31e98370..1d6dd67b 100644 --- a/configs/datamodule/configuration/gcp_configuration.yaml +++ b/configs/datamodule/configuration/gcp_configuration.yaml @@ -12,28 +12,364 @@ input_data: forecast_minutes: 480 time_resolution_minutes: 30 metadata_only: false - + pv: pv_files_groups: - label: solar_sheffield_passiv pv_filename: /mnt/disks/nwp/passive/v0/passiv.netcdf pv_metadata_filename: /mnt/disks/nwp/passive/v0/system_metadata_OCF_ONLY.csv - pv_ml_ids: [154,155,156,158,159,160,162,164,165,166,167,168,169,171,173,177,178,179,181,182,185, - 186,187,188,189,190,191,192,193,197,198,199,200,202,204,205,206,208,209,211,214,215,216,217, - 218,219,220,221,225,229,230,232,233,234,236,242,243,245,252,254,255,256,257,258,260,261,262, - 265,267,268,272,273,275,276,277,280,281,282,283,287,289,291,292,293,294,295,296,297,298,301, - 302,303,304,306,307,309,310,311,317,318,319,320,321,322,323,325,326,329,332,333,335,336,338, - 340,342,344,345,346,348,349,352,354,355,356,357,360,362,363,368,369,370,371,372,374,375,376, - 378,380,382,384,385,388,390,391,393,396,397,398,399,400,401,403,404,405,406,407,409,411,412, - 413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,429,431,435,437,438,440,441,444, - 447,450,451,453,456,457,458,459,464,465,466,467,468,470,471,473,474,476,477,479,480,481,482, - 485,486,488,490,491,492,493,496,498,501,503,506,507,508,509,510,511,512,513,515,516,517,519, - 520,521,522,524,526,527,528,531,532,536,537,538,540,541,542,543,544,545,549,550,551,552,553, - 554,556,557,560,561,563,566,568,571,572,575,576,577,579,580,581,582,584,585,588,590,594,595, - 597,600,602,603,604,606,611,613,614,616,618,620,622,623,624,625,626,628,629,630,631,636,637, - 638,640,641,642,644,645,646,650,651,652,653,654,655,657,660,661,662,663,666,667,668,670,675, - 676,679,681,683,684,685,687,696,698,701,702,703,704,706,710,722,723,724,725,727,728,729,730, - 732,733,734,735,736,737] + pv_ml_ids: + [ + 154, + 155, + 156, + 158, + 159, + 160, + 162, + 164, + 165, + 166, + 167, + 168, + 169, + 171, + 173, + 177, + 178, + 179, + 181, + 182, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 197, + 198, + 199, + 200, + 202, + 204, + 205, + 206, + 208, + 209, + 211, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 225, + 229, + 230, + 232, + 233, + 234, + 236, + 242, + 243, + 245, + 252, + 254, + 255, + 256, + 257, + 258, + 260, + 261, + 262, + 265, + 267, + 268, + 272, + 273, + 275, + 276, + 277, + 280, + 281, + 282, + 283, + 287, + 289, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 301, + 302, + 303, + 304, + 306, + 307, + 309, + 310, + 311, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 325, + 326, + 329, + 332, + 333, + 335, + 336, + 338, + 340, + 342, + 344, + 345, + 346, + 348, + 349, + 352, + 354, + 355, + 356, + 357, + 360, + 362, + 363, + 368, + 369, + 370, + 371, + 372, + 374, + 375, + 376, + 378, + 380, + 382, + 384, + 385, + 388, + 390, + 391, + 393, + 396, + 397, + 398, + 399, + 400, + 401, + 403, + 404, + 405, + 406, + 407, + 409, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 429, + 431, + 435, + 437, + 438, + 440, + 441, + 444, + 447, + 450, + 451, + 453, + 456, + 457, + 458, + 459, + 464, + 465, + 466, + 467, + 468, + 470, + 471, + 473, + 474, + 476, + 477, + 479, + 480, + 481, + 482, + 485, + 486, + 488, + 490, + 491, + 492, + 493, + 496, + 498, + 501, + 503, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 513, + 515, + 516, + 517, + 519, + 520, + 521, + 522, + 524, + 526, + 527, + 528, + 531, + 532, + 536, + 537, + 538, + 540, + 541, + 542, + 543, + 544, + 545, + 549, + 550, + 551, + 552, + 553, + 554, + 556, + 557, + 560, + 561, + 563, + 566, + 568, + 571, + 572, + 575, + 576, + 577, + 579, + 580, + 581, + 582, + 584, + 585, + 588, + 590, + 594, + 595, + 597, + 600, + 602, + 603, + 604, + 606, + 611, + 613, + 614, + 616, + 618, + 620, + 622, + 623, + 624, + 625, + 626, + 628, + 629, + 630, + 631, + 636, + 637, + 638, + 640, + 641, + 642, + 644, + 645, + 646, + 650, + 651, + 652, + 653, + 654, + 655, + 657, + 660, + 661, + 662, + 663, + 666, + 667, + 668, + 670, + 675, + 676, + 679, + 681, + 683, + 684, + 685, + 687, + 696, + 698, + 701, + 702, + 703, + 704, + 706, + 710, + 722, + 723, + 724, + 725, + 727, + 728, + 729, + 730, + 732, + 733, + 734, + 735, + 736, + 737, + ] history_minutes: 180 forecast_minutes: 0 time_resolution_minutes: 5 diff --git a/configs/model/multimodal.yaml b/configs/model/multimodal.yaml index 2b2db1f1..41af91e2 100644 --- a/configs/model/multimodal.yaml +++ b/configs/model/multimodal.yaml @@ -6,7 +6,7 @@ output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] # NWP encoder #-------------------------------------------- -nwp_encoder: +nwp_encoder: _target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet _partial_: True in_channels: 2 @@ -85,4 +85,4 @@ optimizer: amsgrad: True patience: 5 factor: 0.1 - threshold: 0.002 \ No newline at end of file + threshold: 0.002 diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 7657a035..d1faa2d8 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -26,8 +26,8 @@ class Model(BaseModel): - NWP, if included, is put through a similar encoder. - PV site-level data, if included, is put through an encoder which transforms it from 2D, with time and system-ID dimensions, to become a 1D feature vector. - - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun - paramters* are concatenated into a 1D feature vector and passed through another neural + - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun + paramters* are concatenated into a 1D feature vector and passed through another neural network to combine them and produce a forecast. * if included @@ -55,10 +55,7 @@ def __init__( pv_history_minutes: Optional[int] = None, optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), ): - """Neural network which combines information from different sources. - - - """ + """Neural network which combines information from different sources.""" self.include_gsp_yield_history = include_gsp_yield_history self.include_sat = sat_encoder is not None self.include_nwp = nwp_encoder is not None @@ -68,17 +65,17 @@ def __init__( self.add_image_embedding_channel = add_image_embedding_channel super().__init__(history_minutes, forecast_minutes, optimizer, output_quantiles) - + # Number of features expected by the output_network # Add to this as network pices are constructed fusion_input_features = 0 - + if self.include_sat: # We limit the history to have a delay of 15 mins in satellite data if sat_history_minutes is None: sat_history_minutes = history_minutes - + self.sat_sequence_len = (sat_history_minutes - min_sat_delay_minutes) // 5 + 1 self.sat_encoder = sat_encoder( @@ -89,8 +86,8 @@ def __init__( self.sat_embed = ImageEmbedding( 318, self.sat_sequence_len, self.sat_encoder.image_size_pixels ) - - # Update num features + + # Update num features fusion_input_features += self.sat_encoder.out_features if self.include_nwp: @@ -108,25 +105,25 @@ def __init__( self.nwp_embed = ImageEmbedding( 318, nwp_sequence_len, self.nwp_encoder.image_size_pixels ) - - # Update num features + + # Update num features fusion_input_features += self.nwp_encoder.out_features - + if self.include_pv: if pv_history_minutes is None: pv_history_minutes = history_minutes - + self.pv_encoder = pv_encoder( - sequence_length=pv_history_minutes//5 + 1, + sequence_length=pv_history_minutes // 5 + 1, ) - - # Update num features + + # Update num features fusion_input_features += self.pv_encoder.out_features if self.embedding_dim: self.embed = nn.Embedding(num_embeddings=318, embedding_dim=embedding_dim) - - # Update num features + + # Update num features fusion_input_features += embedding_dim if self.include_sun: @@ -135,19 +132,19 @@ def __init__( in_features=2 * (self.forecast_len_30 + self.history_len_30 + 1), out_features=16, ) - - # Update num features + + # Update num features fusion_input_features += 16 - + if include_gsp_yield_history: - # Update num features + # Update num features fusion_input_features += self.history_len_30 self.output_network = output_network( in_features=fusion_input_features, out_features=self.num_output_features, ) - + self.save_hyperparameters() def forward(self, x): @@ -156,7 +153,7 @@ def forward(self, x): # ******************* Satellite imagery ************************* if self.include_sat: # Shape: batch_size, seq_length, channel, height, width - sat_data = x[BatchKey.satellite_actual][:, :self.sat_sequence_len] + sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len] sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels if self.add_image_embedding_channel: id = x[BatchKey.gsp_id][:, 0].int() @@ -172,12 +169,11 @@ def forward(self, x): id = x[BatchKey.gsp_id][:, 0].int() nwp_data = self.nwp_embed(nwp_data, id) modes["nwp"] = self.nwp_encoder(nwp_data) - - + # *********************** PV Data ************************************* # Add site-level PV yield if self.include_pv: - modes["pv"] = self.pv_encoder(x) + modes["pv"] = self.pv_encoder(x) # *********************** GSP Data ************************************ # add gsp yield history diff --git a/pvnet/models/multimodal/site_encoders/basic_blocks.py b/pvnet/models/multimodal/site_encoders/basic_blocks.py index 1fcd2056..b20835f1 100644 --- a/pvnet/models/multimodal/site_encoders/basic_blocks.py +++ b/pvnet/models/multimodal/site_encoders/basic_blocks.py @@ -29,8 +29,7 @@ def __init__( self.num_sites = num_sites self.out_features = out_features - @abstractmethod def forward(self): """Run model forward""" - pass \ No newline at end of file + pass diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 9fb2ea52..4821968f 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -3,9 +3,8 @@ """ import torch -from torch import nn - from ocf_datapipes.utils.consts import BatchKey +from torch import nn from pvnet.models.multimodal.linear_networks.networks import ResFCNet2 from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder @@ -14,11 +13,11 @@ class SimpleLearnedAggregator(AbstractPVSitesEncoder): """A simple model which learns a different weighted-average across all of the PV sites for each GSP. - + Each sequence from each site is independently encodeded through some dense layers wih skip- connections, then the encoded form of each sequence is aggregated through a learned weighted-sum and finally put through more dense layers. - + This model was written to be a simplified version of a single-headed attention layer. """ @@ -32,21 +31,21 @@ def __init__( final_resblocks: int = 2, ): """A simple sequence encoder and weighted-average model. - + Args: sequence_length: The time sequence length of the data. num_sites: Number of PV sites in the input data. out_features: Number of output features. - value_dim: The number of features in each encoded sequence. Similar to the value + value_dim: The number of features in each encoded sequence. Similar to the value dimension in single- or multi-head attention. - value_dim: The number of features in each encoded sequence. Similar to the value + value_dim: The number of features in each encoded sequence. Similar to the value dimension in single- or multi-head attention. value_enc_resblocks: Number of residual blocks in the value-encoder sub-network. final_resblocks: Number of residual blocks in the final sub-network. """ super().__init__(sequence_length, num_sites, out_features) - + # Network used to encode each PV site sequence self._value_encoder = nn.Sequential( ResFCNet2( @@ -58,84 +57,85 @@ def __init__( dropout_frac=0, ), ) - - # The learned weighted average is stored in an embedding layer for ease of use + + # The learned weighted average is stored in an embedding layer for ease of use self._attention_network = nn.Sequential( nn.Embedding(318, num_sites), nn.Softmax(dim=1), ) - + # Network used to process weighted average self.output_network = ResFCNet2( - in_features=value_dim, - out_features=out_features, - fc_hidden_features=value_dim, - n_res_blocks=final_resblocks, - res_block_layers=2, - dropout_frac=0, + in_features=value_dim, + out_features=out_features, + fc_hidden_features=value_dim, + n_res_blocks=final_resblocks, + res_block_layers=2, + dropout_frac=0, ) - + def _calculate_attention(self, x): gsp_ids = x[BatchKey.gsp_id].squeeze().int() attention = self._attention_network(gsp_ids) return attention - + def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] pv_site_seqs = x[BatchKey.pv].float() batch_size = pv_site_seqs.shape[0] - - pv_site_seqs = pv_site_seqs.swapaxes(1,2).flatten(0,1) - + + pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1) + x_seq_enc = self._value_encoder(pv_site_seqs) x_seq_out = x_seq_enc.unflatten(0, (batch_size, self.num_sites)) - return x_seq_out - + return x_seq_out + def forward(self, x): """Run model forward""" - # Output has shape: [batch size, num_sites, value_dim] + # Output has shape: [batch size, num_sites, value_dim] encodeded_seqs = self._encode_value(x) - + # Calculate learned averaging weights attn_avg_weights = self._calculate_attention(x) - + # Take weighted average across num_sites - value_weighted_avg = (encodeded_seqs*attn_avg_weights.unsqueeze(-1)).sum(dim=1) - + value_weighted_avg = (encodeded_seqs * attn_avg_weights.unsqueeze(-1)).sum(dim=1) + # Put through final processing layers x_out = self.output_network(value_weighted_avg) - + return x_out - + class SingleAttentionNetwork(AbstractPVSitesEncoder): """A simple attention-based model with a single multihead attention layer - + For the attention layer the query is based on the target GSP alone, the key is based on the PV ID and the recent PV data, the value is based on the recent PV data. - + """ + def __init__( - self, - sequence_length: int, - num_sites: int, + self, + sequence_length: int, + num_sites: int, out_features: int, - kdim: int = 10, - num_heads: int = 2, + kdim: int = 10, + num_heads: int = 2, pv_id_embed_dim: int = 10, n_kv_res_blocks: int = 2, kv_res_block_layers: int = 2, use_pv_id_in_value: bool = False, - ): + ): """A simple attention-based model with a single multihead attention layer - + Args: sequence_length: The time sequence length of the data. num_sites: Number of PV sites in the input data. - out_features: Number of output features. In this network this is also the the value + out_features: Number of output features. In this network this is also the the value dimension in the multi-head attention layer. kdim: The dimensions used in both the keys and queries. - num_heads: Number of parallel attention heads. Note that `out_features` will be split + num_heads: Number of parallel attention heads. Note that `out_features` will be split across `num_heads` so `out_features` must be a multiple of `num_heads`. pv_id_embed_dim: The dimension of the PV ID embedding used in calculating the key. n_kv_res_blocks: Number of residual blocks to use in the key and value encoders. @@ -143,21 +143,21 @@ def __init__( the key and value encoders. use_pv_id_in_value: Whether to use the PV ID in network used to produce the value for the attention layer. - + """ super().__init__(sequence_length, num_sites, out_features) - + self.gsp_id_embedding = nn.Embedding(318, kdim) self.pv_id_embedding = nn.Embedding(num_sites, pv_id_embed_dim) self._pv_ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False) self.use_pv_id_in_value = use_pv_id_in_value - + if use_pv_id_in_value: self.value_pv_id_embedding = nn.Embedding(num_sites, pv_id_embed_dim) - + self._value_encoder = nn.Sequential( ResFCNet2( - in_features=sequence_length+int(use_pv_id_in_value)*pv_id_embed_dim, + in_features=sequence_length + int(use_pv_id_in_value) * pv_id_embed_dim, out_features=out_features, fc_hidden_features=sequence_length, n_res_blocks=n_kv_res_blocks, @@ -166,88 +166,83 @@ def __init__( ), nn.Linear(out_features, kdim), ) - - self._key_encoder = nn.Sequential( + + self._key_encoder = nn.Sequential( ResFCNet2( - in_features=pv_id_embed_dim+sequence_length, + in_features=pv_id_embed_dim + sequence_length, out_features=kdim, - fc_hidden_features=pv_id_embed_dim+sequence_length, + fc_hidden_features=pv_id_embed_dim + sequence_length, n_res_blocks=n_kv_res_blocks, res_block_layers=kv_res_block_layers, dropout_frac=0, ), - nn.Linear(kdim, kdim) + nn.Linear(kdim, kdim), ) - + self.multihead_attn = nn.MultiheadAttention( - embed_dim=kdim, + embed_dim=kdim, num_heads=num_heads, batch_first=True, vdim=out_features, ) - - + def _encode_query(self, x): gsp_ids = x[BatchKey.gsp_id].squeeze().int() query = self.gsp_id_embedding(gsp_ids).unsqueeze(1) return query - + def _encode_key(self, x): # Shape: [batch size, sequence length, PV site] pv_site_seqs = x[BatchKey.pv].float() batch_size = pv_site_seqs.shape[0] - + # PV ID embeddings are the same for each sample pv_id_embed = torch.tile(self.pv_id_embedding(self._pv_ids), (batch_size, 1, 1)) - + # Each concated (PV sequence, PV ID embedding) is processed with encoder - x_seq_in = torch.cat((pv_site_seqs.swapaxes(1,2), pv_id_embed), dim=2).flatten(0,1) + x_seq_in = torch.cat((pv_site_seqs.swapaxes(1, 2), pv_id_embed), dim=2).flatten(0, 1) key = self._key_encoder(x_seq_in) - + # Reshape to [batch size, PV site, kdim] key = key.unflatten(0, (batch_size, self.num_sites)) return key - + def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] pv_site_seqs = x[BatchKey.pv].float() batch_size = pv_site_seqs.shape[0] - + if self.use_pv_id_in_value: # PV ID embeddings are the same for each sample pv_id_embed = torch.tile(self.value_pv_id_embedding(self._pv_ids), (batch_size, 1, 1)) # Each concated (PV sequence, PV ID embedding) is processed with encoder - x_seq_in = torch.cat((pv_site_seqs.swapaxes(1,2), pv_id_embed), dim=2).flatten(0,1) + x_seq_in = torch.cat((pv_site_seqs.swapaxes(1, 2), pv_id_embed), dim=2).flatten(0, 1) else: - # Encode each PV sequence independently - x_seq_in = pv_site_seqs.swapaxes(1,2).flatten(0,1) - + # Encode each PV sequence independently + x_seq_in = pv_site_seqs.swapaxes(1, 2).flatten(0, 1) + value = self._value_encoder(x_seq_in) - + # Reshape to [batch size, PV site, vdim] value = value.unflatten(0, (batch_size, self.num_sites)) return value - + def _attention_forward(self, x, average_attn_weights=True): - query = self._encode_query(x) key = self._encode_key(x) value = self._encode_value(x) attn_output, attn_weights = self.multihead_attn( - query, - key, - value, - average_attn_weights=average_attn_weights + query, key, value, average_attn_weights=average_attn_weights ) - - return attn_output, attn_weights - + + return attn_output, attn_weights + def forward(self, x): """Run model forward""" attn_output, attn_output_weights = self._attention_forward(x) - + # Reshape from [batch_size, 1, vdim] to [batch_size, vdim] x_out = attn_output.squeeze() - - return x_out \ No newline at end of file + + return x_out diff --git a/pvnet/optimizers.py b/pvnet/optimizers.py index 92d6fa40..55aa5493 100644 --- a/pvnet/optimizers.py +++ b/pvnet/optimizers.py @@ -50,41 +50,41 @@ def __call__(self, model): """Return optimizer""" return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) - + def find_submodule_parameters(model, search_modules): if isinstance(model, search_modules): return model.parameters() - + children = list(model.children()) - if len(children)==0: + if len(children) == 0: return [] else: params = [] for c in children: params += find_submodule_parameters(c, search_modules) return params - - + + def find_other_than_submodule_parameters(model, ignore_modules): if isinstance(model, ignore_modules): return [] - + children = list(model.children()) - if len(children)==0: + if len(children) == 0: return model.parameters() else: params = [] for c in children: params += find_other_than_submodule_parameters(c, ignore_modules) return params - - + + class EmbAdamWReduceLROnPlateau(AbstractOptimizer): """AdamW optimizer and reduce on plateau scheduler""" - def __init__(self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, - **opt_kwargs - ): + def __init__( + self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs + ): """AdamW optimizer and reduce on plateau scheduler""" self.lr = lr self.weight_decay = weight_decay @@ -92,13 +92,12 @@ def __init__(self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, thresho self.factor = factor self.threshold = threshold self.opt_kwargs = opt_kwargs - def __call__(self, model): """Return optimizer""" - search_modules = (torch.nn.Embedding, ) - + search_modules = (torch.nn.Embedding,) + no_decay = find_submodule_parameters(model, search_modules) decay = find_other_than_submodule_parameters(model, search_modules) @@ -107,7 +106,7 @@ def __call__(self, model): {"params": no_decay, "weight_decay": 0.0}, ] opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs) - + sch = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, factor=self.factor, @@ -115,7 +114,7 @@ def __call__(self, model): threshold=self.threshold, ) sch = { - "scheduler": sch, + "scheduler": sch, "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", } return [opt], [sch] @@ -124,9 +123,9 @@ def __call__(self, model): class AdamWReduceLROnPlateau(AbstractOptimizer): """AdamW optimizer and reduce on plateau scheduler""" - def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, - **opt_kwargs - ): + def __init__( + self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs + ): """AdamW optimizer and reduce on plateau scheduler""" self._lr = lr self.patience = patience @@ -134,31 +133,28 @@ def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq= self.threshold = threshold self.step_freq = step_freq self.opt_kwargs = opt_kwargs - + def _call_multi(self, model): - - remaining_params = {k:p for k,p in model.named_parameters()} - + remaining_params = {k: p for k, p in model.named_parameters()} + group_args = [] - + for key in self._lr.keys(): - if key=="default": + if key == "default": continue - + submodule_params = [] for param_name in list(remaining_params.keys()): if param_name.startswith(key): submodule_params += [remaining_params.pop(param_name)] - + group_args += [{"params": submodule_params, "lr": self._lr[key]}] remaining_params = [p for k, p in remaining_params.items()] group_args += [{"params": remaining_params}] - + opt = torch.optim.AdamW( - group_args, - lr=self._lr["default"] if model.lr is None else model.lr, - **self.opt_kwargs + group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs ) sch = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -166,12 +162,11 @@ def _call_multi(self, model): factor=self.factor, patience=self.patience, threshold=self.threshold, - ), + ), "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", } return [opt], [sch] - def __call__(self, model): """Return optimizer""" @@ -188,7 +183,7 @@ def __call__(self, model): threshold=self.threshold, ) sch = { - "scheduler": sch, + "scheduler": sch, "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", } return [opt], [sch]