From e48d2c7c89746b7546bab4484056b613c9f2f3b6 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 18 Oct 2024 19:22:36 +0100 Subject: [PATCH] Updates - Fix docs - Provide add_scalar in BaseWeighedLoss --- src/anemoi/training/losses/utils.py | 15 +++++++++------ src/anemoi/training/losses/weightedloss.py | 7 ++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index 2781b95..9d1d5b2 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -23,7 +23,7 @@ LOGGER = logging.getLogger(__name__) -def grad_scaler( +def grad_scalar( module: nn.Module, grad_in: tuple[torch.Tensor, ...], grad_out: tuple[torch.Tensor, ...], @@ -32,7 +32,7 @@ def grad_scaler( Uses the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 - Use .register_full_backward_hook(grad_scaler, prepend=False) to register this hook. + Use .register_full_backward_hook(grad_scalar, prepend=False) to register this hook. Parameters ---------- @@ -89,8 +89,11 @@ class ScaleTensor: >>> tensor = torch.randn(3, 4, 5) >>> scalars = ScaleTensor((0, torch.randn(3)), (1, torch.randn(4))) >>> scaled_tensor = scalars.scale(tensor) - >>> scalars.get_scalar(tensor.shape).shape + >>> scalars.get_scalar(tensor.ndim).shape torch.Size([3, 4, 1]) + >>> scalars.add_scalar(-1, torch.randn(5)) + >>> scalars.get_scalar(tensor.ndim).shape + torch.Size([3, 4, 5]) """ tensors: dict[str, TENSOR_SPEC] @@ -146,7 +149,7 @@ def get_dim_shape(dimension: int) -> int: return Shape(get_dim_shape) - def validate_scaler(self, dimension: int | tuple[int], scalar: torch.Tensor) -> None: + def validate_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor) -> None: """Check if the scalar is compatible with the given dimension. Parameters @@ -170,7 +173,7 @@ def validate_scaler(self, dimension: int | tuple[int], scalar: torch.Tensor) -> if self.shape[dim] != scalar.shape[scalar_dim]: error_msg = ( - f"Scaler shape {scalar.shape} at dimension {scalar_dim}" + f"Scalar shape {scalar.shape} at dimension {scalar_dim}" f"does not match shape of scalar at dimension {dim}. Expected {self.shape[dim]}", ) raise ValueError(error_msg) @@ -200,7 +203,7 @@ def add_scalar( dimension = tuple(dimension + i for i in range(len(scalar.shape))) try: - self.validate_scaler(dimension, scalar) + self.validate_scalar(dimension, scalar) except ValueError as e: error_msg = f"Validating tensor {name!r} raised an invalidation." raise ValueError(error_msg) from e diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py index eaf691e..80602b6 100644 --- a/src/anemoi/training/losses/weightedloss.py +++ b/src/anemoi/training/losses/weightedloss.py @@ -9,6 +9,7 @@ from __future__ import annotations +import functools import logging from abc import ABC from abc import abstractmethod @@ -65,6 +66,10 @@ def __init__( if variable_scaling is not None: self.scalar.add_scalar(-1, variable_scaling, "variable_scaling") + @functools.wraps(ScaleTensor.add_scalar, assigned=("__doc__", "__annotations__")) + def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None: + self.scalar.add_scalar(dimension=dimension, scalar=scalar, name=name) + def scale( self, x: torch.Tensor, @@ -90,7 +95,7 @@ def scale( scalar = self.scalar.get_scalar(x.shape) - if feature_indices is None: + if feature_indices is None or "variable_scaling" not in self.scalar: return x * scalar return x * scalar[..., feature_indices]