From 4dfad54ce4a0ecde2852537c7b0e564e05bd29d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josu=C3=A9=20Sehnem?= Date: Tue, 9 Apr 2024 18:23:48 -0300 Subject: [PATCH] add new tests --- pyrte_rrtmgp/kernels/rrtmgp.py | 5 + pyrte_rrtmgp/rrtmgp_gas_optics.py | 4 + pyrte_rrtmgp/utils.py | 19 ++ tests/test_python_frontend/test_gas_optics.py | 193 ++++++++++++++++++ tests/test_python_frontend/utils.py | 24 +++ 5 files changed, 245 insertions(+) create mode 100644 tests/test_python_frontend/test_gas_optics.py create mode 100644 tests/test_python_frontend/utils.py diff --git a/pyrte_rrtmgp/kernels/rrtmgp.py b/pyrte_rrtmgp/kernels/rrtmgp.py index a286c9d..9bd82d0 100644 --- a/pyrte_rrtmgp/kernels/rrtmgp.py +++ b/pyrte_rrtmgp/kernels/rrtmgp.py @@ -9,8 +9,10 @@ rrtmgp_compute_tau_rayleigh, rrtmgp_interpolation, ) +from pyrte_rrtmgp.utils import convert_xarray_args +@convert_xarray_args def interpolation( neta: int, flavor: npt.NDArray, @@ -111,6 +113,7 @@ def interpolation( return jtemp.T, fmajor.T, fminor.T, col_mix.T, tropo.T, jeta.T, jpress.T +@convert_xarray_args def compute_planck_source( tlay, tlev, @@ -208,6 +211,7 @@ def compute_planck_source( return sfc_src.T, lay_src.T, lev_src.T, sfc_src_jac.T +@convert_xarray_args def compute_tau_absorption( idx_h2o, gpoint_flavor, @@ -337,6 +341,7 @@ def compute_tau_absorption( return tau.T +@convert_xarray_args def compute_tau_rayleigh( gpoint_flavor, band_lims_gpt, diff --git a/pyrte_rrtmgp/rrtmgp_gas_optics.py b/pyrte_rrtmgp/rrtmgp_gas_optics.py index 693bec4..9e27257 100644 --- a/pyrte_rrtmgp/rrtmgp_gas_optics.py +++ b/pyrte_rrtmgp/rrtmgp_gas_optics.py @@ -23,6 +23,8 @@ @dataclass class GasOptics: tau: Optional[np.ndarray] = None + tau_rayleigh: Optional[np.ndarray] = None + tau_absorption: Optional[np.ndarray] = None g: Optional[np.ndarray] = None ssa: Optional[np.ndarray] = None lay_src: Optional[np.ndarray] = None @@ -310,6 +312,7 @@ def compute_gas_taus(self): self._interpolated.jpress, ) + self.gas_optics.tau_absorption = tau_absorption if self.source_is_internal: self.gas_optics.tau = tau_absorption self.gas_optics.ssa = np.full_like(tau_absorption, np.nan) @@ -332,6 +335,7 @@ def compute_gas_taus(self): self._interpolated.jtemp, ) + self.gas_optics.tau_rayleigh = tau_rayleigh self.gas_optics.tau = tau_absorption + tau_rayleigh self.gas_optics.ssa = np.where( self.gas_optics.tau > 2.0 * np.finfo(float).tiny, diff --git a/pyrte_rrtmgp/utils.py b/pyrte_rrtmgp/utils.py index 366dafe..a95d314 100644 --- a/pyrte_rrtmgp/utils.py +++ b/pyrte_rrtmgp/utils.py @@ -1,4 +1,5 @@ import numpy as np +import xarray as xr def get_usecols(solar_zenith_angle): @@ -41,3 +42,21 @@ def compute_toa_flux(total_solar_irradiance, solar_source): toa_flux = np.stack([solar_source] * ncol) def_tsi = toa_flux.sum(axis=1) return (toa_flux.T * (total_solar_irradiance / def_tsi)).T + + +def convert_xarray_args(func): + def wrapper(*args, **kwargs): + output_args = [] + for x in args: + if isinstance(x, xr.DataArray): + output_args.append(x.data) + else: + output_args.append(x) + for k, v in kwargs.items(): + if isinstance(v, xr.DataArray): + kwargs[k] = v.data + else: + kwargs[k] = v + return func(*output_args, **kwargs) + + return wrapper diff --git a/tests/test_python_frontend/test_gas_optics.py b/tests/test_python_frontend/test_gas_optics.py new file mode 100644 index 0000000..35e5707 --- /dev/null +++ b/tests/test_python_frontend/test_gas_optics.py @@ -0,0 +1,193 @@ +import os + +import numpy as np +import pytest +import xarray as xr +from pyrte_rrtmgp import rrtmgp_gas_optics +from pyrte_rrtmgp.kernels.rrtmgp import ( + compute_planck_source, + compute_tau_absorption, + compute_tau_rayleigh, + interpolation, +) + +from utils import convert_args_arrays + +ERROR_TOLERANCE = 1e-4 + +rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data") +clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" + +rfmip = xr.load_dataset( + f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" +) +rfmip = rfmip.sel(expt=0) # only one experiment +kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc") +kdist_sw = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc") + +rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) +rrtmgp_gas_optics_sw = kdist_sw.gas_optics.load_atmosferic_conditions(rfmip) + +# Prepare the arguments for the interpolation function +interpolation_args = [ + len(kdist["mixing_fraction"]), + kdist.gas_optics.flavors_sets, + kdist["press_ref"].values, + kdist["temp_ref"].values, + kdist["press_ref_trop"].values.item(), + kdist.gas_optics.vmr_ref, + rfmip["pres_layer"].values, + rfmip["temp_layer"].values, + kdist.gas_optics.col_gas, +] + +expected_output = ( + kdist.gas_optics._interpolated.jtemp, + kdist.gas_optics._interpolated.fmajor, + kdist.gas_optics._interpolated.fminor, + kdist.gas_optics._interpolated.col_mix, + kdist.gas_optics._interpolated.tropo, + kdist.gas_optics._interpolated.jeta, + kdist.gas_optics._interpolated.jpress, +) + + +@pytest.mark.parametrize( + "args, expected", + [(i, expected_output) for i in convert_args_arrays(interpolation_args)], +) +def test_compute_interpoaltion(args, expected): + result = interpolation(*args) + assert len(result) == len(expected) + for r, e in zip(result, expected): + assert r.shape == e.shape + assert np.isclose(r, e, atol=ERROR_TOLERANCE).all() + + +# Prepare the arguments for the compute_planck_source function +planck_source_args = [ + rfmip["temp_layer"].data, + rfmip["temp_level"].data, + rfmip["surface_temperature"].data, + kdist.gas_optics.top_at_1, + kdist.gas_optics._interpolated.fmajor, + kdist.gas_optics._interpolated.jeta, + kdist.gas_optics._interpolated.tropo, + kdist.gas_optics._interpolated.jtemp, + kdist.gas_optics._interpolated.jpress, + kdist["bnd_limits_gpt"].data.T, + kdist["plank_fraction"].data.transpose(0, 2, 1, 3), + kdist["temp_ref"].data.min(), + kdist["temp_ref"].data.max(), + kdist["totplnk"].data.T, + kdist.gas_optics.gpoint_flavor, +] + +expected_output = ( + rrtmgp_gas_optics.sfc_src, + rrtmgp_gas_optics.lay_src, + rrtmgp_gas_optics.lev_src, + rrtmgp_gas_optics.sfc_src_jac, +) + + +@pytest.mark.parametrize( + "args, expected", + [(i, expected_output) for i in convert_args_arrays(planck_source_args)], +) +def test_compute_planck_source(args, expected): + result = compute_planck_source(*args) + assert len(result) == len(expected) + for r, e in zip(result, expected): + assert r.shape == e.shape + assert np.isclose(r, e, atol=ERROR_TOLERANCE).all() + + +# Prepare the arguments for the compute_tau_absorption function +minor_gases_lower = kdist.gas_optics.extract_names(kdist["minor_gases_lower"].data) +minor_gases_upper = kdist.gas_optics.extract_names(kdist["minor_gases_upper"].data) +idx_minor_lower = kdist.gas_optics.get_idx_minor( + kdist.gas_optics.gas_names, minor_gases_lower +) +idx_minor_upper = kdist.gas_optics.get_idx_minor( + kdist.gas_optics.gas_names, minor_gases_upper +) + +scaling_gas_lower = kdist.gas_optics.extract_names(kdist["scaling_gas_lower"].data) +scaling_gas_upper = kdist.gas_optics.extract_names(kdist["scaling_gas_upper"].data) +idx_minor_scaling_lower = kdist.gas_optics.get_idx_minor( + kdist.gas_optics.gas_names, scaling_gas_lower +) +idx_minor_scaling_upper = kdist.gas_optics.get_idx_minor( + kdist.gas_optics.gas_names, scaling_gas_upper +) + +tau_absorption_args = [ + kdist.gas_optics.idx_h2o, + kdist.gas_optics.gpoint_flavor, + kdist["bnd_limits_gpt"].values.T, + kdist["kmajor"].values, + kdist["kminor_lower"].values, + kdist["kminor_upper"].values, + kdist["minor_limits_gpt_lower"].values.T, + kdist["minor_limits_gpt_upper"].values.T, + kdist["minor_scales_with_density_lower"].values.astype(bool), + kdist["minor_scales_with_density_upper"].values.astype(bool), + kdist["scale_by_complement_lower"].values.astype(bool), + kdist["scale_by_complement_upper"].values.astype(bool), + idx_minor_lower, + idx_minor_upper, + idx_minor_scaling_lower, + idx_minor_scaling_upper, + kdist["kminor_start_lower"].values, + kdist["kminor_start_upper"].values, + kdist.gas_optics._interpolated.tropo, + kdist.gas_optics._interpolated.col_mix, + kdist.gas_optics._interpolated.fmajor, + kdist.gas_optics._interpolated.fminor, + rfmip["pres_layer"].values, + rfmip["temp_layer"].values, + kdist.gas_optics.col_gas, + kdist.gas_optics._interpolated.jeta, + kdist.gas_optics._interpolated.jtemp, + kdist.gas_optics._interpolated.jpress, +] + + +@pytest.mark.parametrize( + "args, expected", + [ + (i, rrtmgp_gas_optics.tau_absorption) + for i in convert_args_arrays(tau_absorption_args) + ], +) +def test_compute_tau_absorption(args, expected): + result = compute_tau_absorption(*args) + assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all() + + +# Prepare the arguments for the compute_tau_rayleigh function +tau_rayleigh_args = [ + kdist_sw.gas_optics.gpoint_flavor, + kdist_sw["bnd_limits_gpt"].values.T, + np.stack([kdist_sw["rayl_lower"].values, kdist_sw["rayl_upper"].values], axis=-1), + kdist_sw.gas_optics.idx_h2o, + kdist_sw.gas_optics.col_gas[:, :, 0], + kdist_sw.gas_optics.col_gas, + kdist_sw.gas_optics._interpolated.fminor, + kdist_sw.gas_optics._interpolated.jeta, + kdist_sw.gas_optics._interpolated.tropo, + kdist_sw.gas_optics._interpolated.jtemp, +] + + +@pytest.mark.parametrize( + "args, expected", + [ + (i, rrtmgp_gas_optics_sw.tau_rayleigh) + for i in convert_args_arrays(tau_rayleigh_args) + ], +) +def test_compute_tau_rayleigh(args, expected): + result = compute_tau_rayleigh(*args) + assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/utils.py b/tests/test_python_frontend/utils.py new file mode 100644 index 0000000..ce5c50a --- /dev/null +++ b/tests/test_python_frontend/utils.py @@ -0,0 +1,24 @@ +import numpy as np +import xarray as xr + + +def convert_args_arrays(input_args, arrays_dtypes=[np.float64, np.float32]): + args_to_test = [] + for dtype in arrays_dtypes: + args = [] + for item in input_args: + if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes: + output_item = item.astype(dtype) + else: + output_item = item + args.append(output_item) + args_to_test.append(args) + args = [] + for item in input_args: + if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes: + output_item = xr.DataArray(item) + else: + output_item = item + args.append(output_item) + args_to_test.append(args) + return args_to_test