Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new tests #43

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyrte_rrtmgp/kernels/rrtmgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
rrtmgp_compute_tau_rayleigh,
rrtmgp_interpolation,
)
from pyrte_rrtmgp.utils import convert_xarray_args


@convert_xarray_args
def interpolation(
neta: int,
flavor: npt.NDArray,
Expand Down Expand Up @@ -111,6 +113,7 @@ def interpolation(
return jtemp.T, fmajor.T, fminor.T, col_mix.T, tropo.T, jeta.T, jpress.T


@convert_xarray_args
def compute_planck_source(
tlay,
tlev,
Expand Down Expand Up @@ -208,6 +211,7 @@ def compute_planck_source(
return sfc_src.T, lay_src.T, lev_src.T, sfc_src_jac.T


@convert_xarray_args
def compute_tau_absorption(
idx_h2o,
gpoint_flavor,
Expand Down Expand Up @@ -337,6 +341,7 @@ def compute_tau_absorption(
return tau.T


@convert_xarray_args
def compute_tau_rayleigh(
gpoint_flavor,
band_lims_gpt,
Expand Down
4 changes: 4 additions & 0 deletions pyrte_rrtmgp/rrtmgp_gas_optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
@dataclass
class GasOptics:
tau: Optional[np.ndarray] = None
tau_rayleigh: Optional[np.ndarray] = None
tau_absorption: Optional[np.ndarray] = None
g: Optional[np.ndarray] = None
ssa: Optional[np.ndarray] = None
lay_src: Optional[np.ndarray] = None
Expand Down Expand Up @@ -310,6 +312,7 @@ def compute_gas_taus(self):
self._interpolated.jpress,
)

self.gas_optics.tau_absorption = tau_absorption
if self.source_is_internal:
self.gas_optics.tau = tau_absorption
self.gas_optics.ssa = np.full_like(tau_absorption, np.nan)
Expand All @@ -332,6 +335,7 @@ def compute_gas_taus(self):
self._interpolated.jtemp,
)

self.gas_optics.tau_rayleigh = tau_rayleigh
self.gas_optics.tau = tau_absorption + tau_rayleigh
self.gas_optics.ssa = np.where(
self.gas_optics.tau > 2.0 * np.finfo(float).tiny,
Expand Down
19 changes: 19 additions & 0 deletions pyrte_rrtmgp/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import xarray as xr


def get_usecols(solar_zenith_angle):
Expand Down Expand Up @@ -41,3 +42,21 @@ def compute_toa_flux(total_solar_irradiance, solar_source):
toa_flux = np.stack([solar_source] * ncol)
def_tsi = toa_flux.sum(axis=1)
return (toa_flux.T * (total_solar_irradiance / def_tsi)).T


def convert_xarray_args(func):
def wrapper(*args, **kwargs):
output_args = []
for x in args:
if isinstance(x, xr.DataArray):
output_args.append(x.data)
else:
output_args.append(x)
for k, v in kwargs.items():
if isinstance(v, xr.DataArray):
kwargs[k] = v.data
else:
kwargs[k] = v
return func(*output_args, **kwargs)

return wrapper
193 changes: 193 additions & 0 deletions tests/test_python_frontend/test_gas_optics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os

import numpy as np
import pytest
import xarray as xr
from pyrte_rrtmgp import rrtmgp_gas_optics
from pyrte_rrtmgp.kernels.rrtmgp import (
compute_planck_source,
compute_tau_absorption,
compute_tau_rayleigh,
interpolation,
)

from utils import convert_args_arrays

ERROR_TOLERANCE = 1e-4

rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data")
clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs"

rfmip = xr.load_dataset(
f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc"
)
rfmip = rfmip.sel(expt=0) # only one experiment
kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc")
kdist_sw = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc")

rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)
rrtmgp_gas_optics_sw = kdist_sw.gas_optics.load_atmosferic_conditions(rfmip)

# Prepare the arguments for the interpolation function
interpolation_args = [
len(kdist["mixing_fraction"]),
kdist.gas_optics.flavors_sets,
kdist["press_ref"].values,
kdist["temp_ref"].values,
kdist["press_ref_trop"].values.item(),
kdist.gas_optics.vmr_ref,
rfmip["pres_layer"].values,
rfmip["temp_layer"].values,
kdist.gas_optics.col_gas,
]

expected_output = (
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.fminor,
kdist.gas_optics._interpolated.col_mix,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.jpress,
)


@pytest.mark.parametrize(
"args, expected",
[(i, expected_output) for i in convert_args_arrays(interpolation_args)],
)
def test_compute_interpoaltion(args, expected):
result = interpolation(*args)
assert len(result) == len(expected)
for r, e in zip(result, expected):
assert r.shape == e.shape
assert np.isclose(r, e, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_planck_source function
planck_source_args = [
rfmip["temp_layer"].data,
rfmip["temp_level"].data,
rfmip["surface_temperature"].data,
kdist.gas_optics.top_at_1,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.jpress,
kdist["bnd_limits_gpt"].data.T,
kdist["plank_fraction"].data.transpose(0, 2, 1, 3),
kdist["temp_ref"].data.min(),
kdist["temp_ref"].data.max(),
kdist["totplnk"].data.T,
kdist.gas_optics.gpoint_flavor,
]

expected_output = (
rrtmgp_gas_optics.sfc_src,
rrtmgp_gas_optics.lay_src,
rrtmgp_gas_optics.lev_src,
rrtmgp_gas_optics.sfc_src_jac,
)


@pytest.mark.parametrize(
"args, expected",
[(i, expected_output) for i in convert_args_arrays(planck_source_args)],
)
def test_compute_planck_source(args, expected):
result = compute_planck_source(*args)
assert len(result) == len(expected)
for r, e in zip(result, expected):
assert r.shape == e.shape
assert np.isclose(r, e, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_tau_absorption function
minor_gases_lower = kdist.gas_optics.extract_names(kdist["minor_gases_lower"].data)
minor_gases_upper = kdist.gas_optics.extract_names(kdist["minor_gases_upper"].data)
idx_minor_lower = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, minor_gases_lower
)
idx_minor_upper = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, minor_gases_upper
)

scaling_gas_lower = kdist.gas_optics.extract_names(kdist["scaling_gas_lower"].data)
scaling_gas_upper = kdist.gas_optics.extract_names(kdist["scaling_gas_upper"].data)
idx_minor_scaling_lower = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, scaling_gas_lower
)
idx_minor_scaling_upper = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, scaling_gas_upper
)

tau_absorption_args = [
kdist.gas_optics.idx_h2o,
kdist.gas_optics.gpoint_flavor,
kdist["bnd_limits_gpt"].values.T,
kdist["kmajor"].values,
kdist["kminor_lower"].values,
kdist["kminor_upper"].values,
kdist["minor_limits_gpt_lower"].values.T,
kdist["minor_limits_gpt_upper"].values.T,
kdist["minor_scales_with_density_lower"].values.astype(bool),
kdist["minor_scales_with_density_upper"].values.astype(bool),
kdist["scale_by_complement_lower"].values.astype(bool),
kdist["scale_by_complement_upper"].values.astype(bool),
idx_minor_lower,
idx_minor_upper,
idx_minor_scaling_lower,
idx_minor_scaling_upper,
kdist["kminor_start_lower"].values,
kdist["kminor_start_upper"].values,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.col_mix,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.fminor,
rfmip["pres_layer"].values,
rfmip["temp_layer"].values,
kdist.gas_optics.col_gas,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.jpress,
]


@pytest.mark.parametrize(
"args, expected",
[
(i, rrtmgp_gas_optics.tau_absorption)
for i in convert_args_arrays(tau_absorption_args)
],
)
def test_compute_tau_absorption(args, expected):
result = compute_tau_absorption(*args)
assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_tau_rayleigh function
tau_rayleigh_args = [
kdist_sw.gas_optics.gpoint_flavor,
kdist_sw["bnd_limits_gpt"].values.T,
np.stack([kdist_sw["rayl_lower"].values, kdist_sw["rayl_upper"].values], axis=-1),
kdist_sw.gas_optics.idx_h2o,
kdist_sw.gas_optics.col_gas[:, :, 0],
kdist_sw.gas_optics.col_gas,
kdist_sw.gas_optics._interpolated.fminor,
kdist_sw.gas_optics._interpolated.jeta,
kdist_sw.gas_optics._interpolated.tropo,
kdist_sw.gas_optics._interpolated.jtemp,
]


@pytest.mark.parametrize(
"args, expected",
[
(i, rrtmgp_gas_optics_sw.tau_rayleigh)
for i in convert_args_arrays(tau_rayleigh_args)
],
)
def test_compute_tau_rayleigh(args, expected):
result = compute_tau_rayleigh(*args)
assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all()
24 changes: 24 additions & 0 deletions tests/test_python_frontend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import xarray as xr


def convert_args_arrays(input_args, arrays_dtypes=[np.float64, np.float32]):
args_to_test = []
for dtype in arrays_dtypes:
args = []
for item in input_args:
if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes:
output_item = item.astype(dtype)
else:
output_item = item
args.append(output_item)
args_to_test.append(args)
args = []
for item in input_args:
if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes:
output_item = xr.DataArray(item)
else:
output_item = item
args.append(output_item)
args_to_test.append(args)
return args_to_test