From 3ecf92c423de3111d9e7b7fbee849990daa1726e Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Wed, 31 Jul 2024 15:01:42 +0000 Subject: [PATCH] Copy losses/utils.py Co-authored-by: Jesper Dramsch --- src/anemoi/training/losses/utils.py | 44 +++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/anemoi/training/losses/utils.py diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py new file mode 100644 index 00000000..d491d35d --- /dev/null +++ b/src/anemoi/training/losses/utils.py @@ -0,0 +1,44 @@ +from typing import Optional + +import torch +from torch import nn + +from anemoi.training.utils.logger import get_code_logger + +LOGGER = get_code_logger(__name__) + + +def grad_scaler( + module: nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], +) -> Optional[tuple[torch.Tensor, ...]]: + """Scales the loss gradients. + + Uses the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 + + Use .register_full_backward_hook(grad_scaler, prepend=False) to register this hook. + + Parameters + ---------- + module : nn.Module + Loss object (not used) + grad_in : tuple[torch.Tensor, ...] + Loss gradients + grad_out : tuple[torch.Tensor, ...] + Output gradients (not used) + + Returns + ------- + tuple[torch.Tensor, ...] + Re-scaled input gradients + """ + del module, grad_out # not needed + # loss = module(x_pred, x_true) + # so - the first grad_input is that of the predicted state and the second is that of the "ground truth" (== zero) + channels = grad_in[0].shape[-1] # number of channels + channel_weights = torch.reciprocal(torch.sum(torch.abs(grad_in[0]), dim=1, keepdim=True)) # channel-wise weights + new_grad_in = ( + (channels * channel_weights) / torch.sum(channel_weights, dim=-1, keepdim=True) * grad_in[0] + ) # rescaled gradient + return new_grad_in, grad_in[1]