Skip to content

Commit

Permalink
test: add a test for hydra instantiating of bounding
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed Aug 29, 2024
1 parent 703c8fc commit 6c8ff12
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/layers/test_bounding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate

from anemoi.models.layers.bounding import FractionBounding
from anemoi.models.layers.bounding import HardtanhBounding
Expand Down Expand Up @@ -64,3 +65,31 @@ def test_multi_chained_bounding(config, name_to_index, input_tensor):
# Data with Relu applied first and then Hardtanh
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_hydra_instantiate_bounding(config, name_to_index, input_tensor):
layer_definitions = [
{
"_target_": "anemoi.models.layers.bounding.ReluBounding",
"variables": config.variables,
"name_to_index": name_to_index,
},
{
"_target_": "anemoi.models.layers.bounding.HardtanhBounding",
"variables": config.variables,
"name_to_index": name_to_index,
"min_val": 0.0,
"max_val": 1.0,
},
{
"_target_": "anemoi.models.layers.bounding.FractionBounding",
"variables": config.variables,
"name_to_index": name_to_index,
"min_val": 0.0,
"max_val": 1.0,
"total_var": config.total_var,
},
]
for layer_definition in layer_definitions:
bounding = instantiate(layer_definition)
bounding(input_tensor.clone())

0 comments on commit 6c8ff12

Please sign in to comment.