Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated Field Solver #2070

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion src/metpy/calc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from .basic import * # noqa: F403
from .cross_sections import * # noqa: F403
from .exceptions import * # noqa: F403
from .field_solver import solver
from .indices import * # noqa: F403
from .kinematics import * # noqa: F403
from .thermo import * # noqa: F403
from .tools import * # noqa: F403
from .turbulence import * # noqa: F403
from ..package_tools import set_module

__all__ = basic.__all__[:] # pylint: disable=undefined-variable
__all__ = ['solver']
__all__.extend(basic.__all__) # pylint: disable=undefined-variable
__all__.extend(cross_sections.__all__) # pylint: disable=undefined-variable
__all__.extend(indices.__all__) # pylint: disable=undefined-variable
__all__.extend(kinematics.__all__) # pylint: disable=undefined-variable
Expand Down
8 changes: 7 additions & 1 deletion src/metpy/calc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from scipy.ndimage import gaussian_filter

from .field_solver import solver
from .. import constants as mpconsts
from ..package_tools import Exporter
from ..units import check_units, masked_array, units
Expand All @@ -30,6 +31,7 @@


@exporter.export
@solver.register()
@preprocess_and_wrap(wrap_like='u')
@check_units('[speed]', '[speed]')
def wind_speed(u, v):
Expand Down Expand Up @@ -57,6 +59,7 @@ def wind_speed(u, v):


@exporter.export
@solver.register()
@preprocess_and_wrap(wrap_like='u')
@check_units('[speed]', '[speed]')
def wind_direction(u, v, convention='from'):
Expand Down Expand Up @@ -112,6 +115,7 @@ def wind_direction(u, v, convention='from'):


@exporter.export
@solver.register('u', 'v')
@preprocess_and_wrap(wrap_like=('speed', 'speed'))
@check_units('[speed]')
def wind_components(speed, wind_direction):
Expand Down Expand Up @@ -153,6 +157,7 @@ def wind_components(speed, wind_direction):


@exporter.export
@solver.register('temperature', 'wind_speed')
@preprocess_and_wrap(wrap_like='temperature')
@check_units(temperature='[temperature]', speed='[speed]')
def windchill(temperature, speed, face_level_winds=False, mask_undefined=True):
Expand Down Expand Up @@ -215,7 +220,8 @@ def windchill(temperature, speed, face_level_winds=False, mask_undefined=True):


@exporter.export
@preprocess_and_wrap(wrap_like='temperature')
@solver.register()
@preprocess_and_wrap(broadcast=('temperature', 'relative_humidity'), wrap_like='temperature')
@check_units('[temperature]')
def heat_index(temperature, relative_humidity, mask_undefined=True):
r"""Calculate the Heat Index from the current temperature and relative humidity.
Expand Down
127 changes: 127 additions & 0 deletions src/metpy/calc/field_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2021 MetPy Developers.
# Distributed under the terms of the BSD 3-Clause License.
# SPDX-License-Identifier: BSD-3-Clause
"""Solver to automatically calculate derived parameters from a dataset."""

from collections import ChainMap, deque
import contextlib
from inspect import signature, Parameter


class Path:
def __init__(self, steps, have, need):
self.steps = steps
self.have = have
self.need = need

def is_complete(self):
return not bool(self.need)

def __add__(self, other):
if any(f in set(self.steps) for f in other.steps):
raise ValueError(f'{other.steps} already in steps')

# Prepend steps so that final path is in proper call order
# Don't really "have" what's in the new function call, but instead it just needs
# to be removed from what's needed.
return Path(other.steps + self.steps, self.have,
(self.need | other.need) - (self.have | other.have))

def __str__(self):
return (f'Path<Steps: {[f.__name__ for f in self.steps]} Have: {self.have} '
f'Need: {self.need}>')

__repr__ = __str__


class Solver:
names = {'tw': 'wet_bulb_temperature', 'td': 'dewpoint_temperature',
'dewpoint': 'dewpoint_temperature', 'tv': 'virtual_temperature',
'q': 'specific_humidity', 'r': 'mixing_ratio', 'rh': 'relative_humidity',
'p': 'pressure', 't': 'temperature', 'isobaric': 'pressure'}

standard_names = {'temperature': 'air_temperature'}

fallback_names = {'temperature': ['temp'], 'pressure': ['P', 'isobaric']}

def __init__(self):
self._graph = {}
self._funcs = {}

def register(self, *args, inputs=None):
def dec(func):
nonlocal inputs
nonlocal args
if inputs is None:
funcsig = signature(func)
inputs = [name for name, param in funcsig.parameters.items() if
param.default is Parameter.empty]

if not args:
args = (func.__name__,)

normed_returns = self.normalize_names(args)
normed_inputs = self.normalize_names(inputs)
path = Path([func], set(normed_returns), set(normed_inputs))
self._funcs[func] = (normed_inputs, normed_returns)
for ret in normed_returns:
self._graph.setdefault(ret, []).append(path)
return func

return dec

def normalize_names(self, names):
return [self.normalize(name) for name in names]

def normalize(self, name):
return self.names.get(name.lower(), name.lower())

def _map_func_args(self, func, data):
key_map = {self.normalize(key): key for key in ChainMap(data, data.coords)}
for item in self._funcs[func][0]:
if item in key_map:
yield data[key_map[item]]
elif item in self.standard_names:
ds = data.filter_by_attrs(standard_name=self.standard_names[item])
yield next(iter(ds))
else:
for name in self.fallback_names.get(item, []):
if name in key_map:
yield data[key_map[name]]

def calculate(self, data, name):
data = data.copy()
for func in self.solve(set(data) | set(data.coords), name):
result = func(*self._map_func_args(func, data))
retname = self._funcs[func][-1]
if isinstance(result, tuple):
for name, val in zip(retname, result):
data[name] = val
else:
data[retname] = result

return data[self.normalize(name)]

def solve(self, have, want):
# Using deque as a FIFO queue by pushing at one end and popping
# from the other--this makes this a Breadth-First Search
options = deque([Path([], set(self.normalize_names(have)), {self.normalize(want)})])
while options:
path = options.popleft()
# If calculation path is complete, return the steps
if path.is_complete():
return path.steps
else:
# Otherwise grab one of the remaining needs and
# add all methods for calculating to the current steps
# and make them options to consider
item = path.need.pop()
for trial_step in self._graph.get(item, ()):
# ValueError gets thrown if we try to repeat a function call
with contextlib.suppress(ValueError):
options.append(path + trial_step)

raise ValueError(f'Unable to calculate {want} from {have}')


solver = Solver()
17 changes: 17 additions & 0 deletions src/metpy/calc/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import scipy.optimize as so
import xarray as xr

from .field_solver import solver
from .exceptions import InvalidSoundingError
from .tools import (_greater_or_close, _less_or_close, _remove_nans, find_bounding_indices,
find_intersections, first_derivative, get_layer)
Expand All @@ -26,6 +27,7 @@


@exporter.export
@solver.register('relative_humidity')
@preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'dewpoint'))
@check_units('[temperature]', '[temperature]')
def relative_humidity_from_dewpoint(temperature, dewpoint):
Expand Down Expand Up @@ -101,6 +103,7 @@ def exner_function(pressure, reference_pressure=mpconsts.P0):


@exporter.export
@solver.register()
@preprocess_and_wrap(wrap_like='temperature', broadcast=('pressure', 'temperature'))
@check_units('[pressure]', '[temperature]')
def potential_temperature(pressure, temperature):
Expand Down Expand Up @@ -148,6 +151,7 @@ def potential_temperature(pressure, temperature):
broadcast=('pressure', 'potential_temperature')
)
@check_units('[pressure]', '[temperature]')
@solver.register('temperature')
def temperature_from_potential_temperature(pressure, potential_temperature):
r"""Calculate the temperature from a given potential temperature.

Expand Down Expand Up @@ -1011,6 +1015,7 @@ def saturation_vapor_pressure(temperature):


@exporter.export
@solver.register('dewpoint')
@preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'relative_humidity'))
@check_units('[temperature]', '[dimensionless]')
def dewpoint_from_relative_humidity(temperature, relative_humidity):
Expand Down Expand Up @@ -1163,6 +1168,7 @@ def saturation_mixing_ratio(total_press, temperature):


@exporter.export
@solver.register()
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'dewpoint')
Expand Down Expand Up @@ -1290,6 +1296,7 @@ def saturation_equivalent_potential_temperature(pressure, temperature):


@exporter.export
@solver.register()
@preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'mixing_ratio'))
@check_units('[temperature]', '[dimensionless]', '[dimensionless]')
def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpconsts.epsilon):
Expand Down Expand Up @@ -1329,6 +1336,7 @@ def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpcons


@exporter.export
@solver.register()
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'mixing_ratio')
Expand Down Expand Up @@ -1375,6 +1383,7 @@ def virtual_potential_temperature(pressure, temperature, mixing_ratio,


@exporter.export
@solver.register()
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'mixing_ratio')
Expand Down Expand Up @@ -1533,6 +1542,7 @@ def psychrometric_vapor_pressure_wet(pressure, dry_bulb_temperature, wet_bulb_te


@exporter.export
@solver.register('mixing_ratio')
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'relative_humidity')
Expand Down Expand Up @@ -1581,6 +1591,7 @@ def mixing_ratio_from_relative_humidity(pressure, temperature, relative_humidity


@exporter.export
@solver.register('relative_humidity')
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'mixing_ratio')
Expand Down Expand Up @@ -1627,6 +1638,7 @@ def relative_humidity_from_mixing_ratio(pressure, temperature, mixing_ratio):


@exporter.export
@solver.register('mixing_ratio')
@preprocess_and_wrap(wrap_like='specific_humidity')
@check_units('[dimensionless]')
def mixing_ratio_from_specific_humidity(specific_humidity):
Expand Down Expand Up @@ -1662,6 +1674,7 @@ def mixing_ratio_from_specific_humidity(specific_humidity):


@exporter.export
@solver.register('specific_humidity')
@preprocess_and_wrap(wrap_like='mixing_ratio')
@check_units('[dimensionless]')
def specific_humidity_from_mixing_ratio(mixing_ratio):
Expand Down Expand Up @@ -1697,6 +1710,7 @@ def specific_humidity_from_mixing_ratio(mixing_ratio):


@exporter.export
@solver.register('relative_humidity')
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'specific_humidity')
Expand Down Expand Up @@ -2946,6 +2960,7 @@ def brunt_vaisala_period(height, potential_temperature, vertical_dim=0):


@exporter.export
@solver.register()
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'dewpoint')
Expand Down Expand Up @@ -3051,6 +3066,7 @@ def static_stability(pressure, temperature, vertical_dim=0):


@exporter.export
@solver.register('dewpoint')
@preprocess_and_wrap(
wrap_like='temperature',
broadcast=('pressure', 'temperature', 'specific_humdiity')
Expand Down Expand Up @@ -3188,6 +3204,7 @@ def vertical_velocity(omega, pressure, temperature, mixing_ratio=0):


@exporter.export
@solver.register('specific_humidity')
@preprocess_and_wrap(wrap_like='dewpoint', broadcast=('dewpoint', 'pressure'))
@check_units('[pressure]', '[temperature]')
def specific_humidity_from_dewpoint(pressure, dewpoint):
Expand Down
7 changes: 5 additions & 2 deletions src/metpy/plots/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._mpl import TextCollection
from .cartopy_utils import import_cartopy
from .station_plot import StationPlot
from ..calc import reduce_point_density
from ..calc import reduce_point_density, solver
from ..package_tools import Exporter
from ..units import units

Expand Down Expand Up @@ -1083,7 +1083,10 @@ def griddata(self):

# Select our particular field of interest
if self.field:
data = self.data.metpy.parse_cf(self.field)
if self.field in self.data:
data = self.data.metpy.parse_cf(self.field)
else:
data = solver.calculate(self.data.metpy.parse_cf(), self.field)
elif hasattr(self.data.metpy, 'parse_cf'):
# Handles the case where we have a dataset but no specified field
raise ValueError('field attribute has not been set.')
Expand Down
10 changes: 10 additions & 0 deletions tests/calc/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ def test_heat_index_kelvin():
assert_almost_equal(hi.to('degC'), 50.3406 * units.degC, 4)


def test_heat_index_xarray():
"""Test heat_index when working with fields from xarray."""
temp = xr.DataArray(np.full((1, 4, 2, 3), 35.), attrs={'units': 'degC'},
dims=('t', 'p', 'y', 'x'))
rh = xr.DataArray(np.full((4, 1, 2, 3), 0.7), dims = ('p', 't', 'y', 'x'))

hi = heat_index(temp, rh)
assert_almost_equal(hi, units.Quantity(50.3405, 'degC'), 4)


def test_height_to_geopotential(array_type):
"""Test conversion from height to geopotential."""
mask = [False, True, False, True]
Expand Down
Loading