diff --git a/docs/source/reference_guide/pyrte_rrtmgp_python_modules.rst b/docs/source/reference_guide/pyrte_rrtmgp_python_modules.rst index 356c5eb..c672c4d 100644 --- a/docs/source/reference_guide/pyrte_rrtmgp_python_modules.rst +++ b/docs/source/reference_guide/pyrte_rrtmgp_python_modules.rst @@ -8,18 +8,18 @@ The documentation below provides details for several of the high-level functions For more information about the low-level functions available in the ``pyrte_rrtmgp.pyrte_rrtmgp`` submodule, please go to :ref:`low_level_interface`. -pyrte\_rrtmgp.rte module ------------------------- +pyrte\_rrtmgp.kernels.rte module +-------------------------------- -.. automodule:: pyrte_rrtmgp.rte +.. automodule:: pyrte_rrtmgp.kernels.rte :members: :undoc-members: :show-inheritance: -pyrte\_rrtmgp.rrtmgp module ---------------------------- +pyrte\_rrtmgp.kernels.rrtmgp module +----------------------------------- -.. automodule:: pyrte_rrtmgp.rrtmgp +.. automodule:: pyrte_rrtmgp.kernels.rrtmgp :members: :undoc-members: :show-inheritance: diff --git a/docs/source/user_guide/usage.md b/docs/source/user_guide/usage.md index ecabbb0..17b7e21 100644 --- a/docs/source/user_guide/usage.md +++ b/docs/source/user_guide/usage.md @@ -7,8 +7,8 @@ This section provides a brief overview of how to use `pyrte_rrtmgp` with Python. The `pyrte_rrtmgp` package contains the following submodules: - `pyrte_rrtmgp.pyrte_rrtmgp`: The main module that provides access to a subset of RTE-RRTMGP's Fortran functions in Python. The functions available in this module mirror the Fortran functions (see below). You can think of this as the low-level implementation that allows you to access the respective Fortran functions directly in Python. See [](low_level_interface) for more information. -- `pyrte_rrtmgp.rte`: A high-level module that provides a more user-friendly Python interface for select RTE functions. This module is still under development and will be expanded in future releases. See [](module_ref) for details. -- `pyrte_rrtmgp.rrtmgp`: A high-level module that provides a more user-friendly Python interface for select RRTMGP functions. This module is still under development and will be expanded in future releases. See [](module_ref) for details. +- `pyrte_rrtmgp.kernels.rte`: A high-level module that provides a more user-friendly Python interface for select RTE functions. This module is still under development and will be expanded in future releases. See [](module_ref) for details. +- `pyrte_rrtmgp.kernels.rrtmgp`: A high-level module that provides a more user-friendly Python interface for select RRTMGP functions. This module is still under development and will be expanded in future releases. See [](module_ref) for details. - `pyrte_rrtmgp.utils`: A module that provides utility functions for working with RTE-RRTMGP data. This module is still under development and will be expanded in future releases. See [](module_ref) for details. ```{seealso} diff --git a/examples/lw_example.ipynb b/examples/lw_example.ipynb index 4e093e0..3248af6 100644 --- a/examples/lw_example.ipynb +++ b/examples/lw_example.ipynb @@ -6,10 +6,12 @@ "metadata": {}, "outputs": [], "source": [ - "from pyrte_rrtmgp.gas_optics import GasOptics\n", - "import xarray as xr\n", "import numpy as np\n", - "from pyrte_rrtmgp.rte import lw_solver_noscat\n", + "import xarray as xr\n", + "\n", + "from pyrte_rrtmgp import rrtmgp_gas_optics\n", + "from pyrte_rrtmgp.kernels.rte import lw_solver_noscat\n", + "\n", "\n", "ERROR_TOLERANCE = 1e-4\n", "\n", @@ -22,33 +24,25 @@ "rfmip = rfmip.sel(expt=0) # only one experiment\n", "\n", "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc\")\n", - "\n", - "# RRTMGP won't run with pressure less than its minimum. so we add a small value to the minimum pressure\n", - "# There was an issue replicating k_dist%get_press_min() + epsilon(k_dist%get_press_min()) in python, sys.epsilon is not the same\n", - "min_index = rfmip[\"pres_level\"].argmin()\n", - "rfmip[\"pres_level\"][:, min_index] = 1.0051835744630002\n", - "\n", - "gas_optics = GasOptics(kdist, rfmip)\n", - "gas_optics.source_is_internal\n", - "tau, _, _, layer_src, level_src, sfc_src, sfc_src_jac = gas_optics.gas_optics()\n", - "\n", - "# Expand the surface emissivity to ngpt\n", - "sfc_emis = rfmip[\"surface_emissivity\"].values\n", - "sfc_emis = np.stack([sfc_emis]*tau.shape[2]).T\n", + "rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)\n", "\n", "_, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat(\n", - " tau=tau,\n", - " lay_source=layer_src,\n", - " lev_source=level_src,\n", - " sfc_emis=sfc_emis,\n", - " sfc_src=sfc_src,\n", - " sfc_src_jac=sfc_src_jac\n", + " tau=rrtmgp_gas_optics.tau,\n", + " lay_source=rrtmgp_gas_optics.lay_src,\n", + " lev_source=rrtmgp_gas_optics.lev_src,\n", + " sfc_emis=rfmip[\"surface_emissivity\"].data,\n", + " sfc_src=rrtmgp_gas_optics.sfc_src,\n", + " sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac,\n", ")\n", "\n", - "rlu = xr.load_dataset(\"../tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\")\n", + "rlu = xr.load_dataset(\n", + " \"../tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + ")\n", "ref_flux_up = rlu.isel(expt=0)[\"rlu\"].values\n", "\n", - "rld = xr.load_dataset(\"../tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\")\n", + "rld = xr.load_dataset(\n", + " \"../tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + ")\n", "ref_flux_down = rld.isel(expt=0)[\"rld\"].values\n", "\n", "assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all()\n", diff --git a/examples/sw_example.ipynb b/examples/sw_example.ipynb index 21244a3..10b5563 100644 --- a/examples/sw_example.ipynb +++ b/examples/sw_example.ipynb @@ -2,15 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from pyrte_rrtmgp.gas_optics import GasOptics\n", - "import xarray as xr\n", "import numpy as np\n", - "from pyrte_rrtmgp.rte import sw_solver_2stream\n", - "from pyrte_rrtmgp.utils import compute_mu0, get_usecols\n", + "import xarray as xr\n", + "\n", + "from pyrte_rrtmgp import rrtmgp_gas_optics\n", + "from pyrte_rrtmgp.kernels.rte import sw_solver_2stream\n", + "from pyrte_rrtmgp.utils import compute_mu0, get_usecols, compute_toa_flux\n", "\n", "ERROR_TOLERANCE = 1e-4\n", "\n", @@ -23,42 +24,24 @@ "rfmip = rfmip.sel(expt=0) # only one experiment\n", "\n", "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc\")\n", + "gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)\n", "\n", - "# RRTMGP won't run with pressure less than its minimum. so we add a small value to the minimum pressure\n", - "# There was an issue replicating k_dist%get_press_min() + epsilon(k_dist%get_press_min()) in python, sys.epsilon is not the same\n", - "min_index = rfmip[\"pres_level\"].argmin()\n", - "rfmip[\"pres_level\"][:, min_index] = 1.0051835744630002\n", - "\n", - "gas_optics = GasOptics(kdist, rfmip)\n", - "gas_optics.source_is_internal\n", - "tau, g, ssa, toa_flux = gas_optics.gas_optics()\n", - "\n", - "pres_layers = rfmip[\"pres_layer\"][\"layer\"]\n", - "top_at_1 = (pres_layers[0] < pres_layers[-1]).values.item()\n", - "\n", - "# Expand the surface albedo to ngpt\n", - "ngpt = len(kdist[\"gpt\"])\n", - "surface_albedo = rfmip[\"surface_albedo\"].values\n", - "surface_albedo = np.stack([surface_albedo]*ngpt)\n", - "sfc_alb_dir = surface_albedo.T.copy()\n", - "sfc_alb_dif = surface_albedo.T.copy()\n", + "surface_albedo = rfmip[\"surface_albedo\"].data\n", + "total_solar_irradiance = rfmip[\"total_solar_irradiance\"].data\n", "\n", "nlayer = len(rfmip[\"layer\"])\n", "mu0 = compute_mu0(rfmip[\"solar_zenith_angle\"].values, nlayer=nlayer)\n", "\n", - "total_solar_irradiance = rfmip[\"total_solar_irradiance\"].values\n", - "toa_flux = np.stack([toa_flux]*mu0.shape[0])\n", - "def_tsi = toa_flux.sum(axis=1)\n", - "toa_flux = (toa_flux.T * (total_solar_irradiance/def_tsi)).T\n", + "toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source)\n", "\n", "_, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream(\n", - " top_at_1,\n", - " tau,\n", - " ssa,\n", - " g,\n", + " kdist.gas_optics.top_at_1,\n", + " gas_optics.tau,\n", + " gas_optics.ssa,\n", + " gas_optics.g,\n", " mu0,\n", - " sfc_alb_dir,\n", - " sfc_alb_dif,\n", + " sfc_alb_dir=surface_albedo,\n", + " sfc_alb_dif=surface_albedo,\n", " inc_flux_dir=toa_flux,\n", " inc_flux_dif=None,\n", " has_dif_bc=False,\n", @@ -73,10 +56,14 @@ "solver_flux_down = solver_flux_down * usecol[:, np.newaxis]\n", "\n", "# Compare the results with the reference data\n", - "rsu = xr.load_dataset(\"../tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\")\n", + "rsu = xr.load_dataset(\n", + " \"../tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + ")\n", "ref_flux_up = rsu.isel(expt=0)[\"rsu\"].values\n", "\n", - "rsd = xr.load_dataset(\"../tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\")\n", + "rsd = xr.load_dataset(\n", + " \"../tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + ")\n", "ref_flux_down = rsd.isel(expt=0)[\"rsd\"].values\n", "\n", "assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all()\n", diff --git a/pyrte_rrtmgp/constants.py b/pyrte_rrtmgp/constants.py new file mode 100644 index 0000000..d0e59a5 --- /dev/null +++ b/pyrte_rrtmgp/constants.py @@ -0,0 +1,5 @@ +HELMERT1 = 9.80665 +HELMERT2 = 0.02586 +M_DRY = 0.028964 +M_H2O = 0.018016 +AVOGAD = 6.02214076e23 diff --git a/pyrte_rrtmgp/exceptions.py b/pyrte_rrtmgp/exceptions.py new file mode 100644 index 0000000..9daaa77 --- /dev/null +++ b/pyrte_rrtmgp/exceptions.py @@ -0,0 +1,16 @@ +class NotInternalSourceError(ValueError): + pass + + +class NotExternalSourceError(ValueError): + pass + + +class MissingAtmosfericConditionsError(AttributeError): + message = ( + "You need to load the atmospheric conditions first." + "Use the method load_atmosferic_conditions with an appropriated file." + ) + + def __init__(self, message=message): + super().__init__(message) diff --git a/pyrte_rrtmgp/gas_optics.py b/pyrte_rrtmgp/gas_optics.py deleted file mode 100644 index 09cef9d..0000000 --- a/pyrte_rrtmgp/gas_optics.py +++ /dev/null @@ -1,280 +0,0 @@ -import numpy as np - -from pyrte_rrtmgp.rrtmgp import ( - compute_planck_source, - compute_tau_absorption, - compute_tau_rayleigh, - interpolation, -) -from pyrte_rrtmgp.utils import ( - combine_abs_and_rayleigh, - extract_gas_names, - flavors_from_kdist, - get_idx_minor, - gpoint_flavor_from_kdist, - krayl_from_kdist, - rfmip_2_col_gas, -) - - -class GasOptics: - - def __init__(self, kdist, rfmip, gases_to_use=None): - self.kdist = kdist - self.rfmip = rfmip - - kdist_gas_names = extract_gas_names(kdist["gas_names"].values) - self.kdist_gas_names = kdist_gas_names - rfmip_vars = list(rfmip.keys()) - - gas_names = {n: n + "_GM" for n in kdist_gas_names if n + "_GM" in rfmip_vars} - - # Create a dict that maps the gas names in the kdist gas names to the gas names in the rfmip dataset - gas_names.update( - { - "co": "carbon_monoxide_GM", - "ch4": "methane_GM", - "o2": "oxygen_GM", - "n2o": "nitrous_oxide_GM", - "n2": "nitrogen_GM", - "co2": "carbon_dioxide_GM", - "ccl4": "carbon_tetrachloride_GM", - "cfc22": "hcfc22_GM", - "h2o": "water_vapor", - "o3": "ozone", - "no2": "no2", - } - ) - - # sort gas names based on kdist - gas_names = {g: gas_names[g] for g in kdist_gas_names if g in gas_names} - - if gases_to_use is not None: - gas_names = {g: gas_names[g] for g in gases_to_use} - - self.gas_names = gas_names - - self.tlay = rfmip["temp_layer"].values - self.play = rfmip["pres_layer"].values - self.col_gas = rfmip_2_col_gas(rfmip, list(gas_names.values()), dry_air=True) - - @property - def source_is_internal(self): - variables = self.kdist.data_vars - return "totplnk" in variables and "plank_fraction" in variables - - def gas_optics(self): - if self.source_is_internal: - self.interpolate() - self.compute_planck() - self.compute_gas_taus() - return ( - self.tau, - self.g, - self.ssa, - self.lay_src, - self.lev_src, - self.sfc_src, - self.sfc_src_jac, - ) - else: - self.interpolate() - self.compute_gas_taus() - self.compute_solar_variability() - return self.tau, self.g, self.ssa, self.solar_source - - def interpolate(self): - neta = len(self.kdist["mixing_fraction"]) - press_ref = self.kdist["press_ref"].values - temp_ref = self.kdist["temp_ref"].values - - press_ref_trop = self.kdist["press_ref_trop"].values.item() - - # dry air is zero - vmr_idx = [ - i for i, g in enumerate(self.kdist_gas_names, 1) if g in self.gas_names - ] - vmr_idx = [0] + vmr_idx - vmr_ref = self.kdist["vmr_ref"].sel(absorber_ext=vmr_idx).values.T - - # just the unique sets of gases - flavor = flavors_from_kdist(self.kdist) - - ( - self.jtemp, - self.fmajor, - self.fminor, - self.col_mix, - self.tropo, - self.jeta, - self.jpress, - ) = interpolation( - neta=neta, - flavor=flavor, - press_ref=press_ref, - temp_ref=temp_ref, - press_ref_trop=press_ref_trop, - vmr_ref=vmr_ref, - play=self.play, - tlay=self.tlay, - col_gas=self.col_gas, - ) - - def compute_planck(self): - - tlay = self.rfmip["temp_layer"].values - tlev = self.rfmip["temp_level"].values - tsfc = self.rfmip["surface_temperature"].values - pres_layers = self.rfmip["pres_layer"]["layer"] - top_at_1 = pres_layers[0] < pres_layers[-1] - band_lims_gpt = self.kdist["bnd_limits_gpt"].values.T - pfracin = self.kdist["plank_fraction"].values.transpose(0, 2, 1, 3) - temp_ref_min = self.kdist["temp_ref"].values.min() - temp_ref_max = self.kdist["temp_ref"].values.max() - totplnk = self.kdist["totplnk"].values.T - - gpoint_flavor = gpoint_flavor_from_kdist(self.kdist) - - self.sfc_src, self.lay_src, self.lev_src, self.sfc_src_jac = ( - compute_planck_source( - tlay, - tlev, - tsfc, - top_at_1, - self.fmajor, - self.jeta, - self.tropo, - self.jtemp, - self.jpress, - band_lims_gpt, - pfracin, - temp_ref_min, - temp_ref_max, - totplnk, - gpoint_flavor, - ) - ) - - def compute_gas_taus(self): - - idx_h2o = list(self.gas_names).index("h2o") + 1 - - gpoint_flavor = gpoint_flavor_from_kdist(self.kdist) - - kmajor = self.kdist["kmajor"].values - kminor_lower = self.kdist["kminor_lower"].values - kminor_upper = self.kdist["kminor_upper"].values - minor_limits_gpt_lower = self.kdist["minor_limits_gpt_lower"].values.T - minor_limits_gpt_upper = self.kdist["minor_limits_gpt_upper"].values.T - - minor_scales_with_density_lower = self.kdist[ - "minor_scales_with_density_lower" - ].values.astype(bool) - minor_scales_with_density_upper = self.kdist[ - "minor_scales_with_density_upper" - ].values.astype(bool) - scale_by_complement_lower = self.kdist[ - "scale_by_complement_lower" - ].values.astype(bool) - scale_by_complement_upper = self.kdist[ - "scale_by_complement_upper" - ].values.astype(bool) - - gas_name_list = list(self.gas_names.keys()) - - band_lims_gpt = self.kdist["bnd_limits_gpt"].values.T - - minor_gases_lower = extract_gas_names(self.kdist["minor_gases_lower"].values) - minor_gases_upper = extract_gas_names(self.kdist["minor_gases_upper"].values) - # check if the index is correct - idx_minor_lower = get_idx_minor(gas_name_list, minor_gases_lower) - idx_minor_upper = get_idx_minor(gas_name_list, minor_gases_upper) - - minor_scaling_gas_lower = extract_gas_names( - self.kdist["scaling_gas_lower"].values - ) - minor_scaling_gas_upper = extract_gas_names( - self.kdist["scaling_gas_upper"].values - ) - - idx_minor_scaling_lower = get_idx_minor(gas_name_list, minor_scaling_gas_lower) - idx_minor_scaling_upper = get_idx_minor(gas_name_list, minor_scaling_gas_upper) - - kminor_start_lower = self.kdist["kminor_start_lower"].values - kminor_start_upper = self.kdist["kminor_start_upper"].values - - tau_absorption = compute_tau_absorption( - idx_h2o, - gpoint_flavor, - band_lims_gpt, - kmajor, - kminor_lower, - kminor_upper, - minor_limits_gpt_lower, - minor_limits_gpt_upper, - minor_scales_with_density_lower, - minor_scales_with_density_upper, - scale_by_complement_lower, - scale_by_complement_upper, - idx_minor_lower, - idx_minor_upper, - idx_minor_scaling_lower, - idx_minor_scaling_upper, - kminor_start_lower, - kminor_start_upper, - self.tropo, - self.col_mix, - self.fmajor, - self.fminor, - self.play, - self.tlay, - self.col_gas, - self.jeta, - self.jtemp, - self.jpress, - ) - - if self.source_is_internal: - self.tau = tau_absorption - self.ssa = np.full_like(tau_absorption, np.nan) - self.g = np.full_like(tau_absorption, np.nan) - else: - krayl = krayl_from_kdist(self.kdist) - tau_rayleigh = compute_tau_rayleigh( - gpoint_flavor, - band_lims_gpt, - krayl, - idx_h2o, - self.col_gas[:, :, 0], - self.col_gas, - self.fminor, - self.jeta, - self.tropo, - self.jtemp, - ) - self.tau, self.ssa, self.g = combine_abs_and_rayleigh( - tau_absorption, tau_rayleigh - ) - - def compute_solar_variability(self): - """Calculate the solar variability - - Returns: - np.ndarray: Solar source - """ - - a_offset = 0.1495954 - b_offset = 0.00066696 - - solar_source_quiet = self.kdist["solar_source_quiet"] - solar_source_facular = self.kdist["solar_source_facular"] - solar_source_sunspot = self.kdist["solar_source_sunspot"] - - mg_index = self.kdist["mg_default"] - sb_index = self.kdist["sb_default"] - - self.solar_source = ( - solar_source_quiet - + (mg_index - a_offset) * solar_source_facular - + (sb_index - b_offset) * solar_source_sunspot - ) diff --git a/pyrte_rrtmgp/rrtmgp.py b/pyrte_rrtmgp/kernels/rrtmgp.py similarity index 100% rename from pyrte_rrtmgp/rrtmgp.py rename to pyrte_rrtmgp/kernels/rrtmgp.py diff --git a/pyrte_rrtmgp/rte.py b/pyrte_rrtmgp/kernels/rte.py similarity index 77% rename from pyrte_rrtmgp/rte.py rename to pyrte_rrtmgp/kernels/rte.py index 8ba6f60..512ff30 100644 --- a/pyrte_rrtmgp/rte.py +++ b/pyrte_rrtmgp/kernels/rte.py @@ -79,6 +79,9 @@ def lw_solver_noscat( ncol, nlay, ngpt = tau.shape + if len(sfc_emis.shape) == 1: + sfc_emis = np.stack([sfc_emis] * ngpt).T + # default values n_quad_angs = nmus @@ -181,8 +184,41 @@ def sw_solver_2stream( has_dif_bc=False, do_broadband=False, ): + """ + Solve the shortwave radiative transfer equation using the 2-stream approximation. + + Args: + top_at_1 (bool): Flag indicating whether the top of the atmosphere is at level 1. + tau (ndarray): Array of optical depths with shape (ncol, nlay, ngpt). + ssa (ndarray): Array of single scattering albedos with shape (ncol, nlay, ngpt). + g (ndarray): Array of asymmetry parameters with shape (ncol, nlay, ngpt). + mu0 (ndarray): Array of cosine of solar zenith angles with shape (ncol, ngpt). + sfc_alb_dir (ndarray): Array of direct surface albedos with shape (ncol, ngpt). + sfc_alb_dif (ndarray): Array of diffuse surface albedos with shape (ncol, ngpt). + inc_flux_dir (ndarray): Array of direct incident fluxes with shape (ncol, ngpt). + inc_flux_dif (ndarray, optional): Array of diffuse incident fluxes with shape (ncol, ngpt). + Defaults to None. + has_dif_bc (bool, optional): Flag indicating whether the boundary condition includes diffuse fluxes. + Defaults to False. + do_broadband (bool, optional): Flag indicating whether to compute broadband fluxes. + Defaults to False. + + Returns: + Tuple of ndarrays: Tuple containing the following arrays: + - flux_up: Array of upward fluxes with shape (ngpt, nlay + 1, ncol). + - flux_dn: Array of downward fluxes with shape (ngpt, nlay + 1, ncol). + - flux_dir: Array of direct fluxes with shape (ngpt, nlay + 1, ncol). + - broadband_up: Array of broadband upward fluxes with shape (nlay + 1, ncol). + - broadband_dn: Array of broadband downward fluxes with shape (nlay + 1, ncol). + - broadband_dir: Array of broadband direct fluxes with shape (nlay + 1, ncol). + """ ncol, nlay, ngpt = tau.shape + if len(sfc_alb_dir.shape) == 1: + sfc_alb_dir = np.stack([sfc_alb_dir] * ngpt).T + if len(sfc_alb_dif.shape) == 1: + sfc_alb_dif = np.stack([sfc_alb_dif] * ngpt).T + if inc_flux_dif is None: inc_flux_dif = np.zeros((ncol, ngpt), dtype=np.float64) diff --git a/pyrte_rrtmgp/rrtmgp_gas_optics.py b/pyrte_rrtmgp/rrtmgp_gas_optics.py new file mode 100644 index 0000000..693bec4 --- /dev/null +++ b/pyrte_rrtmgp/rrtmgp_gas_optics.py @@ -0,0 +1,464 @@ +import sys +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import numpy.typing as npt +import xarray as xr + +from pyrte_rrtmgp.constants import AVOGAD, HELMERT1, HELMERT2, M_DRY, M_H2O +from pyrte_rrtmgp.exceptions import ( + MissingAtmosfericConditionsError, + NotExternalSourceError, + NotInternalSourceError, +) +from pyrte_rrtmgp.kernels.rrtmgp import ( + compute_planck_source, + compute_tau_absorption, + compute_tau_rayleigh, + interpolation, +) + + +@dataclass +class GasOptics: + tau: Optional[np.ndarray] = None + g: Optional[np.ndarray] = None + ssa: Optional[np.ndarray] = None + lay_src: Optional[np.ndarray] = None + lev_src: Optional[np.ndarray] = None + sfc_src: Optional[np.ndarray] = None + sfc_src_jac: Optional[np.ndarray] = None + solar_source: Optional[np.ndarray] = None + + +@dataclass +class InterpolatedAtmosfereGases: + jtemp: Optional[np.ndarray] = None + fmajor: Optional[np.ndarray] = None + fminor: Optional[np.ndarray] = None + col_mix: Optional[np.ndarray] = None + tropo: Optional[np.ndarray] = None + jeta: Optional[np.ndarray] = None + jpress: Optional[np.ndarray] = None + + +@xr.register_dataset_accessor("gas_optics") +class GasOpticsAccessor: + def __init__(self, xarray_obj, selected_gases=None): + self._obj = xarray_obj + self._selected_gases = selected_gases + self._gas_names = None + self._is_internal = None + self._gas_mappings = None + self._top_at_1 = None + self._vmr_ref = None + self.col_gas = None + + self._interpolated = InterpolatedAtmosfereGases() + self.gas_optics = GasOptics() + + @property + def gas_names(self): + """Gas names""" + if self._gas_names is None: + names = self._obj["gas_names"].values + self._gas_names = self.extract_names(names) + return self._gas_names + + @property + def source_is_internal(self): + """Check if the source is internal""" + if self._is_internal is None: + variables = self._obj.data_vars + self._is_internal = "totplnk" in variables and "plank_fraction" in variables + return self._is_internal + + def solar_source(self): + """Calculate the solar variability + + Returns: + np.ndarray: Solar source + """ + + if self.source_is_internal: + raise NotExternalSourceError( + "Solar source is not available for internal sources." + ) + + if self.gas_optics.solar_source is None: + a_offset = 0.1495954 + b_offset = 0.00066696 + + solar_source_quiet = self._obj["solar_source_quiet"] + solar_source_facular = self._obj["solar_source_facular"] + solar_source_sunspot = self._obj["solar_source_sunspot"] + + mg_index = self._obj["mg_default"] + sb_index = self._obj["sb_default"] + + self.gas_optics.solar_source = ( + solar_source_quiet + + (mg_index - a_offset) * solar_source_facular + + (sb_index - b_offset) * solar_source_sunspot + ).data + + def load_atmosferic_conditions(self, atmosferic_conditions: xr.Dataset): + """Load atmospheric conditions""" + self._atm_cond = atmosferic_conditions + + # RRTMGP won't run with pressure less than its minimum. + # So we add a small value to the minimum pressure + min_index = np.argmin(self._atm_cond["pres_level"].data) + min_press = self._obj["press_ref"].min().item() + sys.float_info.epsilon + self._atm_cond["pres_level"][:, min_index] = min_press + + self.get_col_gas() + + self.interpolate() + self.compute_gas_taus() + if self.source_is_internal: + self.compute_planck() + else: + self.solar_source() + + return self.gas_optics + + def get_col_gas(self): + if self._atm_cond is None: + raise MissingAtmosfericConditionsError() + + ncol = len(self._atm_cond["site"]) + nlay = len(self._atm_cond["layer"]) + col_gas = [] + for gas_name in self.gas_mappings.values(): + # if gas_name is not available, fill it with zeros + if gas_name not in self._atm_cond.data_vars.keys(): + gas_values = np.zeros((ncol, nlay)) + else: + try: + scale = float(self._atm_cond[gas_name].units) + except AttributeError: + scale = 1.0 + gas_values = self._atm_cond[gas_name].values * scale + + if gas_values.ndim == 0: + gas_values = np.full((ncol, nlay), gas_values) + col_gas.append(gas_values) + + vmr_h2o = col_gas[self.gas_names.index("h2o")] + col_dry = self.get_col_dry( + vmr_h2o, self._atm_cond["pres_level"].data, latitude=None + ) + col_gas = [col_dry] + col_gas + + col_gas = np.stack(col_gas, axis=-1).astype(np.float64) + col_gas[:, :, 1:] = col_gas[:, :, 1:] * col_gas[:, :, :1] + + self.col_gas = col_gas + + @property + def gas_mappings(self): + """Gas mappings""" + + if self._atm_cond is None: + raise MissingAtmosfericConditionsError() + + if self._gas_mappings is None: + gas_name_map = { + "h2o": "water_vapor", + "co2": "carbon_dioxide_GM", + "o3": "ozone", + "n2o": "nitrous_oxide_GM", + "co": "carbon_monoxide_GM", + "ch4": "methane_GM", + "o2": "oxygen_GM", + "n2": "nitrogen_GM", + "ccl4": "carbon_tetrachloride_GM", + "cfc11": "cfc11_GM", + "cfc12": "cfc12_GM", + "cfc22": "hcfc22_GM", + "hfc143a": "hfc143a_GM", + "hfc125": "hfc125_GM", + "hfc23": "hfc23_GM", + "hfc32": "hfc32_GM", + "hfc134a": "hfc134a_GM", + "cf4": "cf4_GM", + "no2": "no2", + } + + if self._selected_gases is not None: + gas_name_map = { + g: gas_name_map[g] + for g in self._selected_gases + if g in gas_name_map + } + + gas_name_map = { + g: gas_name_map[g] for g in self.gas_names if g in gas_name_map + } + self._gas_mappings = gas_name_map + return self._gas_mappings + + @property + def top_at_1(self): + if self._top_at_1 is None: + if self._atm_cond is None: + raise MissingAtmosfericConditionsError() + + pres_layers = self._atm_cond["pres_layer"]["layer"] + self._top_at_1 = pres_layers[0] < pres_layers[-1] + return self._top_at_1.item() + + @property + def vmr_ref(self): + if self._vmr_ref is None: + if self._atm_cond is None: + raise MissingAtmosfericConditionsError() + sel_gases = self.gas_mappings.keys() + vmr_idx = [i for i, g in enumerate(self._gas_names, 1) if g in sel_gases] + vmr_idx = [0] + vmr_idx + self._vmr_ref = self._obj["vmr_ref"].sel(absorber_ext=vmr_idx).values.T + return self._vmr_ref + + def interpolate(self): + ( + self._interpolated.jtemp, + self._interpolated.fmajor, + self._interpolated.fminor, + self._interpolated.col_mix, + self._interpolated.tropo, + self._interpolated.jeta, + self._interpolated.jpress, + ) = interpolation( + neta=len(self._obj["mixing_fraction"]), + flavor=self.flavors_sets, + press_ref=self._obj["press_ref"].values, + temp_ref=self._obj["temp_ref"].values, + press_ref_trop=self._obj["press_ref_trop"].values.item(), + vmr_ref=self.vmr_ref, + play=self._atm_cond["pres_layer"].values, + tlay=self._atm_cond["temp_layer"].values, + col_gas=self.col_gas, + ) + + def compute_planck(self): + ( + self.gas_optics.sfc_src, + self.gas_optics.lay_src, + self.gas_optics.lev_src, + self.gas_optics.sfc_src_jac, + ) = compute_planck_source( + self._atm_cond["temp_layer"].values, + self._atm_cond["temp_level"].values, + self._atm_cond["surface_temperature"].values, + self.top_at_1, + self._interpolated.fmajor, + self._interpolated.jeta, + self._interpolated.tropo, + self._interpolated.jtemp, + self._interpolated.jpress, + self._obj["bnd_limits_gpt"].values.T, + self._obj["plank_fraction"].values.transpose(0, 2, 1, 3), + self._obj["temp_ref"].values.min(), + self._obj["temp_ref"].values.max(), + self._obj["totplnk"].values.T, + self.gpoint_flavor, + ) + + def compute_gas_taus(self): + minor_gases_lower = self.extract_names(self._obj["minor_gases_lower"].data) + minor_gases_upper = self.extract_names(self._obj["minor_gases_upper"].data) + # check if the index is correct + idx_minor_lower = self.get_idx_minor(self.gas_names, minor_gases_lower) + idx_minor_upper = self.get_idx_minor(self.gas_names, minor_gases_upper) + + scaling_gas_lower = self.extract_names(self._obj["scaling_gas_lower"].data) + scaling_gas_upper = self.extract_names(self._obj["scaling_gas_upper"].data) + + idx_minor_scaling_lower = self.get_idx_minor(self.gas_names, scaling_gas_lower) + idx_minor_scaling_upper = self.get_idx_minor(self.gas_names, scaling_gas_upper) + + tau_absorption = compute_tau_absorption( + self.idx_h2o, + self.gpoint_flavor, + self._obj["bnd_limits_gpt"].values.T, + self._obj["kmajor"].values, + self._obj["kminor_lower"].values, + self._obj["kminor_upper"].values, + self._obj["minor_limits_gpt_lower"].values.T, + self._obj["minor_limits_gpt_upper"].values.T, + self._obj["minor_scales_with_density_lower"].values.astype(bool), + self._obj["minor_scales_with_density_upper"].values.astype(bool), + self._obj["scale_by_complement_lower"].values.astype(bool), + self._obj["scale_by_complement_upper"].values.astype(bool), + idx_minor_lower, + idx_minor_upper, + idx_minor_scaling_lower, + idx_minor_scaling_upper, + self._obj["kminor_start_lower"].values, + self._obj["kminor_start_upper"].values, + self._interpolated.tropo, + self._interpolated.col_mix, + self._interpolated.fmajor, + self._interpolated.fminor, + self._atm_cond["pres_layer"].values, + self._atm_cond["temp_layer"].values, + self.col_gas, + self._interpolated.jeta, + self._interpolated.jtemp, + self._interpolated.jpress, + ) + + if self.source_is_internal: + self.gas_optics.tau = tau_absorption + self.gas_optics.ssa = np.full_like(tau_absorption, np.nan) + self.gas_optics.g = np.full_like(tau_absorption, np.nan) + else: + krayl = np.stack( + [self._obj["rayl_lower"].values, self._obj["rayl_upper"].values], + axis=-1, + ) + tau_rayleigh = compute_tau_rayleigh( + self.gpoint_flavor, + self._obj["bnd_limits_gpt"].values.T, + krayl, + self.idx_h2o, + self.col_gas[:, :, 0], + self.col_gas, + self._interpolated.fminor, + self._interpolated.jeta, + self._interpolated.tropo, + self._interpolated.jtemp, + ) + + self.gas_optics.tau = tau_absorption + tau_rayleigh + self.gas_optics.ssa = np.where( + self.gas_optics.tau > 2.0 * np.finfo(float).tiny, + tau_rayleigh / self.gas_optics.tau, + 0.0, + ) + self.gas_optics.g = np.zeros(self.gas_optics.tau.shape) + + @property + def idx_h2o(self): + return list(self.gas_names).index("h2o") + 1 + + @property + def gpoint_flavor(self) -> npt.NDArray: + """Get the g-point flavors from the k-distribution file. + + Each g-point is associated with a flavor, which is a pair of key species. + + Returns: + np.ndarray: G-point flavors. + """ + key_species = self._obj["key_species"].values + + band_ranges = [ + [i] * (r.values[1] - r.values[0] + 1) + for i, r in enumerate(self._obj["bnd_limits_gpt"], 1) + ] + gpoint_bands = np.concatenate(band_ranges) + + key_species_rep = key_species.copy() + key_species_rep[np.all(key_species_rep == [0, 0], axis=2)] = [2, 2] + + # unique flavors + flist = self.flavors_sets.T.tolist() + + def key_species_pair2flavor(key_species_pair): + return flist.index(key_species_pair.tolist()) + 1 + + flavors_bands = np.apply_along_axis( + key_species_pair2flavor, 2, key_species_rep + ).tolist() + gpoint_flavor = np.array([flavors_bands[gp - 1] for gp in gpoint_bands]).T + + return gpoint_flavor + + @property + def flavors_sets(self) -> npt.NDArray: + """Get the unique flavors from the k-distribution file. + + Returns: + np.ndarray: Unique flavors. + """ + key_species = self._obj["key_species"].values + tot_flav = len(self._obj["bnd"]) * len(self._obj["atmos_layer"]) + npairs = len(self._obj["pair"]) + all_flav = np.reshape(key_species, (tot_flav, npairs)) + # (0,0) becomes (2,2) because absorption coefficients for these g-points will be 0. + all_flav[np.all(all_flav == [0, 0], axis=1)] = [2, 2] + # we do that instead of unique to preserv the order + _, idx = np.unique(all_flav, axis=0, return_index=True) + return all_flav[np.sort(idx)].T + + @staticmethod + def get_idx_minor(gas_names, minor_gases): + """Index of each minor gas in col_gas + + Args: + gas_names (list): Gas names + minor_gases (list): List of minor gases + + Returns: + list: Index of each minor gas in col_gas + """ + idx_minor_gas = [] + for gas in minor_gases: + try: + gas_idx = gas_names.index(gas) + 1 + except ValueError: + gas_idx = -1 + idx_minor_gas.append(gas_idx) + return np.array(idx_minor_gas, dtype=np.int32) + + @staticmethod + def extract_names(names): + """Extract names from arrays, decoding and removing the suffix + + Args: + names (np.ndarray): Names + + Returns: + tuple: tuple of names + """ + output = tuple(gas.tobytes().decode().strip().split("_")[0] for gas in names) + return output + + @staticmethod + def get_col_dry(vmr_h2o, plev, latitude=None): + """Calculate the dry column of the atmosphere + + Args: + vmr_h2o (np.ndarray): Water vapor volume mixing ratio + plev (np.ndarray): Pressure levels + latitude (np.ndarray): Latitude of the location + + Returns: + np.ndarray: Dry column of the atmosphere + """ + ncol = plev.shape[0] + nlev = plev.shape[1] + col_dry = np.zeros((ncol, nlev - 1)) + + if latitude is not None: + g0 = HELMERT1 - HELMERT2 * np.cos(2.0 * np.pi * latitude / 180.0) + else: + g0 = np.full(ncol, HELMERT1) # Assuming grav is a constant value + + # TODO: use numpy instead of loops + for ilev in range(nlev - 1): + for icol in range(ncol): + delta_plev = abs(plev[icol, ilev] - plev[icol, ilev + 1]) + fact = 1.0 / (1.0 + vmr_h2o[icol, ilev]) + m_air = (M_DRY + M_H2O * vmr_h2o[icol, ilev]) * fact + col_dry[icol, ilev] = ( + 10.0 + * delta_plev + * AVOGAD + * fact + / (1000.0 * m_air * 100.0 * g0[icol]) + ) + return col_dry diff --git a/pyrte_rrtmgp/utils.py b/pyrte_rrtmgp/utils.py index a27387b..366dafe 100644 --- a/pyrte_rrtmgp/utils.py +++ b/pyrte_rrtmgp/utils.py @@ -1,188 +1,4 @@ -from typing import List - import numpy as np -import numpy.typing as npt -import xarray as xr - -# Constants -HELMERT1 = 9.80665 -HELMERT2 = 0.02586 -M_DRY = 0.028964 -M_H2O = 0.018016 -AVOGAD = 6.02214076e23 - - -def flavors_from_kdist(kdist: xr.Dataset) -> npt.NDArray: - """Get the unique flavors from the k-distribution file. - - Args: - kdist (xr.Dataset): K-distribution file. - - Returns: - np.ndarray: Unique flavors. - """ - key_species = kdist["key_species"].values - tot_flav = len(kdist["bnd"]) * len(kdist["atmos_layer"]) - npairs = len(kdist["pair"]) - all_flav = np.reshape(key_species, (tot_flav, npairs)) - # (0,0) becomes (2,2) because absorption coefficients for these g-points will be 0. - all_flav[np.all(all_flav == [0, 0], axis=1)] = [2, 2] - # we do that instead of unique to preserv the order - _, idx = np.unique(all_flav, axis=0, return_index=True) - return all_flav[np.sort(idx)].T - - -def get_col_dry(vmr_h2o, plev, latitude=None): - """Calculate the dry column of the atmosphere - - Args: - vmr_h2o (np.ndarray): Water vapor volume mixing ratio - plev (np.ndarray): Pressure levels - latitude (np.ndarray): Latitude of the location - - Returns: - np.ndarray: Dry column of the atmosphere - """ - ncol = plev.shape[0] - nlev = plev.shape[1] - col_dry = np.zeros((ncol, nlev - 1)) - - if latitude is not None: - g0 = HELMERT1 - HELMERT2 * np.cos(2.0 * np.pi * latitude / 180.0) - else: - g0 = np.full(ncol, HELMERT1) # Assuming grav is a constant value - - # TODO: use numpy instead of loops - for ilev in range(nlev - 1): - for icol in range(ncol): - delta_plev = abs(plev[icol, ilev] - plev[icol, ilev + 1]) - fact = 1.0 / (1.0 + vmr_h2o[icol, ilev]) - m_air = (M_DRY + M_H2O * vmr_h2o[icol, ilev]) * fact - col_dry[icol, ilev] = ( - 10.0 * delta_plev * AVOGAD * fact / (1000.0 * m_air * 100.0 * g0[icol]) - ) - return col_dry - - -def rfmip_2_col_gas(rfmip: xr.Dataset, gas_names: List[str], dry_air: bool = False): - """Convert RFMIP data to column gas concentrations. - - Args: - rfmip (xr.Dataset): RFMIP data. - gas_names (list): List of gas names. - dry_air (bool, optional): Include dry air. Defaults to False. - - Returns: - np.ndarray: Column gas concentrations. - """ - - ncol = len(rfmip["site"]) - nlay = len(rfmip["layer"]) - col_gas = [] - for gas_name in gas_names: - # if gas_name is not available, fill it with zeros - if gas_name not in rfmip.data_vars.keys(): - gas_values = np.zeros((ncol, nlay)) - else: - try: - scale = float(rfmip[gas_name].units) - except AttributeError: - scale = 1.0 - gas_values = rfmip[gas_name].values * scale - - if gas_values.ndim == 0: - gas_values = np.full((ncol, nlay), gas_values) - col_gas.append(gas_values) - - if dry_air: - if "h2o" not in gas_names and "water_vapor" not in gas_names: - raise ValueError( - "h2o gas must be included in gas_names to calculate dry air" - ) - if "h2o" in gas_names: - h2o_idx = gas_names.index("h2o") - else: - h2o_idx = gas_names.index("water_vapor") - vmr_h2o = col_gas[h2o_idx] - col_dry = get_col_dry(vmr_h2o, rfmip["pres_level"].values, latitude=None) - col_gas = [col_dry] + col_gas - - col_gas = np.stack(col_gas, axis=-1).astype(np.float64) - col_gas[:, :, 1:] = col_gas[:, :, 1:] * col_gas[:, :, :1] - - return col_gas - - -def gpoint_flavor_from_kdist(kdist: xr.Dataset) -> npt.NDArray: - """Get the g-point flavors from the k-distribution file. - - Each g-point is associated with a flavor, which is a pair of key species. - - Args: - kdist (xr.Dataset): K-distribution file. - - Returns: - np.ndarray: G-point flavors. - """ - key_species = kdist["key_species"].values - flavors = flavors_from_kdist(kdist) - - band_ranges = [ - [i] * (r.values[1] - r.values[0] + 1) - for i, r in enumerate(kdist["bnd_limits_gpt"], 1) - ] - gpoint_bands = np.concatenate(band_ranges) - - key_species_rep = key_species.copy() - key_species_rep[np.all(key_species_rep == [0, 0], axis=2)] = [2, 2] - - # unique flavors - flist = flavors.T.tolist() - - def key_species_pair2flavor(key_species_pair): - return flist.index(key_species_pair.tolist()) + 1 - - flavors_bands = np.apply_along_axis( - key_species_pair2flavor, 2, key_species_rep - ).tolist() - gpoint_flavor = np.array([flavors_bands[gp - 1] for gp in gpoint_bands]).T - - return gpoint_flavor - - -def extract_gas_names(gas_names): - """Extract gas names from the gas_names array, decoding and removing the suffix - - Args: - gas_names (np.ndarray): Gas names - - Returns: - list: List of gas names - """ - output = [] - for gas in gas_names: - output.append(gas.tobytes().decode().strip().split("_")[0]) - return output - - -def get_idx_minor(gas_names, minor_gases): - """Index of each minor gas in col_gas - - Args: - gas_names (list): Gas names - minor_gases (list): List of minor gases - - Returns: - list: Index of each minor gas in col_gas - """ - idx_minor_gas = [] - for gas in minor_gases: - try: - gas_idx = gas_names.index(gas) + 1 - except ValueError: - gas_idx = -1 - idx_minor_gas.append(gas_idx) - return np.array(idx_minor_gas, dtype=np.int32) def get_usecols(solar_zenith_angle): @@ -211,31 +27,17 @@ def compute_mu0(solar_zenith_angle, nlayer=None): return mu0 -def krayl_from_kdist(kdist: xr.Dataset) -> npt.NDArray: - """Get the Rayleigh scattering coefficients from the k-distribution file. - - Args: - kdist (xr.Dataset): K-distribution file. - - Returns: - np.ndarray: Rayleigh scattering coefficients. - """ - return np.stack([kdist["rayl_lower"].values, kdist["rayl_upper"].values], axis=-1) - - -def combine_abs_and_rayleigh(tau_absorption, tau_rayleigh): - """Combine absorption and Rayleigh scattering optical depths. +def compute_toa_flux(total_solar_irradiance, solar_source): + """Compute the top of atmosphere flux Args: - tau_absorption (np.ndarray): Absorption optical depth. - tau_rayleigh (np.ndarray): Rayleigh scattering optical depth. + total_solar_irradiance (np.ndarray): Total solar irradiance + solar_source (np.ndarray): Solar source Returns: - np.ndarray: Combined optical depth. + np.ndarray: Top of atmosphere flux """ - - tau = tau_absorption + tau_rayleigh - ssa = np.where(tau > 2.0 * np.finfo(float).tiny, tau_rayleigh / tau, 0.0) - g = np.zeros(tau.shape) - - return tau, ssa, g + ncol = total_solar_irradiance.shape[0] + toa_flux = np.stack([solar_source] * ncol) + def_tsi = toa_flux.sum(axis=1) + return (toa_flux.T * (total_solar_irradiance / def_tsi)).T diff --git a/tests/test_python_frontend/test_lw_solver.py b/tests/test_python_frontend/test_lw_solver.py index a0767f8..3cf3aa7 100644 --- a/tests/test_python_frontend/test_lw_solver.py +++ b/tests/test_python_frontend/test_lw_solver.py @@ -1,12 +1,13 @@ import os + import numpy as np import xarray as xr -from pyrte_rrtmgp.gas_optics import GasOptics -from pyrte_rrtmgp.rte import lw_solver_noscat +from pyrte_rrtmgp import rrtmgp_gas_optics +from pyrte_rrtmgp.kernels.rte import lw_solver_noscat ERROR_TOLERANCE = 1e-4 -rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data") +rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data") clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" rfmip = xr.load_dataset( @@ -27,22 +28,15 @@ def test_lw_solver_noscat(): - min_index = np.argmin(rfmip["pres_level"].values) - rfmip["pres_level"][:, min_index] = 1.0051835744630002 - - gas_optics = GasOptics(kdist, rfmip) - tau, _, _, layer_src, level_src, sfc_src, sfc_src_jac = gas_optics.gas_optics() - - sfc_emis = rfmip["surface_emissivity"].values - sfc_emis = np.stack([sfc_emis] * tau.shape[2]).T + rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) _, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat( - tau=tau, - lay_source=layer_src, - lev_source=level_src, - sfc_emis=sfc_emis, - sfc_src=sfc_src, - sfc_src_jac=sfc_src_jac, + tau=rrtmgp_gas_optics.tau, + lay_source=rrtmgp_gas_optics.lay_src, + lev_source=rrtmgp_gas_optics.lev_src, + sfc_emis=rfmip["surface_emissivity"].data, + sfc_src=rrtmgp_gas_optics.sfc_src, + sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac, ) assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/test_sw_solver.py b/tests/test_python_frontend/test_sw_solver.py index f5e22d9..d10984a 100644 --- a/tests/test_python_frontend/test_sw_solver.py +++ b/tests/test_python_frontend/test_sw_solver.py @@ -1,13 +1,15 @@ import os + import numpy as np +import pytest import xarray as xr -from pyrte_rrtmgp.gas_optics import GasOptics -from pyrte_rrtmgp.rte import sw_solver_2stream -from pyrte_rrtmgp.utils import compute_mu0, get_usecols +from pyrte_rrtmgp import rrtmgp_gas_optics +from pyrte_rrtmgp.kernels.rte import sw_solver_2stream +from pyrte_rrtmgp.utils import compute_mu0, compute_toa_flux, get_usecols ERROR_TOLERANCE = 1e-4 -rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data") +rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data") clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" rfmip = xr.load_dataset( @@ -28,39 +30,24 @@ def test_lw_solver_noscat(): - min_index = np.argmin(rfmip["pres_level"].values) - rfmip["pres_level"][:, min_index] = 1.0051835744630002 - - gas_optics = GasOptics(kdist, rfmip) - gas_optics.source_is_internal - tau, g, ssa, toa_flux = gas_optics.gas_optics() - - pres_layers = rfmip["pres_layer"]["layer"] - top_at_1 = (pres_layers[0] < pres_layers[-1]).values.item() + gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) - # Expand the surface albedo to ngpt - ngpt = len(kdist["gpt"]) - surface_albedo = rfmip["surface_albedo"].values - surface_albedo = np.stack([surface_albedo] * ngpt) - sfc_alb_dir = surface_albedo.T.copy() - sfc_alb_dif = surface_albedo.T.copy() + surface_albedo = rfmip["surface_albedo"].data + total_solar_irradiance = rfmip["total_solar_irradiance"].data nlayer = len(rfmip["layer"]) mu0 = compute_mu0(rfmip["solar_zenith_angle"].values, nlayer=nlayer) - total_solar_irradiance = rfmip["total_solar_irradiance"].values - toa_flux = np.stack([toa_flux] * mu0.shape[0]) - def_tsi = toa_flux.sum(axis=1) - toa_flux = (toa_flux.T * (total_solar_irradiance / def_tsi)).T + toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source) _, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream( - top_at_1, - tau, - ssa, - g, + kdist.gas_optics.top_at_1, + gas_optics.tau, + gas_optics.ssa, + gas_optics.g, mu0, - sfc_alb_dir, - sfc_alb_dif, + sfc_alb_dir=surface_albedo, + sfc_alb_dif=surface_albedo, inc_flux_dir=toa_flux, inc_flux_dif=None, has_dif_bc=False,