Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
- Fix docs
- Provide add_scalar in BaseWeighedLoss
  • Loading branch information
HCookie committed Oct 18, 2024
1 parent 83f7f5c commit e48d2c7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
15 changes: 9 additions & 6 deletions src/anemoi/training/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LOGGER = logging.getLogger(__name__)


def grad_scaler(
def grad_scalar(
module: nn.Module,
grad_in: tuple[torch.Tensor, ...],
grad_out: tuple[torch.Tensor, ...],
Expand All @@ -32,7 +32,7 @@ def grad_scaler(
Uses the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
Use <module>.register_full_backward_hook(grad_scaler, prepend=False) to register this hook.
Use <module>.register_full_backward_hook(grad_scalar, prepend=False) to register this hook.
Parameters
----------
Expand Down Expand Up @@ -89,8 +89,11 @@ class ScaleTensor:
>>> tensor = torch.randn(3, 4, 5)
>>> scalars = ScaleTensor((0, torch.randn(3)), (1, torch.randn(4)))
>>> scaled_tensor = scalars.scale(tensor)
>>> scalars.get_scalar(tensor.shape).shape
>>> scalars.get_scalar(tensor.ndim).shape
torch.Size([3, 4, 1])
>>> scalars.add_scalar(-1, torch.randn(5))
>>> scalars.get_scalar(tensor.ndim).shape
torch.Size([3, 4, 5])
"""

tensors: dict[str, TENSOR_SPEC]
Expand Down Expand Up @@ -146,7 +149,7 @@ def get_dim_shape(dimension: int) -> int:

return Shape(get_dim_shape)

def validate_scaler(self, dimension: int | tuple[int], scalar: torch.Tensor) -> None:
def validate_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor) -> None:
"""Check if the scalar is compatible with the given dimension.
Parameters
Expand All @@ -170,7 +173,7 @@ def validate_scaler(self, dimension: int | tuple[int], scalar: torch.Tensor) ->

if self.shape[dim] != scalar.shape[scalar_dim]:
error_msg = (
f"Scaler shape {scalar.shape} at dimension {scalar_dim}"
f"Scalar shape {scalar.shape} at dimension {scalar_dim}"
f"does not match shape of scalar at dimension {dim}. Expected {self.shape[dim]}",
)
raise ValueError(error_msg)
Expand Down Expand Up @@ -200,7 +203,7 @@ def add_scalar(
dimension = tuple(dimension + i for i in range(len(scalar.shape)))

try:
self.validate_scaler(dimension, scalar)
self.validate_scalar(dimension, scalar)
except ValueError as e:
error_msg = f"Validating tensor {name!r} raised an invalidation."
raise ValueError(error_msg) from e
Expand Down
7 changes: 6 additions & 1 deletion src/anemoi/training/losses/weightedloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations

import functools
import logging
from abc import ABC
from abc import abstractmethod
Expand Down Expand Up @@ -65,6 +66,10 @@ def __init__(
if variable_scaling is not None:
self.scalar.add_scalar(-1, variable_scaling, "variable_scaling")

@functools.wraps(ScaleTensor.add_scalar, assigned=("__doc__", "__annotations__"))
def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None:
self.scalar.add_scalar(dimension=dimension, scalar=scalar, name=name)

def scale(
self,
x: torch.Tensor,
Expand All @@ -90,7 +95,7 @@ def scale(

scalar = self.scalar.get_scalar(x.shape)

if feature_indices is None:
if feature_indices is None or "variable_scaling" not in self.scalar:
return x * scalar
return x * scalar[..., feature_indices]

Expand Down

0 comments on commit e48d2c7

Please sign in to comment.