Skip to content

Commit

Permalink
activation functions added for bounded outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieloks committed Aug 1, 2024
1 parent a34cb8b commit da2ffcc
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
122 changes: 122 additions & 0 deletions src/anemoi/models/models/bounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from abc import ABC
from abc import abstractmethod

import torch


class BaseBoundingStrategy(ABC):
"""Abstract base class for bounding strategies.
This class defines an interface for bounding strategies which are used to apply a specific
restriction to the predictions of a model.
Methods
-------
apply(y_pred: torch.Tensor, indices: list) -> torch.Tensor
Applies the bounding strategy to the given variables (indices) of the input prediction (y_pred)
Parameters
----------
y_pred : torch.Tensor
The tensor containing the predictions that will be bounded.
indices : list
A list of indices specifying which variables in `y_pred` should be bounded.
Returns
-------
torch.Tensor
A tensor with the bounding strategy applied.
"""

@abstractmethod
def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor:
pass


class ReluBoundingStrategy(BaseBoundingStrategy):
def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor:
return torch.nn.functional.relu(y_pred[..., indices[0]])


class HardtanhBoundingStrategy(BaseBoundingStrategy):
"""Initializes the bounding with specified minimum and maximum values for bounding.
Parameters
----------
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
"""

def __init__(self, min_val: float, max_val: float):
super().__init__()
self.min_val = min_val
self.max_val = max_val

def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor:
return torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val)


class FractionHardtanhBoundingStrategy(BaseBoundingStrategy):
"""Initializes the FractionHardtanhBoundingStrategy with specified parameters.
Parameters
----------
min_val : float
The minimum value for the HardTanh activation function.
max_val : float
The maximum value for the HardTanh activation function.
total_var : str
A string representing a variable from which a secondary variable is derived. For
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation).
"""

def __init__(self, min_val: float, max_val: float, total_var: str):

super().__init__()
self.min_val = min_val
self.max_val = max_val
self.total_var = total_var

def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor:
return (
torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val)
* y_pred[..., indices[1]]
)


class CustomFractionHardtanhBoundingStrategy(BaseBoundingStrategy):
"""Initializes the CustomFractionHardtanhBoundingStrategy.
Description
----------
Initializes the CustomFractionHardtanhBoundingStrategy with specified
parameters. This is a special case of FractionHardtanhBoundingStrategy where the
total variable is constructed from a combination of two other variables. For
example, large-scale precipitation (lsp) can be derived from total precipitation (tp)
and convective precipitation (cp) as follows: lsp = tp - cp.
Parameters
----------
min_val : float
The minimum value for the HardTanh activation function.
max_val : float
The maximum value for the HardTanh activation function.
first_var : str
First variable from which the total variable is derived.
second_var : str
Second variable from which the total variable is derived.
"""

def __init__(self, min_val: float, max_val: float, first_var: str, second_var: str):
super().__init__()
self.min_val = min_val
self.max_val = max_val
self.first_var = first_var
self.second_var = second_var

def apply(self, y_pred: torch.Tensor, indices: list) -> torch.Tensor:
return torch.nn.functional.hardtanh(y_pred[..., indices[0]], min_val=self.min_val, max_val=self.max_val) * (
y_pred[..., indices[1]] - y_pred[..., indices[2]]
)
33 changes: 33 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.models.bounding import BaseBoundingStrategy
from anemoi.models.layers.graph import TrainableTensor

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +68,18 @@ def __init__(
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

# Variables affected by the activation function
def create_bounding_strategy(config: DotDict) -> BaseBoundingStrategy:
return instantiate(config)

self.data_indices = data_indices
if config.training.bounding_strategies is not None:
self.bounding_strategies = {
var: create_bounding_strategy(cfg) for var, cfg in config.training.bounding_strategies.items()
}
else:
self.bounding_strategies = {}

self.num_channels = config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
Expand Down Expand Up @@ -250,4 +263,24 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for var, strategy in self.bounding_strategies.items(): # bounding performed in the order specified in the config file
indices = []
indices.append(self.data_indices.model.output.name_to_index[var])

# Special case when fraction activation is used var = frac * var_total
if strategy.__class__.__name__ == "FractionHardtanhBoundingStrategy":
indices.append(self.data_indices.model.output.name_to_index[strategy.total_var])
elif strategy.__class__.__name__ == "CustomFractionHardtanhBoundingStrategy":
indices.extend(
[
self.data_indices.model.output.name_to_index[strategy.first_var],
self.data_indices.model.output.name_to_index[strategy.second_var],
],
)

activated_var = strategy.apply(x_out, indices)
x_out = x_out.clone() # needed to avoid inplace operation error during backpropagation
x_out[..., self.data_indices.model.output.name_to_index[var]] = activated_var

return x_out

0 comments on commit da2ffcc

Please sign in to comment.