diff --git a/pyproject.toml b/pyproject.toml index 37ba2f4..490ff6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ "anemoi-utils>=0.4.4", + "cfunits", "earthkit-data", "earthkit-meteo", ] diff --git a/src/anemoi/transform/filters/rescale.py b/src/anemoi/transform/filters/rescale.py new file mode 100644 index 0000000..24eb5ba --- /dev/null +++ b/src/anemoi/transform/filters/rescale.py @@ -0,0 +1,82 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from cfunits import Units + +from . import filter_registry +from .base import SimpleFilter + +class Rescale(SimpleFilter): + """A filter to rescale a parameter from a scale and an offset, and back. + """ + + def __init__( + self, + *, + scale, + offset, + param, + ): + self.scale = scale + self.offset = offset + self.param = param + + def forward(self, data): + return self._transform( + data, + self.forward_transform, + self.param + ) + + def backward(self, data): + return self._transform( + data, + self.backward_transform, + self.param, + ) + + def forward_transform(self, x): + """x to ax+b""" + + rescaled = x.to_numpy() * self.scale + self.offset + + yield self.new_field_from_numpy(rescaled, template=x, param=self.param) + + def backward_transform(self, x): + """ax+b to x""" + + descaled = (x.to_numpy() - self.offset)/ self.scale + + yield self.new_field_from_numpy(descaled, template=x, param=self.param) + +class Convert(Rescale): + """A filter to convert a parameter in a given unit to another unit, and back. + """ + + def __init__( + self, + *, + unit_in, + unit_out, + param + ): + u0 = Units(unit_in) + u1 = Units(unit_out) + x1, x2 = 0.0, 1.0 + y1, y2 = Units.conform([x1, x2], u0, u1) + a = (y2 - y1) / (x2 - x1) + b = y1 - a * x1 + self.scale = a + self.offset = b + self.param = param + + +filter_registry.register("rescale", Rescale) +filter_registry.register("convert", Convert) diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..0351907 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,61 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.transform.filters.rescale import Rescale, Convert +import earthkit.data as ekd +from pytest import approx + + +def test_rescale(): + # rescale from K to °C + temp = ekd.from_source( + "mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]} + ) + fieldlist = temp.to_fieldlist() + k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t") + rescaled = k_to_deg.forward(fieldlist) + assert rescaled[0].values.min() == temp.values.min() - 273.15 + assert rescaled[0].values.std() == approx(temp.values.std()) + # and back + rescaled_back = k_to_deg.backward(rescaled) + assert rescaled_back[0].values.min() == temp.values.min() + assert rescaled_back[0].values.std() == approx(temp.values.std()) + # rescale from °C to F + deg_to_far = Rescale(scale=9 / 5, offset=32, param="2t") + rescaled_farheneit = deg_to_far.forward(rescaled) + assert rescaled_farheneit[0].values.min() == 9 / 5 * rescaled[0].values.min() + 32 + assert rescaled_farheneit[0].values.std() == approx( + (9 / 5) * rescaled[0].values.std() + ) + # rescale from F to K + rescaled_back = k_to_deg.backward(deg_to_far.backward(rescaled_farheneit)) + assert rescaled_back[0].values.min() == temp.values.min() + assert rescaled_back[0].values.std() == approx(temp.values.std()) + + +def test_convert(): + # rescale from K to °C + temp = ekd.from_source( + "mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]} + ) + fieldlist = temp.to_fieldlist() + k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t") + rescaled = k_to_deg.forward(fieldlist) + assert rescaled[0].values.min() == temp.values.min() - 273.15 + assert rescaled[0].values.std() == approx(temp.values.std()) + # and back + rescaled_back = k_to_deg.backward(rescaled) + assert rescaled_back[0].values.min() == temp.values.min() + assert rescaled_back[0].values.std() == approx(temp.values.std()) + + +if __name__ == "__main__": + test_rescale() + test_convert()