From da2ffccd67e64822a222b28fa51254754856a5de Mon Sep 17 00:00:00 2001 From: Gabriel Moldovan Date: Thu, 1 Aug 2024 14:56:09 +0000 Subject: [PATCH] activation functions added for bounded outputs --- src/anemoi/models/models/bounding.py | 122 ++++++++++++++++++ .../models/encoder_processor_decoder.py | 33 +++++ 2 files changed, 155 insertions(+) create mode 100644 src/anemoi/models/models/bounding.py diff --git a/src/anemoi/models/models/bounding.py b/src/anemoi/models/models/bounding.py new file mode 100644 index 0000000..c9913da --- /dev/null +++ b/src/anemoi/models/models/bounding.py @@ -0,0 +1,122 @@ +from abc import ABC +from abc import abstractmethod + +import torch + + +class BaseBoundingStrategy(ABC): + """Abstract base class for bounding strategies. + + This class defines an interface for bounding strategies which are used to apply a specific + restriction to the predictions of a model. + + Methods + ------- + apply(y_pred: torch.Tensor, indices: list) -> torch.Tensor + Applies the bounding strategy to the given variables (indices) of the input prediction (y_pred) + + Parameters + ---------- + y_pred : torch.Tensor + The tensor containing the predictions that will be bounded. + indices : list + A list of indices specifying which variables in `y_pred` should be bounded. + + Returns + ------- + torch.Tensor + A tensor with the bounding strategy applied. + """ + + @abstractmethod + def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor: + pass + + +class ReluBoundingStrategy(BaseBoundingStrategy): + def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor: + return torch.nn.functional.relu(y_pred[..., indices[0]]) + + +class HardtanhBoundingStrategy(BaseBoundingStrategy): + """Initializes the bounding with specified minimum and maximum values for bounding. + + Parameters + ---------- + min_val : float + The minimum value for the HardTanh activation. + max_val : float + The maximum value for the HardTanh activation. + """ + + def __init__(self, min_val: float, max_val: float): + super().__init__() + self.min_val = min_val + self.max_val = max_val + + def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor: + return torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val) + + +class FractionHardtanhBoundingStrategy(BaseBoundingStrategy): + """Initializes the FractionHardtanhBoundingStrategy with specified parameters. + + Parameters + ---------- + min_val : float + The minimum value for the HardTanh activation function. + max_val : float + The maximum value for the HardTanh activation function. + total_var : str + A string representing a variable from which a secondary variable is derived. For + example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). + """ + + def __init__(self, min_val: float, max_val: float, total_var: str): + + super().__init__() + self.min_val = min_val + self.max_val = max_val + self.total_var = total_var + + def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor: + return ( + torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val) + * y_pred[..., indices[1]] + ) + + +class CustomFractionHardtanhBoundingStrategy(BaseBoundingStrategy): + """Initializes the CustomFractionHardtanhBoundingStrategy. + + Description + ---------- + Initializes the CustomFractionHardtanhBoundingStrategy with specified + parameters. This is a special case of FractionHardtanhBoundingStrategy where the + total variable is constructed from a combination of two other variables. For + example, large-scale precipitation (lsp) can be derived from total precipitation (tp) + and convective precipitation (cp) as follows: lsp = tp - cp. + + Parameters + ---------- + min_val : float + The minimum value for the HardTanh activation function. + max_val : float + The maximum value for the HardTanh activation function. + first_var : str + First variable from which the total variable is derived. + second_var : str + Second variable from which the total variable is derived. + """ + + def __init__(self, min_val: float, max_val: float, first_var: str, second_var: str): + super().__init__() + self.min_val = min_val + self.max_val = max_val + self.first_var = first_var + self.second_var = second_var + + def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor: + return torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val) * ( + y_pred[..., indices[1]] - y_pred[..., indices[2]] + ) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 0f37474..b371284 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,6 +21,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards +from anemoi.models.models.bounding import BaseBoundingStrategy from anemoi.models.layers.graph import TrainableTensor LOGGER = logging.getLogger(__name__) @@ -67,6 +68,18 @@ def __init__( self._register_latlon("data", self._graph_name_data) self._register_latlon("hidden", self._graph_name_hidden) + # Variables affected by the activation function + def create_bounding_strategy(config: DotDict) -> BaseBoundingStrategy: + return instantiate(config) + + self.data_indices = data_indices + if config.training.bounding_strategies is not None: + self.bounding_strategies = { + var: create_bounding_strategy(cfg) for var, cfg in config.training.bounding_strategies.items() + } + else: + self.bounding_strategies = {} + self.num_channels = config.model.num_channels input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size @@ -250,4 +263,24 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # residual connection (just for the prognostic variables) x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for var, strategy in self.bounding_strategies.items(): # bounding performed in the order specified in the config file + indices = [] + indices.append(self.data_indices.model.output.name_to_index[var]) + + # Special case when fraction activation is used var = frac * var_total + if strategy.__class__.__name__ == "FractionHardtanhBoundingStrategy": + indices.append(self.data_indices.model.output.name_to_index[strategy.total_var]) + elif strategy.__class__.__name__ == "CustomFractionHardtanhBoundingStrategy": + indices.extend( + [ + self.data_indices.model.output.name_to_index[strategy.first_var], + self.data_indices.model.output.name_to_index[strategy.second_var], + ], + ) + + activated_var = strategy.apply(x_out, indices) + x_out = x_out.clone() # needed to avoid inplace operation error during backpropagation + x_out[..., self.data_indices.model.output.name_to_index[var]] = activated_var + return x_out