Skip to content

Commit

Permalink
Add rescale filter and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Nov 18, 2024
1 parent a3256c9 commit 3fa63da
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ classifiers = [
dynamic = [ "version" ]
dependencies = [
"anemoi-utils>=0.4.4",
"cfunits",
"earthkit-data",
"earthkit-meteo",
]
Expand Down
82 changes: 82 additions & 0 deletions src/anemoi/transform/filters/rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from cfunits import Units

from . import filter_registry
from .base import SimpleFilter

class Rescale(SimpleFilter):
"""A filter to rescale a parameter from a scale and an offset, and back.
"""

def __init__(
self,
*,
scale,
offset,
param,
):
self.scale = scale
self.offset = offset
self.param = param

def forward(self, data):
return self._transform(
data,
self.forward_transform,
self.param
)

def backward(self, data):
return self._transform(
data,
self.backward_transform,
self.param,
)

def forward_transform(self, x):
"""x to ax+b"""

rescaled = x.to_numpy() * self.scale + self.offset

yield self.new_field_from_numpy(rescaled, template=x, param=self.param)

def backward_transform(self, x):
"""ax+b to x"""

descaled = (x.to_numpy() - self.offset)/ self.scale

yield self.new_field_from_numpy(descaled, template=x, param=self.param)

class Convert(Rescale):
"""A filter to convert a parameter in a given unit to another unit, and back.
"""

def __init__(
self,
*,
unit_in,
unit_out,
param
):
u0 = Units(unit_in)
u1 = Units(unit_out)
x1, x2 = 0.0, 1.0
y1, y2 = Units.conform([x1, x2], u0, u1)
a = (y2 - y1) / (x2 - x1)
b = y1 - a * x1
self.scale = a
self.offset = b
self.param = param


filter_registry.register("rescale", Rescale)
filter_registry.register("convert", Convert)
61 changes: 61 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from anemoi.transform.filters.rescale import Rescale, Convert
import earthkit.data as ekd
from pytest import approx


def test_rescale():
# rescale from K to °C
temp = ekd.from_source(
"mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]}
)
fieldlist = temp.to_fieldlist()
k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t")
rescaled = k_to_deg.forward(fieldlist)
assert rescaled[0].values.min() == temp.values.min() - 273.15
assert rescaled[0].values.std() == approx(temp.values.std())
# and back
rescaled_back = k_to_deg.backward(rescaled)
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())
# rescale from °C to F
deg_to_far = Rescale(scale=9 / 5, offset=32, param="2t")
rescaled_farheneit = deg_to_far.forward(rescaled)
assert rescaled_farheneit[0].values.min() == 9 / 5 * rescaled[0].values.min() + 32
assert rescaled_farheneit[0].values.std() == approx(
(9 / 5) * rescaled[0].values.std()
)
# rescale from F to K
rescaled_back = k_to_deg.backward(deg_to_far.backward(rescaled_farheneit))
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())


def test_convert():
# rescale from K to °C
temp = ekd.from_source(
"mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]}
)
fieldlist = temp.to_fieldlist()
k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t")
rescaled = k_to_deg.forward(fieldlist)
assert rescaled[0].values.min() == temp.values.min() - 273.15
assert rescaled[0].values.std() == approx(temp.values.std())
# and back
rescaled_back = k_to_deg.backward(rescaled)
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())


if __name__ == "__main__":
test_rescale()
test_convert()

0 comments on commit 3fa63da

Please sign in to comment.