diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index dd019d8..968d3e9 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: platform: [ubuntu-latest, macos-latest] - python-version: ["3.11"] + python-version: ["3.12"] runs-on: ${{ matrix.platform }} @@ -33,8 +33,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Get conda uses: conda-incubator/setup-miniconda@v3.0.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 1834d3f..007473b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ endif() ExternalProject_Add( rte-rrtmgp GIT_REPOSITORY https://github.com/earth-system-radiation/rte-rrtmgp.git - GIT_TAG origin/develop + GIT_TAG v1.8 GIT_SHALLOW TRUE SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/rte-rrtmgp CONFIGURE_COMMAND "" @@ -81,8 +81,24 @@ target_compile_definitions(${TARGET_NAME} PRIVATE DCMAKE_LIBRARY_OUTPUT_DIRECTORY=pyrte_rrtmgp ) +# Add these checks after the initial Linux check +set(APPLE_ARM FALSE) +if(APPLE) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + set(APPLE_ARM TRUE) + endif() +endif() + if (${LINUX}) target_link_libraries(${TARGET_NAME} PUBLIC gfortran) +elseif(APPLE) + # On macOS, explicitly link against gfortran runtime + if(APPLE_ARM) + target_link_directories(${TARGET_NAME} PUBLIC /opt/homebrew/lib/gcc/current) + else() + target_link_directories(${TARGET_NAME} PUBLIC /usr/local/lib/gcc/current) + endif() + target_link_libraries(${TARGET_NAME} PUBLIC gfortran) endif() # The install directory is the output (wheel) directory diff --git a/examples/lw_example.ipynb b/examples/lw_example.ipynb index 3248af6..03ac7cb 100644 --- a/examples/lw_example.ipynb +++ b/examples/lw_example.ipynb @@ -6,47 +6,45 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", + "\n", "import numpy as np\n", "import xarray as xr\n", "\n", "from pyrte_rrtmgp import rrtmgp_gas_optics\n", - "from pyrte_rrtmgp.kernels.rte import lw_solver_noscat\n", - "\n", + "from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics\n", + "from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data\n", + "from pyrte_rrtmgp.rte_solver import RTESolver\n", "\n", - "ERROR_TOLERANCE = 1e-4\n", + "ERROR_TOLERANCE = 1e-7\n", "\n", - "rte_rrtmgp_dir = \"../rrtmgp-data\"\n", - "clear_sky_example_files = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs\"\n", + "rte_rrtmgp_dir = download_rrtmgp_data()\n", + "rfmip_dir = os.path.join(rte_rrtmgp_dir, \"examples\", \"rfmip-clear-sky\")\n", + "input_dir = os.path.join(rfmip_dir, \"inputs\")\n", + "ref_dir = os.path.join(rfmip_dir, \"reference\")\n", "\n", - "rfmip = xr.load_dataset(\n", - " f\"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc\"\n", - ")\n", - "rfmip = rfmip.sel(expt=0) # only one experiment\n", + "gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256)\n", "\n", - "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc\")\n", - "rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)\n", + "atmosphere_file = \"multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc\"\n", + "atmosphere_path = os.path.join(input_dir, atmosphere_file)\n", + "atmosphere = xr.load_dataset(atmosphere_path)\n", "\n", - "_, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat(\n", - " tau=rrtmgp_gas_optics.tau,\n", - " lay_source=rrtmgp_gas_optics.lay_src,\n", - " lev_source=rrtmgp_gas_optics.lev_src,\n", - " sfc_emis=rfmip[\"surface_emissivity\"].data,\n", - " sfc_src=rrtmgp_gas_optics.sfc_src,\n", - " sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac,\n", - ")\n", + "gas_optics_lw.gas_optics.compute(atmosphere, problem_type=\"absorption\")\n", "\n", - "rlu = xr.load_dataset(\n", - " \"../tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", - ")\n", - "ref_flux_up = rlu.isel(expt=0)[\"rlu\"].values\n", + "solver = RTESolver()\n", + "fluxes = solver.solve(atmosphere, add_to_input=False)\n", "\n", - "rld = xr.load_dataset(\n", - " \"../tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", - ")\n", - "ref_flux_down = rld.isel(expt=0)[\"rld\"].values\n", + "rlu_reference = f\"{ref_dir}/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + "rld_reference = f\"{ref_dir}/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + "rlu = xr.load_dataset(rlu_reference, decode_cf=False)\n", + "rld = xr.load_dataset(rld_reference, decode_cf=False)\n", "\n", - "assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all()\n", - "assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all()" + "assert np.isclose(\n", + " fluxes[\"lw_flux_up_broadband\"], rlu[\"rlu\"], atol=ERROR_TOLERANCE\n", + ").all()\n", + "assert np.isclose(\n", + " fluxes[\"lw_flux_down_broadband\"], rld[\"rld\"], atol=ERROR_TOLERANCE\n", + ").all()" ] } ], diff --git a/examples/sw_example.ipynb b/examples/sw_example.ipynb index 10b5563..a83c75f 100644 --- a/examples/sw_example.ipynb +++ b/examples/sw_example.ipynb @@ -2,72 +2,45 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ + "import os\n", + "\n", "import numpy as np\n", "import xarray as xr\n", "\n", "from pyrte_rrtmgp import rrtmgp_gas_optics\n", - "from pyrte_rrtmgp.kernels.rte import sw_solver_2stream\n", - "from pyrte_rrtmgp.utils import compute_mu0, get_usecols, compute_toa_flux\n", - "\n", - "ERROR_TOLERANCE = 1e-4\n", - "\n", - "rte_rrtmgp_dir = \"../rrtmgp-data\"\n", - "clear_sky_example_files = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs\"\n", - "\n", - "rfmip = xr.load_dataset(\n", - " f\"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc\"\n", - ")\n", - "rfmip = rfmip.sel(expt=0) # only one experiment\n", - "\n", - "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc\")\n", - "gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)\n", + "from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics\n", + "from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data\n", + "from pyrte_rrtmgp.rte_solver import RTESolver\n", "\n", - "surface_albedo = rfmip[\"surface_albedo\"].data\n", - "total_solar_irradiance = rfmip[\"total_solar_irradiance\"].data\n", + "ERROR_TOLERANCE = 1e-7\n", "\n", - "nlayer = len(rfmip[\"layer\"])\n", - "mu0 = compute_mu0(rfmip[\"solar_zenith_angle\"].values, nlayer=nlayer)\n", + "rte_rrtmgp_dir = download_rrtmgp_data()\n", + "rfmip_dir = os.path.join(rte_rrtmgp_dir, \"examples\", \"rfmip-clear-sky\")\n", + "input_dir = os.path.join(rfmip_dir, \"inputs\")\n", + "ref_dir = os.path.join(rfmip_dir, \"reference\")\n", "\n", - "toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source)\n", + "gas_optics_sw = load_gas_optics(gas_optics_file=GasOpticsFiles.SW_G224)\n", "\n", - "_, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream(\n", - " kdist.gas_optics.top_at_1,\n", - " gas_optics.tau,\n", - " gas_optics.ssa,\n", - " gas_optics.g,\n", - " mu0,\n", - " sfc_alb_dir=surface_albedo,\n", - " sfc_alb_dif=surface_albedo,\n", - " inc_flux_dir=toa_flux,\n", - " inc_flux_dif=None,\n", - " has_dif_bc=False,\n", - " do_broadband=True,\n", - ")\n", + "atmosphere_file = \"multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc\"\n", + "atmosphere_path = os.path.join(input_dir, atmosphere_file)\n", + "atmosphere = xr.load_dataset(atmosphere_path)\n", "\n", - "# RTE will fail if passed solar zenith angles greater than 90 degree. We replace any with\n", - "# nighttime columns with a default solar zenith angle. We'll mask these out later, of\n", - "# course, but this gives us more work and so a better measure of timing.\n", - "usecol = get_usecols(rfmip[\"solar_zenith_angle\"].values)\n", - "solver_flux_up = solver_flux_up * usecol[:, np.newaxis]\n", - "solver_flux_down = solver_flux_down * usecol[:, np.newaxis]\n", + "gas_optics_sw.gas_optics.compute(atmosphere, problem_type=\"two-stream\")\n", "\n", - "# Compare the results with the reference data\n", - "rsu = xr.load_dataset(\n", - " \"../tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", - ")\n", - "ref_flux_up = rsu.isel(expt=0)[\"rsu\"].values\n", + "solver = RTESolver()\n", + "fluxes = solver.solve(atmosphere, add_to_input=False)\n", "\n", - "rsd = xr.load_dataset(\n", - " \"../tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", - ")\n", - "ref_flux_down = rsd.isel(expt=0)[\"rsd\"].values\n", + "rsu_reference = f\"{ref_dir}/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + "rsd_reference = f\"{ref_dir}/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", + "rsu = xr.load_dataset(rsu_reference, decode_cf=False)\n", + "rsd = xr.load_dataset(rsd_reference, decode_cf=False)\n", "\n", - "assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all()\n", - "assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all()" + "assert np.isclose(fluxes[\"sw_flux_up\"], rsu[\"rsu\"], atol=ERROR_TOLERANCE).all()\n", + "assert np.isclose(fluxes[\"sw_flux_down\"], rsd[\"rsd\"], atol=ERROR_TOLERANCE).all()" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 5e5fa93..0ecdfbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.0.6" description = "A Python interface to the RTE+RRTMGP Fortran software package." readme = "README.md" requires-python = ">=3.7" -dependencies = ["numpy>=1.21.0", "xarray>=2023.5.0", "netcdf4>=1.5.7"] +dependencies = ["numpy>=2.0.0", "xarray>=2023.5.0", "netcdf4>=1.5.7", "requests>=2.4.0"] classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: MIT License", @@ -15,6 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] @@ -24,7 +25,7 @@ build-backend = "scikit_build_core.build" [project.optional-dependencies] -test = ["pytest", "numpy>=1.21.0", "xarray>=2023.5.0", "netcdf4>=1.5.7", "requests>=2.4.0"] +test = ["pytest", "numpy>=2.0.0", "xarray>=2023.5.0", "netcdf4>=1.5.7", "requests>=2.4.0"] [tool.scikit-build] diff --git a/pyrte_rrtmgp/config.py b/pyrte_rrtmgp/config.py new file mode 100644 index 0000000..314e9b7 --- /dev/null +++ b/pyrte_rrtmgp/config.py @@ -0,0 +1,52 @@ +"""Default mappings for gas names, dimensions and variables used in RRTMGP. + +This module contains dictionaries that map standard names to dataset-specific names +for gases, dimensions and variables used in radiative transfer calculations. +""" + +from typing import Dict, Final + +# Mapping of standard gas names to RRTMGP-specific names +DEFAULT_GAS_MAPPING: Final[Dict[str, str]] = { + "h2o": "water_vapor", + "co2": "carbon_dioxide_GM", + "o3": "ozone", + "n2o": "nitrous_oxide_GM", + "co": "carbon_monoxide_GM", + "ch4": "methane_GM", + "o2": "oxygen_GM", + "n2": "nitrogen_GM", + "ccl4": "carbon_tetrachloride_GM", + "cfc11": "cfc11_GM", + "cfc12": "cfc12_GM", + "cfc22": "hcfc22_GM", + "hfc143a": "hfc143a_GM", + "hfc125": "hfc125_GM", + "hfc23": "hfc23_GM", + "hfc32": "hfc32_GM", + "hfc134a": "hfc134a_GM", + "cf4": "cf4_GM", + "no2": "no2", +} + +# Mapping of standard dimension names to dataset-specific names +DEFAULT_DIM_MAPPING: Final[Dict[str, str]] = { + "site": "site", + "layer": "layer", + "level": "level", +} + +# Mapping of standard variable names to dataset-specific names +DEFAULT_VAR_MAPPING: Final[Dict[str, str]] = { + "pres_layer": "pres_layer", + "pres_level": "pres_level", + "temp_layer": "temp_layer", + "temp_level": "temp_level", + "surface_temperature": "surface_temperature", + "solar_zenith_angle": "solar_zenith_angle", + "surface_albedo": "surface_albedo", + "surface_albedo_dir": "surface_albedo_dir", + "surface_albedo_dif": "surface_albedo_dif", + "surface_emissivity": "surface_emissivity", + "surface_emissivity_jacobian": "surface_emissivity_jacobian", +} diff --git a/pyrte_rrtmgp/constants.py b/pyrte_rrtmgp/constants.py index d0e59a5..dfee75a 100644 --- a/pyrte_rrtmgp/constants.py +++ b/pyrte_rrtmgp/constants.py @@ -1,5 +1,49 @@ -HELMERT1 = 9.80665 -HELMERT2 = 0.02586 -M_DRY = 0.028964 -M_H2O = 0.018016 -AVOGAD = 6.02214076e23 +"""Physical and mathematical constants used in radiative transfer calculations. + +This module contains various physical and mathematical constants needed for +radiative transfer calculations, including gravitational parameters, molecular +masses, and Gaussian quadrature weights and points. +""" + +from typing import Dict, Final + +import numpy as np +from numpy.typing import NDArray + +# Gravitational parameters from Helmert's equation (m/s^2) +HELMERT1: Final[float] = 9.80665 # Standard gravity at sea level +HELMERT2: Final[float] = 0.02586 # Gravity variation with latitude + +# Molecular masses (kg/mol) +M_DRY: Final[float] = 0.028964 # Dry air +M_H2O: Final[float] = 0.018016 # Water vapor + +# Avogadro's number (molecules/mol) +AVOGAD: Final[float] = 6.02214076e23 + +# Solar constants for orbit calculations +SOLAR_CONSTANTS: Final[Dict[str, float]] = { + "A_OFFSET": 0.1495954, # Semi-major axis offset (AU) + "B_OFFSET": 0.00066696, # Orbital eccentricity factor +} + +# Gaussian quadrature constants for radiative transfer +GAUSS_DS: NDArray[np.float64] = np.reciprocal( + np.array( + [ + [0.6096748751, np.inf, np.inf, np.inf], + [0.2509907356, 0.7908473988, np.inf, np.inf], + [0.1024922169, 0.4417960320, 0.8633751621, np.inf], + [0.0454586727, 0.2322334416, 0.5740198775, 0.9030775973], + ] + ) +) + +GAUSS_WTS: NDArray[np.float64] = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.2300253764, 0.7699746236, 0.0, 0.0], + [0.0437820218, 0.3875796738, 0.5686383044, 0.0], + [0.0092068785, 0.1285704278, 0.4323381850, 0.4298845087], + ] +) diff --git a/pyrte_rrtmgp/data_types.py b/pyrte_rrtmgp/data_types.py new file mode 100644 index 0000000..ff16737 --- /dev/null +++ b/pyrte_rrtmgp/data_types.py @@ -0,0 +1,40 @@ +from enum import Enum, StrEnum + + +class GasOpticsFiles(StrEnum): + """Enumeration of default RRTMGP gas optics data files. + + This enum defines the available pre-configured gas optics data files that can be used + with RRTMGP. The files contain absorption coefficients and other optical properties + needed for radiative transfer calculations. + + Attributes: + LW_G128: Longwave gas optics file with 128 g-points + LW_G256: Longwave gas optics file with 256 g-points + SW_G112: Shortwave gas optics file with 112 g-points + SW_G224: Shortwave gas optics file with 224 g-points + """ + + LW_G128 = "rrtmgp-gas-lw-g128.nc" + LW_G256 = "rrtmgp-gas-lw-g256.nc" + SW_G112 = "rrtmgp-gas-sw-g112.nc" + SW_G224 = "rrtmgp-gas-sw-g224.nc" + + +class ProblemTypes(StrEnum): + """Enumeration of available radiation calculation types. + + This enum defines the different types of radiation calculations that can be performed, + including both longwave and shortwave calculations with different solution methods. + + Attributes: + LW_ABSORPTION: Longwave absorption-only calculation + LW_2STREAM: Longwave two-stream approximation calculation + SW_DIRECT: Shortwave direct beam calculation + SW_2STREAM: Shortwave two-stream approximation calculation + """ + + LW_ABSORPTION = "Longwave absorption" + LW_2STREAM = "Longwave 2-stream" + SW_DIRECT = "Shortwave direct" + SW_2STREAM = "Shortwave 2-stream" diff --git a/pyrte_rrtmgp/data_validation.py b/pyrte_rrtmgp/data_validation.py new file mode 100644 index 0000000..5d9f7e7 --- /dev/null +++ b/pyrte_rrtmgp/data_validation.py @@ -0,0 +1,193 @@ +from dataclasses import asdict, dataclass +from typing import Dict, Optional, Set, Union + +import xarray as xr + +from pyrte_rrtmgp.config import ( + DEFAULT_DIM_MAPPING, + DEFAULT_GAS_MAPPING, + DEFAULT_VAR_MAPPING, +) + + +@dataclass +class GasMapping: + """Class for managing gas name mappings between standard and dataset-specific names. + + Attributes: + _mapping: Dictionary mapping standard gas names to dataset-specific names + _required_gases: Set of required gas names that must be present + """ + + _mapping: Dict[str, str] + _required_gases: Set[str] + + @classmethod + def create( + cls, gas_names: Set[str], custom_mapping: Optional[Dict[str, str]] = None + ) -> "GasMapping": + """Create a new GasMapping instance with default and custom mappings. + + Args: + gas_names: Set of required gas names + custom_mapping: Optional custom mapping to override defaults + + Returns: + New GasMapping instance + """ + mapping = DEFAULT_GAS_MAPPING.copy() + if custom_mapping: + mapping.update(custom_mapping) + + return cls(mapping, gas_names) + + def validate(self) -> Dict[str, str]: + """Validate and return the final gas name mapping. + + Returns: + Dictionary mapping standard gas names to dataset-specific names + + Raises: + ValueError: If a required gas is not found in any mapping + """ + validated_mapping = {} + + for gas in self._required_gases: + if gas not in self._mapping: + if gas not in DEFAULT_GAS_MAPPING: + raise ValueError(f"Gas {gas} not found in any mapping") + validated_mapping[gas] = DEFAULT_GAS_MAPPING[gas] + else: + validated_mapping[gas] = self._mapping[gas] + + return validated_mapping + + +@dataclass +class DatasetMapping: + """Container for dimension and variable name mappings. + + Attributes: + dim_mapping: Dictionary mapping standard dimension names to dataset-specific names + var_mapping: Dictionary mapping standard variable names to dataset-specific names + """ + + dim_mapping: Dict[str, str] + var_mapping: Dict[str, str] + + def __post_init__(self) -> None: + """Validate mappings upon initialization.""" + pass + + @classmethod + def from_dict(cls, d: Dict[str, Dict[str, str]]) -> "DatasetMapping": + """Create mapping from dictionary representation. + + Args: + d: Dictionary containing dim_mapping and var_mapping + + Returns: + New DatasetMapping instance + """ + return cls(dim_mapping=d["dim_mapping"], var_mapping=d["var_mapping"]) + + +@xr.register_dataset_accessor("mapping") +class DatasetMappingAccessor: + """Accessor for xarray datasets that provides variable mapping functionality. + + The mapping is stored in the dataset's attributes to maintain persistence. + """ + + def __init__(self, xarray_obj: xr.Dataset) -> None: + self._obj = xarray_obj + + def set_mapping(self, mapping: DatasetMapping) -> None: + """Set the mapping in dataset attributes. + + Args: + mapping: DatasetMapping instance to store + + Raises: + ValueError: If mapped dimensions don't exist in dataset + """ + missing_dims = set(mapping.dim_mapping.values()) - set(self._obj.dims) + if missing_dims: + raise ValueError(f"Dataset missing required dimensions: {missing_dims}") + + self._obj.attrs["dataset_mapping"] = asdict(mapping) + + @property + def mapping(self) -> Optional[DatasetMapping]: + """Get the mapping from dataset attributes. + + Returns: + DatasetMapping if exists, None otherwise + """ + if "dataset_mapping" not in self._obj.attrs: + return None + return DatasetMapping.from_dict(self._obj.attrs["dataset_mapping"]) + + def get_var(self, standard_name: str) -> Optional[str]: + """Get the dataset-specific variable name for a standard name. + + Args: + standard_name: Standard variable name + + Returns: + Dataset-specific variable name if found, None otherwise + """ + mapping = self.mapping + if mapping is None: + return None + return mapping.var_mapping.get(standard_name) + + def get_dim(self, standard_name: str) -> Optional[str]: + """Get the dataset-specific dimension name for a standard name. + + Args: + standard_name: Standard dimension name + + Returns: + Dataset-specific dimension name if found, None otherwise + """ + mapping = self.mapping + if mapping is None: + return None + return mapping.dim_mapping.get(standard_name) + + +@dataclass +class AtmosphericMapping(DatasetMapping): + """Specific mapping for atmospheric data with required dimensions and variables. + + Inherits from DatasetMapping and adds validation for required atmospheric fields. + """ + + def __post_init__(self) -> None: + """Validate atmospheric-specific mappings. + + Raises: + ValueError: If required dimensions or variables are missing + """ + required_dims = {"site", "layer", "level"} + missing_dims = required_dims - set(self.dim_mapping.keys()) + if missing_dims: + raise ValueError(f"Missing required dimensions in mapping: {missing_dims}") + + required_vars = {"pres_layer", "pres_level", "temp_layer", "temp_level"} + missing_vars = required_vars - set(self.var_mapping.keys()) + if missing_vars: + raise ValueError(f"Missing required variables in mapping: {missing_vars}") + + +def create_default_mapping() -> AtmosphericMapping: + """Create a default atmospheric mapping configuration. + + Returns: + AtmosphericMapping instance with default dimension and variable mappings + """ + return AtmosphericMapping( + dim_mapping=DEFAULT_DIM_MAPPING, + var_mapping=DEFAULT_VAR_MAPPING, + ) diff --git a/pyrte_rrtmgp/exceptions.py b/pyrte_rrtmgp/exceptions.py deleted file mode 100644 index 9daaa77..0000000 --- a/pyrte_rrtmgp/exceptions.py +++ /dev/null @@ -1,16 +0,0 @@ -class NotInternalSourceError(ValueError): - pass - - -class NotExternalSourceError(ValueError): - pass - - -class MissingAtmosfericConditionsError(AttributeError): - message = ( - "You need to load the atmospheric conditions first." - "Use the method load_atmosferic_conditions with an appropriated file." - ) - - def __init__(self, message=message): - super().__init__(message) diff --git a/pyrte_rrtmgp/kernels/rrtmgp.py b/pyrte_rrtmgp/kernels/rrtmgp.py index 9bd82d0..42faa3d 100644 --- a/pyrte_rrtmgp/kernels/rrtmgp.py +++ b/pyrte_rrtmgp/kernels/rrtmgp.py @@ -9,58 +9,65 @@ rrtmgp_compute_tau_rayleigh, rrtmgp_interpolation, ) -from pyrte_rrtmgp.utils import convert_xarray_args -@convert_xarray_args def interpolation( + ncol: int, + nlay: int, + ngas: int, + nflav: int, neta: int, - flavor: npt.NDArray, - press_ref: npt.NDArray, - temp_ref: npt.NDArray, + npres: int, + ntemp: int, + flavor: npt.NDArray[np.int32], + press_ref: npt.NDArray[np.float64], + temp_ref: npt.NDArray[np.float64], press_ref_trop: float, - vmr_ref: npt.NDArray, - play: npt.NDArray, - tlay: npt.NDArray, - col_gas: npt.NDArray, + vmr_ref: npt.NDArray[np.float64], + play: npt.NDArray[np.float64], + tlay: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], ) -> Tuple[ - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, + npt.NDArray[np.int32], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.bool_], + npt.NDArray[np.int32], + npt.NDArray[np.int32], ]: - """Interpolate the RRTMGP coefficients. + """Interpolate the RRTMGP coefficients to the current atmospheric state. + + This function performs interpolation of gas optics coefficients based on the current + atmospheric temperature and pressure profiles. Args: - neta (int): Number of mixing_fraction. - flavor (np.ndarray): Index into vmr_ref of major gases for each flavor. - press_ref (np.ndarray): Reference pressure grid. - temp_ref (np.ndarray): Reference temperature grid. - press_ref_trop (float): Reference pressure at the tropopause. - vmr_ref (np.ndarray): Reference volume mixing ratio. - play (np.ndarray): Pressure layers. - tlay (np.ndarray): Temperature layers. - col_gas (np.ndarray): Gas concentrations. + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of mixing fraction points + npres: Number of reference pressure grid points + ntemp: Number of reference temperature grid points + flavor: Index into vmr_ref of major gases for each flavor with shape (nflav,) + press_ref: Reference pressure grid with shape (npres,) + temp_ref: Reference temperature grid with shape (ntemp,) + press_ref_trop: Reference pressure at the tropopause + vmr_ref: Reference volume mixing ratios with shape (ngas,) + play: Layer pressures with shape (ncol, nlay) + tlay: Layer temperatures with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) Returns: - Tuple: A tuple containing the following arrays: - - jtemp (np.ndarray): Temperature interpolation index. - - fmajor (np.ndarray): Major gas interpolation fraction. - - fminor (np.ndarray): Minor gas interpolation fraction. - - col_mix (np.ndarray): Mixing fractions. - - tropo (np.ndarray): Use lower (or upper) atmosphere tables. - - jeta (np.ndarray): Index for binary species interpolation. - - jpress (np.ndarray): Pressure interpolation index. + Tuple containing: + - jtemp: Temperature interpolation indices with shape (ncol, nlay) + - fmajor: Major gas interpolation fractions with shape (2, 2, 2, ncol, nlay, nflav) + - fminor: Minor gas interpolation fractions with shape (2, 2, ncol, nlay, nflav) + - col_mix: Mixing fractions with shape (2, ncol, nlay, nflav) + - tropo: Boolean mask for troposphere with shape (ncol, nlay) + - jeta: Binary species interpolation indices with shape (2, ncol, nlay, nflav) + - jpress: Pressure interpolation indices with shape (ncol, nlay) """ - npres = press_ref.shape[0] - ntemp = temp_ref.shape[0] - ncol, nlay, ngas = col_gas.shape - ngas = ngas - 1 # Fortran uses index 0 here - nflav = flavor.shape[1] - press_ref_log = np.log(press_ref) press_ref_log_delta = (press_ref_log.min() - press_ref_log.max()) / ( len(press_ref_log) - 1 @@ -70,14 +77,16 @@ def interpolation( temp_ref_min = temp_ref.min() temp_ref_delta = (temp_ref.max() - temp_ref.min()) / (len(temp_ref) - 1) - # outputs - jtemp = np.ndarray([nlay, ncol], dtype=np.int32) - fmajor = np.ndarray([nflav, nlay, ncol, 2, 2, 2], dtype=np.float64) - fminor = np.ndarray([nflav, nlay, ncol, 2, 2], dtype=np.float64) - col_mix = np.ndarray([nflav, nlay, ncol, 2], dtype=np.float64) - tropo = np.ndarray([nlay, ncol], dtype=np.int32) - jeta = np.ndarray([nflav, nlay, ncol, 2], dtype=np.int32) - jpress = np.ndarray([nlay, ncol], dtype=np.int32) + ngas = ngas - 1 # Fortran uses index 0 here + + # Initialize output arrays + jtemp = np.ndarray([ncol, nlay], dtype=np.int32, order="F") + fmajor = np.ndarray([2, 2, 2, ncol, nlay, nflav], dtype=np.float64, order="F") + fminor = np.ndarray([2, 2, ncol, nlay, nflav], dtype=np.float64, order="F") + col_mix = np.ndarray([2, ncol, nlay, nflav], dtype=np.float64, order="F") + tropo = np.ndarray([ncol, nlay], dtype=np.int32, order="F") + jeta = np.ndarray([2, ncol, nlay, nflav], dtype=np.int32, order="F") + jpress = np.ndarray([ncol, nlay], dtype=np.int32, order="F") args = [ ncol, @@ -87,17 +96,17 @@ def interpolation( neta, npres, ntemp, - flavor.flatten("F"), - press_ref_log.flatten("F"), - temp_ref.flatten("F"), + np.asfortranarray(flavor), + np.asfortranarray(press_ref_log), + np.asfortranarray(temp_ref), press_ref_log_delta, temp_ref_min, temp_ref_delta, press_ref_trop_log, - vmr_ref.flatten("F"), - play.flatten("F"), - tlay.flatten("F"), - col_gas.flatten("F"), + np.asfortranarray(vmr_ref), + np.asfortranarray(play), + np.asfortranarray(tlay), + np.asfortranarray(col_gas), jtemp, fmajor, fminor, @@ -110,69 +119,87 @@ def interpolation( rrtmgp_interpolation(*args) tropo = tropo != 0 # Convert to boolean - return jtemp.T, fmajor.T, fminor.T, col_mix.T, tropo.T, jeta.T, jpress.T + return jtemp, fmajor, fminor, col_mix, tropo, jeta, jpress -@convert_xarray_args def compute_planck_source( - tlay, - tlev, - tsfc, - top_at_1, - fmajor, - jeta, - tropo, - jtemp, - jpress, - band_lims_gpt, - pfracin, - temp_ref_min, - temp_ref_max, - totplnk, - gpoint_flavor, -): - """Compute the Planck source function for a radiative transfer calculation. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + nflav: int, + neta: int, + npres: int, + ntemp: int, + nPlanckTemp: int, + tlay: npt.NDArray[np.float64], + tlev: npt.NDArray[np.float64], + tsfc: npt.NDArray[np.float64], + top_at_1: bool, + fmajor: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + jtemp: npt.NDArray[np.int32], + jpress: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + pfracin: npt.NDArray[np.float64], + temp_ref_min: float, + temp_ref_max: float, + totplnk: npt.NDArray[np.float64], + gpoint_flavor: npt.NDArray[np.int32], +) -> Tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], +]: + """Compute the Planck source function for radiative transfer calculations. + + This function calculates the Planck blackbody emission source terms needed for + longwave radiative transfer calculations. Args: - tlay (numpy.ndarray): Temperature at layer centers (K), shape (ncol, nlay). - tlev (numpy.ndarray): Temperature at layer interfaces (K), shape (ncol, nlay+1). - tsfc (numpy.ndarray): Surface temperature, shape (ncol,). - top_at_1 (bool): Flag indicating if the top layer is at index 0. - sfc_lay (int): Index of the surface layer. - fmajor (numpy.ndarray): Interpolation weights for major gases, shape (2, 2, 2, ncol, nlay, nflav). - jeta (numpy.ndarray): Interpolation indexes in eta, shape (2, ncol, nlay, nflav). - tropo (numpy.ndarray): Use upper- or lower-atmospheric tables, shape (ncol, nlay). - jtemp (numpy.ndarray): Interpolation indexes in temperature, shape (ncol, nlay). - jpress (numpy.ndarray): Interpolation indexes in pressure, shape (ncol, nlay). - band_lims_gpt (numpy.ndarray): Start and end g-point for each band, shape (2, nbnd). - pfracin (numpy.ndarray): Fraction of the Planck function in each g-point, shape (ntemp, neta, npres+1, ngpt). - temp_ref_min (float): Minimum reference temperature for Planck function interpolation. - totplnk (numpy.ndarray): Total Planck function by band at each temperature, shape (nPlanckTemp, nbnd). - gpoint_flavor (numpy.ndarray): Major gas flavor (pair) by upper/lower, g-point, shape (2, ngpt). + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + nflav: Number of gas flavors + neta: Number of eta points + npres: Number of pressure points + ntemp: Number of temperature points + nPlanckTemp: Number of temperatures for Planck function + tlay: Layer temperatures with shape (ncol, nlay) + tlev: Level temperatures with shape (ncol, nlay+1) + tsfc: Surface temperatures with shape (ncol,) + top_at_1: Whether the top of the atmosphere is at index 1 + fmajor: Major gas interpolation weights with shape (2, 2, 2, ncol, nlay, nflav) + jeta: Eta interpolation indices with shape (2, ncol, nlay, nflav) + tropo: Troposphere mask with shape (ncol, nlay) + jtemp: Temperature interpolation indices with shape (ncol, nlay) + jpress: Pressure interpolation indices with shape (ncol, nlay) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + pfracin: Planck fractions with shape (ntemp, neta, npres+1, ngpt) + temp_ref_min: Minimum reference temperature + temp_ref_max: Maximum reference temperature + totplnk: Total Planck function by band with shape (nPlanckTemp, nbnd) + gpoint_flavor: G-point flavors with shape (2, ngpt) Returns: - sfc_src (numpy.ndarray): Planck emission from the surface, shape (ncol, ngpt). - lay_src (numpy.ndarray): Planck emission from layer centers, shape (ncol, nlay, ngpt). - lev_src (numpy.ndarray): Planck emission from layer boundaries, shape (ncol, nlay+1, ngpt). - sfc_source_Jac (numpy.ndarray): Jacobian (derivative) of the surface Planck source with respect to surface temperature, shape (ncol, ngpt). + Tuple containing: + - sfc_src: Surface emission with shape (ncol, ngpt) + - lay_src: Layer emission with shape (ncol, nlay, ngpt) + - lev_src: Level emission with shape (ncol, nlay+1, ngpt) + - sfc_src_jac: Surface emission Jacobian with shape (ncol, ngpt) """ - - _, ncol, nlay, nflav = jeta.shape - nPlanckTemp, nbnd = totplnk.shape - ntemp, neta, npres_e, ngpt = pfracin.shape - npres = npres_e - 1 - sfc_lay = nlay if top_at_1 else 1 - gpoint_bands = [] - totplnk_delta = (temp_ref_max - temp_ref_min) / (nPlanckTemp - 1) - # outputs - sfc_src = np.ndarray([ngpt, ncol], dtype=np.float64) - lay_src = np.ndarray([ngpt, nlay, ncol], dtype=np.float64) - lev_src = np.ndarray([ngpt, nlay + 1, ncol], dtype=np.float64) - sfc_src_jac = np.ndarray([ngpt, ncol], dtype=np.float64) + # Initialize output arrays + sfc_src = np.ndarray((ncol, ngpt), dtype=np.float64, order="F") + lay_src = np.ndarray((ncol, nlay, ngpt), dtype=np.float64, order="F") + lev_src = np.ndarray((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + sfc_src_jac = np.ndarray((ncol, ngpt), dtype=np.float64, order="F") args = [ ncol, @@ -184,22 +211,22 @@ def compute_planck_source( npres, ntemp, nPlanckTemp, - tlay.flatten("F"), - tlev.flatten("F"), - tsfc.flatten("F"), + np.asfortranarray(tlay), + np.asfortranarray(tlev), + np.asfortranarray(tsfc), sfc_lay, - fmajor.flatten("F"), - jeta.flatten("F"), - tropo.flatten("F"), - jtemp.flatten("F"), - jpress.flatten("F"), + np.asfortranarray(fmajor), + np.asfortranarray(jeta), + np.asfortranarray(tropo), + np.asfortranarray(jtemp), + np.asfortranarray(jpress), gpoint_bands, - band_lims_gpt.flatten("F"), - pfracin.flatten("F"), + np.asfortranarray(band_lims_gpt), + np.asfortranarray(pfracin), temp_ref_min, totplnk_delta, - totplnk.flatten("F"), - gpoint_flavor.flatten("F"), + np.asfortranarray(totplnk), + np.asfortranarray(gpoint_flavor), sfc_src, lay_src, lev_src, @@ -208,88 +235,107 @@ def compute_planck_source( rrtmgp_compute_Planck_source(*args) - return sfc_src.T, lay_src.T, lev_src.T, sfc_src_jac.T + return sfc_src, lay_src, lev_src, sfc_src_jac -@convert_xarray_args def compute_tau_absorption( - idx_h2o, - gpoint_flavor, - band_lims_gpt, - kmajor, - kminor_lower, - kminor_upper, - minor_limits_gpt_lower, - minor_limits_gpt_upper, - minor_scales_with_density_lower, - minor_scales_with_density_upper, - scale_by_complement_lower, - scale_by_complement_upper, - idx_minor_lower, - idx_minor_upper, - idx_minor_scaling_lower, - idx_minor_scaling_upper, - kminor_start_lower, - kminor_start_upper, - tropo, - col_mix, - fmajor, - fminor, - play, - tlay, - col_gas, - jeta, - jtemp, - jpress, -): - """Compute the absorption optical depth for a set of atmospheric profiles. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + ngas: int, + nflav: int, + neta: int, + npres: int, + ntemp: int, + nminorlower: int, + nminorklower: int, + nminorupper: int, + nminorkupper: int, + idx_h2o: int, + gpoint_flavor: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + kmajor: npt.NDArray[np.float64], + kminor_lower: npt.NDArray[np.float64], + kminor_upper: npt.NDArray[np.float64], + minor_limits_gpt_lower: npt.NDArray[np.int32], + minor_limits_gpt_upper: npt.NDArray[np.int32], + minor_scales_with_density_lower: npt.NDArray[np.bool_], + minor_scales_with_density_upper: npt.NDArray[np.bool_], + scale_by_complement_lower: npt.NDArray[np.bool_], + scale_by_complement_upper: npt.NDArray[np.bool_], + idx_minor_lower: npt.NDArray[np.int32], + idx_minor_upper: npt.NDArray[np.int32], + idx_minor_scaling_lower: npt.NDArray[np.int32], + idx_minor_scaling_upper: npt.NDArray[np.int32], + kminor_start_lower: npt.NDArray[np.int32], + kminor_start_upper: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + col_mix: npt.NDArray[np.float64], + fmajor: npt.NDArray[np.float64], + fminor: npt.NDArray[np.float64], + play: npt.NDArray[np.float64], + tlay: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + jtemp: npt.NDArray[np.int32], + jpress: npt.NDArray[np.int32], +) -> npt.NDArray[np.float64]: + """Compute the absorption optical depth for atmospheric profiles. + + This function calculates the total absorption optical depth by combining contributions + from major and minor gas species in both the upper and lower atmosphere. Args: - idx_h2o (int): Index of the water vapor gas species. - gpoint_flavor (np.ndarray): Spectral g-point flavor indices. - band_lims_gpt (np.ndarray): Spectral band limits in g-point space. - kmajor (np.ndarray): Major gas absorption coefficients. - kminor_lower (np.ndarray): Minor gas absorption coefficients for the lower atmosphere. - kminor_upper (np.ndarray): Minor gas absorption coefficients for the upper atmosphere. - minor_limits_gpt_lower (np.ndarray): Spectral g-point limits for minor contributors in the lower atmosphere. - minor_limits_gpt_upper (np.ndarray): Spectral g-point limits for minor contributors in the upper atmosphere. - minor_scales_with_density_lower (np.ndarray): Flags indicating if minor contributors in the lower atmosphere scale with density. - minor_scales_with_density_upper (np.ndarray): Flags indicating if minor contributors in the upper atmosphere scale with density. - scale_by_complement_lower (np.ndarray): Flags indicating if minor contributors in the lower atmosphere should be scaled by the complement. - scale_by_complement_upper (np.ndarray): Flags indicating if minor contributors in the upper atmosphere should be scaled by the complement. - idx_minor_lower (np.ndarray): Indices of minor contributors in the lower atmosphere. - idx_minor_upper (np.ndarray): Indices of minor contributors in the upper atmosphere. - idx_minor_scaling_lower (np.ndarray): Indices of minor contributors in the lower atmosphere that require scaling. - idx_minor_scaling_upper (np.ndarray): Indices of minor contributors in the upper atmosphere that require scaling. - kminor_start_lower (np.ndarray): Starting indices of minor absorption coefficients in the lower atmosphere. - kminor_start_upper (np.ndarray): Starting indices of minor absorption coefficients in the upper atmosphere. - tropo (np.ndarray): Flags indicating if a layer is in the troposphere. - col_mix (np.ndarray): Column-dependent gas mixing ratios. - fmajor (np.ndarray): Major gas absorption coefficient scaling factors. - fminor (np.ndarray): Minor gas absorption coefficient scaling factors. - play (np.ndarray): Pressure in each layer. - tlay (np.ndarray): Temperature in each layer. - col_gas (np.ndarray): Column-dependent gas concentrations. - jeta (np.ndarray): Indices of temperature/pressure levels. - jtemp (np.ndarray): Indices of temperature levels. - jpress (np.ndarray): Indices of pressure levels. + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of eta points + npres: Number of pressure points + ntemp: Number of temperature points + nminorlower: Number of minor species in lower atmosphere + nminorklower: Number of minor absorption coefficients in lower atmosphere + nminorupper: Number of minor species in upper atmosphere + nminorkupper: Number of minor absorption coefficients in upper atmosphere + idx_h2o: Index of water vapor + gpoint_flavor: G-point flavors with shape (2, ngpt) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + kmajor: Major gas absorption coefficients + kminor_lower: Minor gas absorption coefficients for lower atmosphere + kminor_upper: Minor gas absorption coefficients for upper atmosphere + minor_limits_gpt_lower: G-point limits for minor gases in lower atmosphere + minor_limits_gpt_upper: G-point limits for minor gases in upper atmosphere + minor_scales_with_density_lower: Density scaling flags for lower atmosphere + minor_scales_with_density_upper: Density scaling flags for upper atmosphere + scale_by_complement_lower: Complement scaling flags for lower atmosphere + scale_by_complement_upper: Complement scaling flags for upper atmosphere + idx_minor_lower: Minor gas indices for lower atmosphere + idx_minor_upper: Minor gas indices for upper atmosphere + idx_minor_scaling_lower: Minor gas scaling indices for lower atmosphere + idx_minor_scaling_upper: Minor gas scaling indices for upper atmosphere + kminor_start_lower: Starting indices for minor gases in lower atmosphere + kminor_start_upper: Starting indices for minor gases in upper atmosphere + tropo: Troposphere mask with shape (ncol, nlay) + col_mix: Gas mixing ratios with shape (2, ncol, nlay, nflav) + fmajor: Major gas interpolation weights + fminor: Minor gas interpolation weights + play: Layer pressures with shape (ncol, nlay) + tlay: Layer temperatures with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) + jeta: Eta interpolation indices + jtemp: Temperature interpolation indices + jpress: Pressure interpolation indices Returns: - np.ndarray): tau Absorption optical depth. + Absorption optical depth with shape (ncol, nlay, ngpt) """ + ngas = ngas - 1 # Fortran uses index 0 here - ntemp, npres_e, neta, ngpt = kmajor.shape - npres = npres_e - 1 - nbnd = band_lims_gpt.shape[1] - _, ncol, nlay, nflav = jeta.shape - ngas = col_gas.shape[2] - 1 - nminorlower = minor_scales_with_density_lower.shape[0] - nminorupper = minor_scales_with_density_upper.shape[0] - nminorklower = kminor_lower.shape[2] - nminorkupper = kminor_upper.shape[2] - - # outputs - tau = np.zeros([ngpt, nlay, ncol], dtype=np.float64) + # Initialize output array + tau = np.zeros((ncol, nlay, ngpt), dtype=np.float64, order="F") args = [ ncol, @@ -306,79 +352,90 @@ def compute_tau_absorption( nminorupper, nminorkupper, idx_h2o, - gpoint_flavor.flatten("F"), # correct - band_lims_gpt.flatten("F"), - kmajor.transpose(0, 2, 1, 3).flatten("F"), - kminor_lower.flatten("F"), - kminor_upper.flatten("F"), - minor_limits_gpt_lower.flatten("F"), - minor_limits_gpt_upper.flatten("F"), - minor_scales_with_density_lower.flatten("F"), - minor_scales_with_density_upper.flatten("F"), - scale_by_complement_lower.flatten("F"), - scale_by_complement_upper.flatten("F"), - idx_minor_lower.flatten("F"), - idx_minor_upper.flatten("F"), - idx_minor_scaling_lower.flatten("F"), - idx_minor_scaling_upper.flatten("F"), - kminor_start_lower.flatten("F"), - kminor_start_upper.flatten("F"), - tropo.flatten("F"), - col_mix.flatten("F"), - fmajor.flatten("F"), - fminor.flatten("F"), - play.flatten("F"), - tlay.flatten("F"), - col_gas.flatten("F"), - jeta.flatten("F"), - jtemp.flatten("F"), - jpress.flatten("F"), + np.asfortranarray(gpoint_flavor), + np.asfortranarray(band_lims_gpt), + np.asfortranarray(kmajor), + np.asfortranarray(kminor_lower), + np.asfortranarray(kminor_upper), + np.asfortranarray(minor_limits_gpt_lower), + np.asfortranarray(minor_limits_gpt_upper), + np.asfortranarray(minor_scales_with_density_lower), + np.asfortranarray(minor_scales_with_density_upper), + np.asfortranarray(scale_by_complement_lower), + np.asfortranarray(scale_by_complement_upper), + np.asfortranarray(idx_minor_lower), + np.asfortranarray(idx_minor_upper), + np.asfortranarray(idx_minor_scaling_lower), + np.asfortranarray(idx_minor_scaling_upper), + np.asfortranarray(kminor_start_lower), + np.asfortranarray(kminor_start_upper), + np.asfortranarray(tropo), + np.asfortranarray(col_mix), + np.asfortranarray(fmajor), + np.asfortranarray(fminor), + np.asfortranarray(play), + np.asfortranarray(tlay), + np.asfortranarray(col_gas), + np.asfortranarray(jeta), + np.asfortranarray(jtemp), + np.asfortranarray(jpress), tau, ] rrtmgp_compute_tau_absorption(*args) - return tau.T + return tau -@convert_xarray_args def compute_tau_rayleigh( - gpoint_flavor, - band_lims_gpt, - krayl, - idx_h2o, - col_dry, - col_gas, - fminor, - jeta, - tropo, - jtemp, -): - """Compute Rayleigh optical depth. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + ngas: int, + nflav: int, + neta: int, + ntemp: int, + gpoint_flavor: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + krayl: npt.NDArray[np.float64], + idx_h2o: int, + col_dry: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], + fminor: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + jtemp: npt.NDArray[np.int32], +) -> npt.NDArray[np.float64]: + """Compute Rayleigh scattering optical depth. + + This function calculates the optical depth due to Rayleigh scattering by air molecules. Args: - gpoint_flavor (numpy.ndarray): Major gas flavor (pair) by upper/lower, g-point (shape: (2, ngpt)). - band_lims_gpt (numpy.ndarray): Start and end g-point for each band (shape: (2, nbnd)). - krayl (numpy.ndarray): Rayleigh scattering coefficients (shape: (ntemp, neta, ngpt, 2)). - idx_h2o (int): Index of water vapor in col_gas. - col_dry (numpy.ndarray): Column amount of dry air (shape: (ncol, nlay)). - col_gas (numpy.ndarray): Input column gas amount (molecules/cm^2) (shape: (ncol, nlay, 0:ngas)). - fminor (numpy.ndarray): Interpolation weights for major gases - computed in interpolation() (shape: (2, 2, ncol, nlay, nflav)). - jeta (numpy.ndarray): Interpolation indexes in eta - computed in interpolation() (shape: (2, ncol, nlay, nflav)). - tropo (numpy.ndarray): Use upper- or lower-atmospheric tables? (shape: (ncol, nlay)). - jtemp (numpy.ndarray): Interpolation indexes in temperature - computed in interpolation() (shape: (ncol, nlay)). + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of eta points + ntemp: Number of temperature points + gpoint_flavor: G-point flavors with shape (2, ngpt) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + krayl: Rayleigh scattering coefficients with shape (ntemp, neta, ngpt, 2) + idx_h2o: Index of water vapor + col_dry: Dry air column amounts with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) + fminor: Minor gas interpolation weights + jeta: Eta interpolation indices + tropo: Troposphere mask with shape (ncol, nlay) + jtemp: Temperature interpolation indices Returns: - numpy.ndarray: Rayleigh optical depth (shape: (ncol, nlay, ngpt)). + Rayleigh scattering optical depth with shape (ncol, nlay, ngpt) """ - - ncol, nlay, ngas = col_gas.shape - ntemp, neta, ngpt, _ = krayl.shape - nflav = jeta.shape[3] - nbnd = band_lims_gpt.shape[1] - - # outputs - tau_rayleigh = np.ndarray((ngpt, nlay, ncol), dtype=np.float64) + # Initialize output array + tau_rayleigh = np.ndarray((ncol, nlay, ngpt), dtype=np.float64, order="F") args = [ ncol, @@ -390,19 +447,19 @@ def compute_tau_rayleigh( neta, 0, # not used in fortran ntemp, - gpoint_flavor.flatten("F"), - band_lims_gpt.flatten("F"), - krayl.flatten("F"), + np.asfortranarray(gpoint_flavor), + np.asfortranarray(band_lims_gpt), + np.asfortranarray(krayl), idx_h2o, - col_dry.flatten("F"), - col_gas.flatten("F"), - fminor.flatten("F"), - jeta.flatten("F"), - tropo.flatten("F"), - jtemp.flatten("F"), + np.asfortranarray(col_dry), + np.asfortranarray(col_gas), + np.asfortranarray(fminor), + np.asfortranarray(jeta), + np.asfortranarray(tropo), + np.asfortranarray(jtemp), tau_rayleigh, ] rrtmgp_compute_tau_rayleigh(*args) - return tau_rayleigh.T + return tau_rayleigh diff --git a/pyrte_rrtmgp/kernels/rte.py b/pyrte_rrtmgp/kernels/rte.py index 512ff30..2596196 100644 --- a/pyrte_rrtmgp/kernels/rte.py +++ b/pyrte_rrtmgp/kernels/rte.py @@ -4,109 +4,80 @@ import numpy.typing as npt from pyrte_rrtmgp.pyrte_rrtmgp import ( + rte_lw_solver_2stream, rte_lw_solver_noscat, rte_sw_solver_2stream, rte_sw_solver_noscat, ) -GAUSS_DS = np.array( - [ - [1.66, 0.0, 0.0, 0.0], # Diffusivity angle, not Gaussian angle - [1.18350343, 2.81649655, 0.0, 0.0], - [1.09719858, 1.69338507, 4.70941630, 0.0], - [1.06056257, 1.38282560, 2.40148179, 7.15513024], - ] -) - - -GAUSS_WTS = np.array( - [ - [0.5, 0.0, 0.0, 0.0], - [0.3180413817, 0.1819586183, 0.0, 0.0], - [0.2009319137, 0.2292411064, 0.0698269799, 0.0], - [0.1355069134, 0.2034645680, 0.1298475476, 0.0311809710], - ] -) - def lw_solver_noscat( - tau: npt.NDArray, - lay_source: npt.NDArray, - lev_source: npt.NDArray, - sfc_emis: npt.NDArray, - sfc_src: npt.NDArray, + ncol: int, + nlay: int, + ngpt: int, + ds: npt.NDArray[np.float64], + weights: npt.NDArray[np.float64], + tau: npt.NDArray[np.float64], + ssa: npt.NDArray[np.float64], + g: npt.NDArray[np.float64], + lay_source: npt.NDArray[np.float64], + lev_source: npt.NDArray[np.float64], + sfc_emis: npt.NDArray[np.float64], + sfc_src: npt.NDArray[np.float64], + sfc_src_jac: npt.NDArray[np.float64], + inc_flux: npt.NDArray[np.float64], top_at_1: bool = True, nmus: int = 1, - inc_flux: Optional[npt.NDArray] = None, - ds: Optional[npt.NDArray] = None, - weights: Optional[npt.NDArray] = None, - do_broadband: Optional[bool] = True, - do_Jacobians: Optional[bool] = False, - sfc_src_jac: Optional[npt.NDArray] = [], - do_rescaling: Optional[bool] = False, - ssa: Optional[npt.NDArray] = None, - g: Optional[np.ndarray] = None, -) -> Tuple: - """ - Perform longwave radiation transfer calculations without scattering. + do_broadband: bool = True, + do_Jacobians: bool = False, + do_rescaling: bool = False, +) -> Tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], +]: + """Perform longwave radiation transfer calculations without scattering. + + This function solves the longwave radiative transfer equation in the absence of scattering, + computing fluxes and optionally their Jacobians. Args: - top_at_1 (bool): Flag indicating whether the top of the atmosphere is at level 1. - nmus (int): Number of quadrature points. - tau (npt.NDArray): Array of optical depths. - lay_source (npt.NDArray): Array of layer sources. - lev_source (npt.NDArray): Array of level sources. - sfc_emis (npt.NDArray): Array of surface emissivities. - sfc_src (npt.NDArray): Array of surface sources. - inc_flux (npt.NDArray): Array of incoming fluxes. - ds (Optional[npt.NDArray], optional): Array of integration weights. Defaults to None. - weights (Optional[npt.NDArray], optional): Array of Gaussian quadrature weights. Defaults to None. - do_broadband (Optional[bool], optional): Flag indicating whether to compute broadband fluxes. Defaults to None. - do_Jacobians (Optional[bool], optional): Flag indicating whether to compute Jacobians. Defaults to None. - sfc_src_jac (Optional[npt.NDArray], optional): Array of surface source Jacobians. Defaults to None. - do_rescaling (Optional[bool], optional): Flag indicating whether to perform flux rescaling. Defaults to None. - ssa (Optional[npt.NDArray], optional): Array of single scattering albedos. Defaults to None. - g (Optional[np.ndarray], optional): Array of asymmetry parameters. Defaults to None. + ncol: Number of columns + nlay: Number of layers + ngpt: Number of g-points + ds: Integration weights with shape (ncol, ngpt, n_quad_angs) + weights: Gaussian quadrature weights with shape (n_quad_angs,) + tau: Optical depths with shape (ncol, nlay, ngpt) + ssa: Single scattering albedos with shape (ncol, nlay, ngpt) + g: Asymmetry parameters with shape (ncol, nlay, ngpt) + lay_source: Layer source terms with shape (ncol, nlay, ngpt) + lev_source: Level source terms with shape (ncol, nlay+1, ngpt) + sfc_emis: Surface emissivities with shape (ncol, ngpt) or (ncol,) + sfc_src: Surface source terms with shape (ncol, ngpt) + sfc_src_jac: Surface source Jacobians with shape (ncol, nlay+1) + inc_flux: Incident fluxes with shape (ncol, ngpt) + top_at_1: Whether the top of the atmosphere is at index 1 + nmus: Number of quadrature points (1-4) + do_broadband: Whether to compute broadband fluxes + do_Jacobians: Whether to compute Jacobians + do_rescaling: Whether to perform flux rescaling Returns: - Tuple: A tuple containing the following arrays: - - flux_up_jac (np.ndarray): Array of upward flux Jacobians. - - broadband_up (np.ndarray): Array of upward broadband fluxes. - - broadband_dn (np.ndarray): Array of downward broadband fluxes. - - flux_up (np.ndarray): Array of upward fluxes. - - flux_dn (np.ndarray): Array of downward fluxes. + Tuple containing: + flux_up_jac: Upward flux Jacobians with shape (ncol, nlay+1) + broadband_up: Upward broadband fluxes with shape (ncol, nlay+1) + broadband_dn: Downward broadband fluxes with shape (ncol, nlay+1) + flux_up: Upward fluxes with shape (ncol, nlay+1, ngpt) + flux_dn: Downward fluxes with shape (ncol, nlay+1, ngpt) """ - - ncol, nlay, ngpt = tau.shape - - if len(sfc_emis.shape) == 1: - sfc_emis = np.stack([sfc_emis] * ngpt).T - - # default values - n_quad_angs = nmus - - if inc_flux is None: - inc_flux = np.zeros(sfc_src.shape) - - if ds is None: - ds = np.empty((ncol, ngpt, n_quad_angs)) - for imu in range(n_quad_angs): - for igpt in range(ngpt): - for icol in range(ncol): - ds[icol, igpt, imu] = GAUSS_DS[imu, n_quad_angs - 1] - - if weights is None: - weights = GAUSS_WTS[0:n_quad_angs, n_quad_angs - 1] - - ssa = ssa or tau - g = g or tau - - # outputs - flux_up_jac = np.full([nlay + 1, ncol], np.nan, dtype=np.float64) - broadband_up = np.full([nlay + 1, ncol], np.nan, dtype=np.float64) - broadband_dn = np.full([nlay + 1, ncol], np.nan, dtype=np.float64) - flux_up = np.full([ngpt, nlay + 1, ncol], np.nan, dtype=np.float64) - flux_dn = np.full([ngpt, nlay + 1, ncol], np.nan, dtype=np.float64) + # Initialize output arrays + flux_up_jac = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + broadband_up = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + broadband_dn = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + flux_up = np.full((ncol, nlay + 1, ngpt), np.nan, dtype=np.float64, order="F") + flux_dn = np.full((ncol, nlay + 1, ngpt), np.nan, dtype=np.float64, order="F") args = [ ncol, @@ -114,57 +85,132 @@ def lw_solver_noscat( ngpt, top_at_1, nmus, - ds.flatten("F"), - weights.flatten("F"), - tau.flatten("F"), - lay_source.flatten("F"), - lev_source.flatten("F"), - sfc_emis.flatten("F"), - sfc_src.flatten("F"), - inc_flux.flatten("F"), + np.asfortranarray(ds), + np.asfortranarray(weights), + np.asfortranarray(tau), + np.asfortranarray(lay_source), + np.asfortranarray(lev_source), + np.asfortranarray(sfc_emis), + np.asfortranarray(sfc_src), + np.asfortranarray(inc_flux), flux_up, flux_dn, do_broadband, broadband_up, broadband_dn, do_Jacobians, - sfc_src_jac.flatten("F"), + np.asfortranarray(sfc_src_jac), flux_up_jac, do_rescaling, - ssa.flatten("F"), - g.flatten("F"), + np.asfortranarray(ssa), + np.asfortranarray(g), ] rte_lw_solver_noscat(*args) - return flux_up_jac.T, broadband_up.T, broadband_dn.T, flux_up.T, flux_dn.T - + return flux_up_jac, broadband_up, broadband_dn, flux_up, flux_dn + + +def lw_solver_2stream( + ncol: int, + nlay: int, + ngpt: int, + tau: npt.NDArray[np.float64], + ssa: npt.NDArray[np.float64], + g: npt.NDArray[np.float64], + lay_source: npt.NDArray[np.float64], + lev_source: npt.NDArray[np.float64], + sfc_emis: npt.NDArray[np.float64], + sfc_src: npt.NDArray[np.float64], + inc_flux: npt.NDArray[np.float64], + top_at_1: bool = True, +) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """Solve the longwave radiative transfer equation using the 2-stream approximation. -def sw_solver_noscat( - top_at_1, - tau, - mu0, - inc_flux_dir, -): - """ - Computes the direct-beam flux for a shortwave radiative transfer problem without scattering. + This function implements the two-stream approximation for longwave radiative transfer, + accounting for both absorption and scattering processes. Args: - top_at_1 (bool): Logical flag indicating if the top layer is at index 1. - tau (numpy.ndarray): Absorption optical thickness of size (ncol, nlay, ngpt). - mu0 (numpy.ndarray): Cosine of solar zenith angle of size (ncol, nlay). - inc_flux_dir (numpy.ndarray): Direct beam incident flux of size (ncol, ngpt). + ncol: Number of columns + nlay: Number of layers + ngpt: Number of g-points + tau: Optical depths with shape (ncol, nlay, ngpt) + ssa: Single-scattering albedos with shape (ncol, nlay, ngpt) + g: Asymmetry parameters with shape (ncol, nlay, ngpt) + lay_source: Layer source terms with shape (ncol, nlay, ngpt) + lev_source: Level source terms with shape (ncol, nlay+1, ngpt) + sfc_emis: Surface emissivities with shape (ncol, ngpt) or (ncol,) + sfc_src: Surface source terms with shape (ncol, ngpt) + inc_flux: Incident fluxes with shape (ncol, ngpt) + top_at_1: Whether the top of the atmosphere is at index 1 Returns: - numpy.ndarray: Direct-beam flux of size (ncol, nlay+1, ngpt). + Tuple containing: + flux_up: Upward fluxes with shape (ncol, nlay+1, ngpt) + flux_dn: Downward fluxes with shape (ncol, nlay+1, ngpt) """ + # Initialize output arrays + flux_up = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + flux_dn = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") - ncol, nlay, ngpt = tau.shape + args = [ + ncol, + nlay, + ngpt, + top_at_1, + np.asfortranarray(tau), + np.asfortranarray(ssa), + np.asfortranarray(g), + np.asfortranarray(lay_source), + np.asfortranarray(lev_source), + np.asfortranarray(sfc_emis), + np.asfortranarray(sfc_src), + np.asfortranarray(inc_flux), + flux_up, + flux_dn, + ] + + rte_lw_solver_2stream(*args) + + return flux_up, flux_dn - # outputs - flux_dir = np.ndarray((ncol, nlay + 1, ngpt), dtype=np.float64) - args = [ncol, nlay, ngpt, top_at_1, tau, mu0, inc_flux_dir, flux_dir] +def sw_solver_noscat( + ncol: int, + nlay: int, + ngpt: int, + tau: npt.NDArray[np.float64], + mu0: npt.NDArray[np.float64], + inc_flux_dir: npt.NDArray[np.float64], + top_at_1: bool = True, +) -> npt.NDArray[np.float64]: + """Perform shortwave radiation transfer calculations without scattering. + + This function solves the shortwave radiative transfer equation in the absence of + scattering, computing direct beam fluxes only. + + Args: + tau: Optical depths with shape (ncol, nlay, ngpt) + mu0: Cosine of solar zenith angles with shape (ncol, nlay) + inc_flux_dir: Direct beam incident fluxes with shape (ncol, ngpt) + top_at_1: Whether the top of the atmosphere is at index 1 + + Returns: + Direct-beam fluxes with shape (ncol, nlay+1, ngpt) + """ + # Initialize output array + flux_dir = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + + args = [ + ncol, + nlay, + ngpt, + top_at_1, + np.asfortranarray(tau), + np.asfortranarray(mu0), + np.asfortranarray(inc_flux_dir), + flux_dir, + ] rte_sw_solver_noscat(*args) @@ -172,81 +218,83 @@ def sw_solver_noscat( def sw_solver_2stream( - top_at_1, - tau, - ssa, - g, - mu0, - sfc_alb_dir, - sfc_alb_dif, - inc_flux_dir, - inc_flux_dif=None, - has_dif_bc=False, - do_broadband=False, -): - """ - Solve the shortwave radiative transfer equation using the 2-stream approximation. + ncol: int, + nlay: int, + ngpt: int, + tau: npt.NDArray[np.float64], + ssa: npt.NDArray[np.float64], + g: npt.NDArray[np.float64], + mu0: npt.NDArray[np.float64], + sfc_alb_dir: npt.NDArray[np.float64], + sfc_alb_dif: npt.NDArray[np.float64], + inc_flux_dir: npt.NDArray[np.float64], + inc_flux_dif: npt.NDArray[np.float64], + top_at_1: bool = True, + has_dif_bc: bool = False, + do_broadband: bool = True, +) -> Tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], +]: + """Perform shortwave radiation transfer calculations using the 2-stream approximation. + + This function implements the two-stream approximation for shortwave radiative transfer, + computing direct, diffuse upward and downward fluxes, as well as optional broadband fluxes. Args: - top_at_1 (bool): Flag indicating whether the top of the atmosphere is at level 1. - tau (ndarray): Array of optical depths with shape (ncol, nlay, ngpt). - ssa (ndarray): Array of single scattering albedos with shape (ncol, nlay, ngpt). - g (ndarray): Array of asymmetry parameters with shape (ncol, nlay, ngpt). - mu0 (ndarray): Array of cosine of solar zenith angles with shape (ncol, ngpt). - sfc_alb_dir (ndarray): Array of direct surface albedos with shape (ncol, ngpt). - sfc_alb_dif (ndarray): Array of diffuse surface albedos with shape (ncol, ngpt). - inc_flux_dir (ndarray): Array of direct incident fluxes with shape (ncol, ngpt). - inc_flux_dif (ndarray, optional): Array of diffuse incident fluxes with shape (ncol, ngpt). - Defaults to None. - has_dif_bc (bool, optional): Flag indicating whether the boundary condition includes diffuse fluxes. - Defaults to False. - do_broadband (bool, optional): Flag indicating whether to compute broadband fluxes. - Defaults to False. + ncol: Number of columns + nlay: Number of layers + ngpt: Number of g-points + tau: Optical depths with shape (ncol, nlay, ngpt) + ssa: Single scattering albedos with shape (ncol, nlay, ngpt) + g: Asymmetry parameters with shape (ncol, nlay, ngpt) + mu0: Cosine of solar zenith angles with shape (ncol, ngpt) + sfc_alb_dir: Direct surface albedos with shape (ncol, ngpt) or (ncol,) + sfc_alb_dif: Diffuse surface albedos with shape (ncol, ngpt) or (ncol,) + inc_flux_dir: Direct incident fluxes with shape (ncol, ngpt) + inc_flux_dif: Diffuse incident fluxes with shape (ncol, ngpt) + top_at_1: Whether the top of the atmosphere is at index 1 + has_dif_bc: Whether the boundary condition includes diffuse fluxes + do_broadband: Whether to compute broadband fluxes Returns: - Tuple of ndarrays: Tuple containing the following arrays: - - flux_up: Array of upward fluxes with shape (ngpt, nlay + 1, ncol). - - flux_dn: Array of downward fluxes with shape (ngpt, nlay + 1, ncol). - - flux_dir: Array of direct fluxes with shape (ngpt, nlay + 1, ncol). - - broadband_up: Array of broadband upward fluxes with shape (nlay + 1, ncol). - - broadband_dn: Array of broadband downward fluxes with shape (nlay + 1, ncol). - - broadband_dir: Array of broadband direct fluxes with shape (nlay + 1, ncol). + Tuple containing: + flux_up: Upward fluxes with shape (ncol, nlay+1, ngpt) + flux_dn: Downward fluxes with shape (ncol, nlay+1, ngpt) + flux_dir: Direct fluxes with shape (ncol, nlay+1, ngpt) + broadband_up: Broadband upward fluxes with shape (ncol, nlay+1) + broadband_dn: Broadband downward fluxes with shape (ncol, nlay+1) + broadband_dir: Broadband direct fluxes with shape (ncol, nlay+1) """ - ncol, nlay, ngpt = tau.shape - - if len(sfc_alb_dir.shape) == 1: - sfc_alb_dir = np.stack([sfc_alb_dir] * ngpt).T - if len(sfc_alb_dif.shape) == 1: - sfc_alb_dif = np.stack([sfc_alb_dif] * ngpt).T - - if inc_flux_dif is None: - inc_flux_dif = np.zeros((ncol, ngpt), dtype=np.float64) - - # outputs - flux_up = np.zeros((ngpt, nlay + 1, ncol), dtype=np.float64) - flux_dn = np.zeros((ngpt, nlay + 1, ncol), dtype=np.float64) - flux_dir = np.zeros((ngpt, nlay + 1, ncol), dtype=np.float64) - broadband_up = np.zeros((nlay + 1, ncol), dtype=np.float64) - broadband_dn = np.zeros((nlay + 1, ncol), dtype=np.float64) - broadband_dir = np.zeros((nlay + 1, ncol), dtype=np.float64) + # Initialize output arrays + flux_up = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + flux_dn = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + flux_dir = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") + broadband_up = np.zeros((ncol, nlay + 1), dtype=np.float64, order="F") + broadband_dn = np.zeros((ncol, nlay + 1), dtype=np.float64, order="F") + broadband_dir = np.zeros((ncol, nlay + 1), dtype=np.float64, order="F") args = [ ncol, nlay, ngpt, top_at_1, - tau.flatten("F"), - ssa.flatten("F"), - g.flatten("F"), - mu0.flatten("F"), - sfc_alb_dir.flatten("F"), - sfc_alb_dif.flatten("F"), - inc_flux_dir.flatten("F"), + np.asfortranarray(tau), + np.asfortranarray(ssa), + np.asfortranarray(g), + np.asfortranarray(mu0), + np.asfortranarray(sfc_alb_dir), + np.asfortranarray(sfc_alb_dif), + np.asfortranarray(inc_flux_dir), flux_up, flux_dn, flux_dir, has_dif_bc, - inc_flux_dif.flatten("F"), + np.asfortranarray(inc_flux_dif), do_broadband, broadband_up, broadband_dn, @@ -255,11 +303,4 @@ def sw_solver_2stream( rte_sw_solver_2stream(*args) - return ( - flux_up.T, - flux_dn.T, - flux_dir.T, - broadband_up.T, - broadband_dn.T, - broadband_dir.T, - ) + return flux_up, flux_dn, flux_dir, broadband_up, broadband_dn, broadband_dir diff --git a/pyrte_rrtmgp/rrtmgp_data.py b/pyrte_rrtmgp/rrtmgp_data.py index 69089aa..6084d3d 100644 --- a/pyrte_rrtmgp/rrtmgp_data.py +++ b/pyrte_rrtmgp/rrtmgp_data.py @@ -2,15 +2,22 @@ import os import platform import tarfile +from pathlib import Path +from typing import Union import requests # URL of the file to download -TAG = "v1.8" +TAG = "v1.8.2" DATA_URL = f"https://github.com/earth-system-radiation/rrtmgp-data/archive/refs/tags/{TAG}.tar.gz" -def get_cache_dir(): +def get_cache_dir() -> str: + """Get the system-specific cache directory for pyrte_rrtmgp data. + + Returns: + str: Path to the cache directory + """ # Determine the system cache folder if platform.system() == "Windows": cache_path = os.getenv("LOCALAPPDATA") @@ -27,7 +34,19 @@ def get_cache_dir(): return cache_path -def download_rrtmgp_data(): +def download_rrtmgp_data() -> str: + """Download and extract RRTMGP data files. + + Downloads the RRTMGP data files from GitHub if not already present in the cache, + verifies the checksum, and extracts the contents. + + Returns: + str: Path to the extracted data directory + + Raises: + requests.exceptions.RequestException: If download fails + tarfile.TarError: If extraction fails + """ # Directory where the data will be stored cache_dir = get_cache_dir() @@ -40,8 +59,8 @@ def download_rrtmgp_data(): # Download the file if it doesn't exist or if the checksum doesn't match if not os.path.exists(file_path) or ( os.path.exists(checksum_file_path) - and open(checksum_file_path).read() - != hashlib.sha256(open(file_path, "rb").read()).hexdigest() + and _get_file_checksum(checksum_file_path) + != _get_file_checksum(file_path, mode="rb") ): response = requests.get(DATA_URL, stream=True) response.raise_for_status() @@ -52,10 +71,25 @@ def download_rrtmgp_data(): # Save the checksum of the downloaded file with open(checksum_file_path, "w") as f: - f.write(hashlib.sha256(open(file_path, "rb").read()).hexdigest()) + f.write(_get_file_checksum(file_path, mode="rb")) # Uncompress the file with tarfile.open(file_path) as tar: - tar.extractall(path=cache_dir) + tar.extractall(path=cache_dir, filter="data") return os.path.join(cache_dir, f"rrtmgp-data-{TAG[1:]}") + + +def _get_file_checksum(filepath: Union[str, Path], mode: str = "r") -> str: + """Calculate SHA256 checksum of a file or read existing checksum. + + Args: + filepath: Path to the file + mode: File open mode, "r" for text or "rb" for binary + + Returns: + str: File content if mode="r", or SHA256 hex digest if mode="rb" + """ + with open(filepath, mode) as f: + content = f.read() + return hashlib.sha256(content).hexdigest() if mode == "rb" else content diff --git a/pyrte_rrtmgp/rrtmgp_gas_optics.py b/pyrte_rrtmgp/rrtmgp_gas_optics.py index 9e27257..5674de8 100644 --- a/pyrte_rrtmgp/rrtmgp_gas_optics.py +++ b/pyrte_rrtmgp/rrtmgp_gas_optics.py @@ -1,16 +1,27 @@ +import logging +import os import sys -from dataclasses import dataclass -from typing import Optional +from typing import Union import numpy as np import numpy.typing as npt +import pandas as pd import xarray as xr -from pyrte_rrtmgp.constants import AVOGAD, HELMERT1, HELMERT2, M_DRY, M_H2O -from pyrte_rrtmgp.exceptions import ( - MissingAtmosfericConditionsError, - NotExternalSourceError, - NotInternalSourceError, +from pyrte_rrtmgp import data_validation +from pyrte_rrtmgp.constants import ( + AVOGAD, + HELMERT1, + HELMERT2, + M_DRY, + M_H2O, + SOLAR_CONSTANTS, +) +from pyrte_rrtmgp.data_types import GasOpticsFiles, ProblemTypes +from pyrte_rrtmgp.data_validation import ( + AtmosphericMapping, + GasMapping, + create_default_mapping, ) from pyrte_rrtmgp.kernels.rrtmgp import ( compute_planck_source, @@ -18,451 +29,1070 @@ compute_tau_rayleigh, interpolation, ) +from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data + +logger = logging.getLogger(__name__) + + +def load_gas_optics( + file_path: str | None = None, + gas_optics_file: GasOpticsFiles | None = None, + selected_gases: list[str] | None = None, +) -> xr.Dataset: + """Load gas optics data from a file or predefined gas optics file. + + This function loads gas optics data either from a custom netCDF file or from + a predefined gas optics file included in the RRTMGP data package. The data + contains absorption coefficients and other optical properties needed for + radiative transfer calculations. + + Args: + file_path: Path to a custom gas optics netCDF file. If provided, this takes + precedence over gas_optics_file. + gas_optics_file: Enum specifying a predefined gas optics file from the RRTMGP + data package. Only used if file_path is None. + selected_gases: Optional list of gas names to include in calculations. + If None, all gases in the file will be used. + + Returns: + xr.Dataset: Dataset containing the gas optics data with selected_gases + stored in the attributes. + + Raises: + ValueError: If neither file_path nor gas_optics_file is provided. + """ + if file_path is not None: + dataset = xr.load_dataset(file_path) + elif gas_optics_file is not None: + rte_rrtmgp_dir = download_rrtmgp_data() + dataset = xr.load_dataset(os.path.join(rte_rrtmgp_dir, gas_optics_file.value)) + else: + raise ValueError("Either file_path or gas_optics_file must be provided") + + dataset.attrs["selected_gases"] = selected_gases + return dataset + + +class BaseGasOpticsAccessor: + """Base class for gas optics calculations. + + This class provides common functionality for both longwave and shortwave gas optics + calculations, including gas interpolation, optical depth computation, and handling of + atmospheric conditions. + + Args: + xarray_obj (xr.Dataset): Dataset containing gas optics data + is_internal (bool): Whether this is for internal (longwave) radiation + selected_gases (list[str] | None): List of gases to include in calculations + + Raises: + ValueError: If 'h2o' is not included in the gas mapping + """ + + def __init__( + self, + xarray_obj: xr.Dataset, + is_internal: bool, + selected_gases: list[str] | None = None, + ) -> None: + self._dataset = xarray_obj + self.is_internal = is_internal + + # Get the gas names from the dataset + self._gas_names: tuple[str, ...] = self.extract_names( + self._dataset["gas_names"].values + ) + + if selected_gases is not None: + # Filter gas names to only include those that exist in the dataset + available_gases = tuple(g for g in selected_gases if g in self._gas_names) + # Log warning for any gases that weren't found + missing_gases = set(selected_gases) - set(available_gases) + for gas in missing_gases: + logger.warning(f"Gas {gas} not found in gas optics file") -@dataclass -class GasOptics: - tau: Optional[np.ndarray] = None - tau_rayleigh: Optional[np.ndarray] = None - tau_absorption: Optional[np.ndarray] = None - g: Optional[np.ndarray] = None - ssa: Optional[np.ndarray] = None - lay_src: Optional[np.ndarray] = None - lev_src: Optional[np.ndarray] = None - sfc_src: Optional[np.ndarray] = None - sfc_src_jac: Optional[np.ndarray] = None - solar_source: Optional[np.ndarray] = None + self._gas_names = available_gases + if "h2o" not in self._gas_names: + raise ValueError( + "'h2o' must be included in gas mapping as it is required to compute Dry air" + ) -@dataclass -class InterpolatedAtmosfereGases: - jtemp: Optional[np.ndarray] = None - fmajor: Optional[np.ndarray] = None - fminor: Optional[np.ndarray] = None - col_mix: Optional[np.ndarray] = None - tropo: Optional[np.ndarray] = None - jeta: Optional[np.ndarray] = None - jpress: Optional[np.ndarray] = None + # Set the gas names as coordinate in the dataset + self._dataset.coords["absorber_ext"] = np.array(("dry_air",) + self._gas_names) + def _initialize_pressure_levels( + self, atmosphere: xr.Dataset, inplace: bool = True + ) -> xr.Dataset | None: + """Initialize pressure levels with minimum pressure adjustment. -@xr.register_dataset_accessor("gas_optics") -class GasOpticsAccessor: - def __init__(self, xarray_obj, selected_gases=None): - self._obj = xarray_obj - self._selected_gases = selected_gases - self._gas_names = None - self._is_internal = None - self._gas_mappings = None - self._top_at_1 = None - self._vmr_ref = None - self.col_gas = None - - self._interpolated = InterpolatedAtmosfereGases() - self.gas_optics = GasOptics() + Args: + atmosphere: Dataset containing atmospheric conditions + inplace: Whether to modify atmosphere in-place or return a copy + + Returns: + Modified atmosphere dataset if inplace=False, otherwise None + """ + pres_level_var = atmosphere.mapping.get_var("pres_level") + + min_index = np.argmin(atmosphere[pres_level_var].data) + min_press = self._dataset["press_ref"].min().item() + sys.float_info.epsilon + atmosphere[pres_level_var][:, min_index] = min_press + + if not inplace: + return atmosphere @property - def gas_names(self): - """Gas names""" - if self._gas_names is None: - names = self._obj["gas_names"].values - self._gas_names = self.extract_names(names) - return self._gas_names + def _selected_gas_names(self) -> list[str]: + """List of selected gas names.""" + return list(self._gas_names) @property - def source_is_internal(self): - """Check if the source is internal""" - if self._is_internal is None: - variables = self._obj.data_vars - self._is_internal = "totplnk" in variables and "plank_fraction" in variables - return self._is_internal + def _selected_gas_names_ext(self) -> list[str]: + """List of selected gas names including dry air.""" + return ["dry_air"] + self._selected_gas_names + + def get_gases_columns( + self, atmosphere: xr.Dataset, gas_name_map: dict[str, str] + ) -> xr.DataArray: + """Get gas columns from atmospheric conditions. - def solar_source(self): - """Calculate the solar variability + Args: + atmosphere: Dataset containing atmospheric conditions + gas_name_map: Mapping between gas names and variable names Returns: - np.ndarray: Solar source + DataArray containing gas columns including dry air """ + pres_level_var = atmosphere.mapping.get_var("pres_level") + + gas_values = [] + for gas_map in gas_name_map.values(): + if gas_map in atmosphere.data_vars: + values = atmosphere[gas_map] + if hasattr(values, "units"): + values = values * float(values.units) + if values.ndim == 0: + values = xr.full_like( + atmosphere[pres_level_var].isel(level=0), values + ) + else: + values = xr.zeros_like(atmosphere[pres_level_var].isel(level=0)) + gas_values.append(values) - if self.source_is_internal: - raise NotExternalSourceError( - "Solar source is not available for internal sources." - ) - - if self.gas_optics.solar_source is None: - a_offset = 0.1495954 - b_offset = 0.00066696 - - solar_source_quiet = self._obj["solar_source_quiet"] - solar_source_facular = self._obj["solar_source_facular"] - solar_source_sunspot = self._obj["solar_source_sunspot"] - - mg_index = self._obj["mg_default"] - sb_index = self._obj["sb_default"] + gas_values = xr.concat( + gas_values, dim=pd.Index(gas_name_map.keys(), name="gas"), coords="minimal" + ) - self.gas_optics.solar_source = ( - solar_source_quiet - + (mg_index - a_offset) * solar_source_facular - + (sb_index - b_offset) * solar_source_sunspot - ).data + col_dry = self.get_col_dry(gas_values.sel(gas="h2o"), atmosphere, latitude=None) - def load_atmosferic_conditions(self, atmosferic_conditions: xr.Dataset): - """Load atmospheric conditions""" - self._atm_cond = atmosferic_conditions + gas_values = gas_values * col_dry + gas_values = xr.concat( + [col_dry.expand_dims(gas=["dry_air"]), gas_values], + dim="gas", + ) - # RRTMGP won't run with pressure less than its minimum. - # So we add a small value to the minimum pressure - min_index = np.argmin(self._atm_cond["pres_level"].data) - min_press = self._obj["press_ref"].min().item() + sys.float_info.epsilon - self._atm_cond["pres_level"][:, min_index] = min_press + return gas_values - self.get_col_gas() + def compute_problem( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute optical properties for radiative transfer problem. - self.interpolate() - self.compute_gas_taus() - if self.source_is_internal: - self.compute_planck() - else: - self.solar_source() + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data - return self.gas_optics + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() - def get_col_gas(self): - if self._atm_cond is None: - raise MissingAtmosfericConditionsError() + def compute_sources( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute radiation sources. - ncol = len(self._atm_cond["site"]) - nlay = len(self._atm_cond["layer"]) - col_gas = [] - for gas_name in self.gas_mappings.values(): - # if gas_name is not available, fill it with zeros - if gas_name not in self._atm_cond.data_vars.keys(): - gas_values = np.zeros((ncol, nlay)) - else: - try: - scale = float(self._atm_cond[gas_name].units) - except AttributeError: - scale = 1.0 - gas_values = self._atm_cond[gas_name].values * scale - - if gas_values.ndim == 0: - gas_values = np.full((ncol, nlay), gas_values) - col_gas.append(gas_values) - - vmr_h2o = col_gas[self.gas_names.index("h2o")] - col_dry = self.get_col_dry( - vmr_h2o, self._atm_cond["pres_level"].data, latitude=None - ) - col_gas = [col_dry] + col_gas + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data - col_gas = np.stack(col_gas, axis=-1).astype(np.float64) - col_gas[:, :, 1:] = col_gas[:, :, 1:] * col_gas[:, :, :1] + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() - self.col_gas = col_gas + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.Dataset: + """Compute boundary conditions. - @property - def gas_mappings(self): - """Gas mappings""" - - if self._atm_cond is None: - raise MissingAtmosfericConditionsError() - - if self._gas_mappings is None: - gas_name_map = { - "h2o": "water_vapor", - "co2": "carbon_dioxide_GM", - "o3": "ozone", - "n2o": "nitrous_oxide_GM", - "co": "carbon_monoxide_GM", - "ch4": "methane_GM", - "o2": "oxygen_GM", - "n2": "nitrogen_GM", - "ccl4": "carbon_tetrachloride_GM", - "cfc11": "cfc11_GM", - "cfc12": "cfc12_GM", - "cfc22": "hcfc22_GM", - "hfc143a": "hfc143a_GM", - "hfc125": "hfc125_GM", - "hfc23": "hfc23_GM", - "hfc32": "hfc32_GM", - "hfc134a": "hfc134a_GM", - "cf4": "cf4_GM", - "no2": "no2", - } + Args: + atmosphere: Dataset containing atmospheric conditions - if self._selected_gases is not None: - gas_name_map = { - g: gas_name_map[g] - for g in self._selected_gases - if g in gas_name_map - } + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() - gas_name_map = { - g: gas_name_map[g] for g in self.gas_names if g in gas_name_map - } - self._gas_mappings = gas_name_map - return self._gas_mappings + def interpolate( + self, atmosphere: xr.Dataset, gas_name_map: dict[str, str] + ) -> xr.Dataset: + """Interpolate gas optics data to atmospheric conditions. - @property - def top_at_1(self): - if self._top_at_1 is None: - if self._atm_cond is None: - raise MissingAtmosfericConditionsError() + Args: + atmosphere: Dataset containing atmospheric conditions + gas_name_map: Mapping between gas names and variable names - pres_layers = self._atm_cond["pres_layer"]["layer"] - self._top_at_1 = pres_layers[0] < pres_layers[-1] - return self._top_at_1.item() + Returns: + Dataset containing interpolated gas optics data + """ + # Get the gas columns from atmospheric conditions + gas_order = self._selected_gas_names_ext + gases_columns = self.get_gases_columns(atmosphere, gas_name_map).sel( + gas=gas_order + ) - @property - def vmr_ref(self): - if self._vmr_ref is None: - if self._atm_cond is None: - raise MissingAtmosfericConditionsError() - sel_gases = self.gas_mappings.keys() - vmr_idx = [i for i, g in enumerate(self._gas_names, 1) if g in sel_gases] - vmr_idx = [0] + vmr_idx - self._vmr_ref = self._obj["vmr_ref"].sel(absorber_ext=vmr_idx).values.T - return self._vmr_ref - - def interpolate(self): - ( - self._interpolated.jtemp, - self._interpolated.fmajor, - self._interpolated.fminor, - self._interpolated.col_mix, - self._interpolated.tropo, - self._interpolated.jeta, - self._interpolated.jpress, - ) = interpolation( - neta=len(self._obj["mixing_fraction"]), - flavor=self.flavors_sets, - press_ref=self._obj["press_ref"].values, - temp_ref=self._obj["temp_ref"].values, - press_ref_trop=self._obj["press_ref_trop"].values.item(), - vmr_ref=self.vmr_ref, - play=self._atm_cond["pres_layer"].values, - tlay=self._atm_cond["temp_layer"].values, - col_gas=self.col_gas, + site_dim = atmosphere.mapping.get_dim("site") + layer_dim = atmosphere.mapping.get_dim("layer") + + npres = self._dataset.sizes["pressure"] + ntemp = self._dataset.sizes["temperature"] + ngas = gases_columns["gas"].size + ncol = atmosphere[site_dim].size + nlay = atmosphere[layer_dim].size + nflav = self.flavors_sets["flavor"].size + neta = self._dataset["mixing_fraction"].size + + jtemp, fmajor, fminor, col_mix, tropo, jeta, jpress = xr.apply_ufunc( + interpolation, + ncol, # ncol + nlay, # nlay + ngas, # ngas + nflav, # nflav + neta, # neta + npres, # npres + ntemp, # ntemp + self.flavors_sets, # flavor + self._dataset["press_ref"], # press_ref + self._dataset["temp_ref"], # temp_ref + self._dataset["press_ref_trop"], # press_ref_trop (scalar) + self._dataset["vmr_ref"].sel(absorber_ext=gas_order), + atmosphere[atmosphere.mapping.get_var("pres_layer")], # play + atmosphere[atmosphere.mapping.get_var("temp_layer")], # tlay + gases_columns.sel(gas=gas_order), # col_gas + input_core_dims=[ + [], # ncol + [], # nlay + [], # ngas + [], # nflav + [], # neta + [], # npres + [], # ntemp + ["pair", "flavor"], # flavor + ["pressure"], # press_ref + ["temperature"], # temp_ref + [], # press_ref_trop + ["atmos_layer", "absorber_ext", "temperature"], # vmr_ref + [site_dim, layer_dim], # play + [site_dim, layer_dim], # tlay + [site_dim, layer_dim, "gas"], # col_gas + ], + output_core_dims=[ + [site_dim, layer_dim], # jtemp + [ + "eta_interp", + "press_interp", + "temp_interp", + site_dim, + layer_dim, + "flavor", + ], # fmajor + ["eta_interp", "temp_interp", site_dim, layer_dim, "flavor"], # fminor + ["temp_interp", site_dim, layer_dim, "flavor"], # col_mix + [site_dim, layer_dim], # tropo + ["pair", site_dim, layer_dim, "flavor"], # jeta + [site_dim, layer_dim], # jpress + ], + vectorize=True, + dask="allowed", ) - def compute_planck(self): - ( - self.gas_optics.sfc_src, - self.gas_optics.lay_src, - self.gas_optics.lev_src, - self.gas_optics.sfc_src_jac, - ) = compute_planck_source( - self._atm_cond["temp_layer"].values, - self._atm_cond["temp_level"].values, - self._atm_cond["surface_temperature"].values, - self.top_at_1, - self._interpolated.fmajor, - self._interpolated.jeta, - self._interpolated.tropo, - self._interpolated.jtemp, - self._interpolated.jpress, - self._obj["bnd_limits_gpt"].values.T, - self._obj["plank_fraction"].values.transpose(0, 2, 1, 3), - self._obj["temp_ref"].values.min(), - self._obj["temp_ref"].values.max(), - self._obj["totplnk"].values.T, - self.gpoint_flavor, + interpolation_results = xr.Dataset( + { + "temperature_index": jtemp, + "fmajor": fmajor, + "fminor": fminor, + "column_mix": col_mix, + "tropopause_mask": tropo, + "eta_index": jeta, + "pressure_index": jpress, + "gases_columns": gases_columns, + } ) - def compute_gas_taus(self): - minor_gases_lower = self.extract_names(self._obj["minor_gases_lower"].data) - minor_gases_upper = self.extract_names(self._obj["minor_gases_upper"].data) - # check if the index is correct - idx_minor_lower = self.get_idx_minor(self.gas_names, minor_gases_lower) - idx_minor_upper = self.get_idx_minor(self.gas_names, minor_gases_upper) + interpolation_results.attrs["dataset_mapping"] = atmosphere.attrs[ + "dataset_mapping" + ] - scaling_gas_lower = self.extract_names(self._obj["scaling_gas_lower"].data) - scaling_gas_upper = self.extract_names(self._obj["scaling_gas_upper"].data) + return interpolation_results - idx_minor_scaling_lower = self.get_idx_minor(self.gas_names, scaling_gas_lower) - idx_minor_scaling_upper = self.get_idx_minor(self.gas_names, scaling_gas_upper) + def tau_absorption( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute absorption optical depth. - tau_absorption = compute_tau_absorption( - self.idx_h2o, + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data + + Returns: + Dataset containing absorption optical depth + """ + site_dim = atmosphere.mapping.get_dim("site") + layer_dim = atmosphere.mapping.get_dim("layer") + + ncol = atmosphere[site_dim].size + nlay = atmosphere[layer_dim].size + ntemp = self._dataset["temperature"].size + neta = self._dataset["mixing_fraction"].size + npres = self._dataset["press_ref"].size + nbnd = self._dataset["bnd"].size + ngpt = self._dataset["gpt"].size + ngas = gas_interpolation_data["gas"].size + nflav = self.flavors_sets["flavor"].size + + nminorlower = self._dataset["minor_scales_with_density_lower"].size + nminorupper = self._dataset["minor_scales_with_density_upper"].size + nminorklower = self._dataset["contributors_lower"].size + nminorkupper = self._dataset["contributors_upper"].size + + minor_gases_lower = self.extract_names(self._dataset["minor_gases_lower"].data) + minor_gases_upper = self.extract_names(self._dataset["minor_gases_upper"].data) + # check if the index is correct + idx_minor_lower = self.get_idx_minor(minor_gases_lower) + idx_minor_upper = self.get_idx_minor(minor_gases_upper) + + scaling_gas_lower = self.extract_names(self._dataset["scaling_gas_lower"].data) + scaling_gas_upper = self.extract_names(self._dataset["scaling_gas_upper"].data) + + idx_minor_scaling_lower = self.get_idx_minor(scaling_gas_lower) + idx_minor_scaling_upper = self.get_idx_minor(scaling_gas_upper) + + pres_layer_var = atmosphere.mapping.get_var("pres_layer") + temp_layer_var = atmosphere.mapping.get_var("temp_layer") + + tau_absorption = xr.apply_ufunc( + compute_tau_absorption, + ncol, + nlay, + nbnd, + ngpt, + ngas, + nflav, + neta, + npres, + ntemp, + nminorlower, + nminorklower, + nminorupper, + nminorkupper, + self._selected_gas_names_ext.index("h2o"), self.gpoint_flavor, - self._obj["bnd_limits_gpt"].values.T, - self._obj["kmajor"].values, - self._obj["kminor_lower"].values, - self._obj["kminor_upper"].values, - self._obj["minor_limits_gpt_lower"].values.T, - self._obj["minor_limits_gpt_upper"].values.T, - self._obj["minor_scales_with_density_lower"].values.astype(bool), - self._obj["minor_scales_with_density_upper"].values.astype(bool), - self._obj["scale_by_complement_lower"].values.astype(bool), - self._obj["scale_by_complement_upper"].values.astype(bool), + self._dataset["bnd_limits_gpt"], + self._dataset["kmajor"], + self._dataset["kminor_lower"], + self._dataset["kminor_upper"], + self._dataset["minor_limits_gpt_lower"], + self._dataset["minor_limits_gpt_upper"], + self._dataset["minor_scales_with_density_lower"], + self._dataset["minor_scales_with_density_upper"], + self._dataset["scale_by_complement_lower"], + self._dataset["scale_by_complement_upper"], idx_minor_lower, idx_minor_upper, idx_minor_scaling_lower, idx_minor_scaling_upper, - self._obj["kminor_start_lower"].values, - self._obj["kminor_start_upper"].values, - self._interpolated.tropo, - self._interpolated.col_mix, - self._interpolated.fmajor, - self._interpolated.fminor, - self._atm_cond["pres_layer"].values, - self._atm_cond["temp_layer"].values, - self.col_gas, - self._interpolated.jeta, - self._interpolated.jtemp, - self._interpolated.jpress, + self._dataset["kminor_start_lower"], + self._dataset["kminor_start_upper"], + gas_interpolation_data["tropopause_mask"], + gas_interpolation_data["column_mix"], + gas_interpolation_data["fmajor"], + gas_interpolation_data["fminor"], + atmosphere[pres_layer_var], + atmosphere[temp_layer_var], + gas_interpolation_data["gases_columns"], + gas_interpolation_data["eta_index"], + gas_interpolation_data["temperature_index"], + gas_interpolation_data["pressure_index"], + input_core_dims=[ + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], # idx_h2o + ["atmos_layer", "gpt"], # gpoint_flavor + ["pair", "bnd"], # bnd_limits_gpt + ["temperature", "mixing_fraction", "pressure_interp", "gpt"], # kmajor + [ + "temperature", + "mixing_fraction", + "contributors_lower", + ], # kminor_lower + [ + "temperature", + "mixing_fraction", + "contributors_upper", + ], # kminor_upper + ["pair", "minor_absorber_intervals_lower"], # minor_limits_gpt_lower + ["pair", "minor_absorber_intervals_upper"], # minor_limits_gpt_upper + ["minor_absorber_intervals_lower"], # minor_scales_with_density_lower + ["minor_absorber_intervals_upper"], # minor_scales_with_density_upper + ["minor_absorber_intervals_lower"], # scale_by_complement_lower + ["minor_absorber_intervals_upper"], # scale_by_complement_upper + ["minor_absorber_intervals_lower"], # idx_minor_lower + ["minor_absorber_intervals_upper"], # idx_minor_upper + ["minor_absorber_intervals_lower"], # idx_minor_scaling_lower + ["minor_absorber_intervals_upper"], # idx_minor_scaling_upper + ["minor_absorber_intervals_lower"], # kminor_start_lower + ["minor_absorber_intervals_upper"], # kminor_start_upper + [site_dim, layer_dim], # tropopause_mask + ["temp_interp", site_dim, layer_dim, "flavor"], # column_mix + [ + "eta_interp", + "press_interp", + "temp_interp", + site_dim, + layer_dim, + "flavor", + ], # fmajor + ["eta_interp", "temp_interp", site_dim, layer_dim, "flavor"], # fminor + [site_dim, layer_dim], # pres_layer + [site_dim, layer_dim], # temp_layer + [site_dim, layer_dim, "gas"], # gases_columns + ["pair", site_dim, layer_dim, "flavor"], # eta_index + [site_dim, layer_dim], # temperature_index + [site_dim, layer_dim], # pressure_index + ], + output_core_dims=[[site_dim, layer_dim, "gpt"]], + vectorize=True, + dask="allowed", ) - self.gas_optics.tau_absorption = tau_absorption - if self.source_is_internal: - self.gas_optics.tau = tau_absorption - self.gas_optics.ssa = np.full_like(tau_absorption, np.nan) - self.gas_optics.g = np.full_like(tau_absorption, np.nan) - else: - krayl = np.stack( - [self._obj["rayl_lower"].values, self._obj["rayl_upper"].values], - axis=-1, - ) - tau_rayleigh = compute_tau_rayleigh( - self.gpoint_flavor, - self._obj["bnd_limits_gpt"].values.T, - krayl, - self.idx_h2o, - self.col_gas[:, :, 0], - self.col_gas, - self._interpolated.fminor, - self._interpolated.jeta, - self._interpolated.tropo, - self._interpolated.jtemp, - ) - - self.gas_optics.tau_rayleigh = tau_rayleigh - self.gas_optics.tau = tau_absorption + tau_rayleigh - self.gas_optics.ssa = np.where( - self.gas_optics.tau > 2.0 * np.finfo(float).tiny, - tau_rayleigh / self.gas_optics.tau, - 0.0, - ) - self.gas_optics.g = np.zeros(self.gas_optics.tau.shape) - - @property - def idx_h2o(self): - return list(self.gas_names).index("h2o") + 1 + return tau_absorption.rename("tau").to_dataset() @property - def gpoint_flavor(self) -> npt.NDArray: + def gpoint_flavor(self) -> xr.DataArray: """Get the g-point flavors from the k-distribution file. Each g-point is associated with a flavor, which is a pair of key species. Returns: - np.ndarray: G-point flavors. + DataArray containing g-point flavors """ - key_species = self._obj["key_species"].values - - band_ranges = [ - [i] * (r.values[1] - r.values[0] + 1) - for i, r in enumerate(self._obj["bnd_limits_gpt"], 1) - ] - gpoint_bands = np.concatenate(band_ranges) - - key_species_rep = key_species.copy() - key_species_rep[np.all(key_species_rep == [0, 0], axis=2)] = [2, 2] - - # unique flavors - flist = self.flavors_sets.T.tolist() - - def key_species_pair2flavor(key_species_pair): - return flist.index(key_species_pair.tolist()) + 1 + band_sizes = ( + self._dataset["bnd_limits_gpt"].values[:, 1] + - self._dataset["bnd_limits_gpt"].values[:, 0] + + 1 + ) + gpoint_bands = xr.DataArray( + np.repeat(np.arange(1, len(band_sizes) + 1), band_sizes), + dims=["gpt"], + coords={"gpt": self._dataset.gpt}, + ) - flavors_bands = np.apply_along_axis( - key_species_pair2flavor, 2, key_species_rep - ).tolist() - gpoint_flavor = np.array([flavors_bands[gp - 1] for gp in gpoint_bands]).T + # key_species = self._dataset["key_species"] + key_species_rep = xr.where( + (self._dataset["key_species"] == 0).all("pair"), + np.array([2, 2]), + self._dataset["key_species"], + ) - return gpoint_flavor + matches = (self.flavors_sets == key_species_rep).all(dim="pair") + match_indices = ( + matches.argmax(dim="flavor") + 1 + ) # +1 because flavors are 1-indexed + # Create a mapping from band number to flavor index + band_to_flavor = match_indices.sel(bnd=np.arange(len(band_sizes))) + # Map each g-point to its corresponding flavor using the band number + return band_to_flavor.sel(bnd=gpoint_bands - 1) @property - def flavors_sets(self) -> npt.NDArray: + def flavors_sets(self) -> xr.DataArray: """Get the unique flavors from the k-distribution file. Returns: - np.ndarray: Unique flavors. + DataArray containing unique flavors """ - key_species = self._obj["key_species"].values - tot_flav = len(self._obj["bnd"]) * len(self._obj["atmos_layer"]) - npairs = len(self._obj["pair"]) - all_flav = np.reshape(key_species, (tot_flav, npairs)) - # (0,0) becomes (2,2) because absorption coefficients for these g-points will be 0. - all_flav[np.all(all_flav == [0, 0], axis=1)] = [2, 2] - # we do that instead of unique to preserv the order - _, idx = np.unique(all_flav, axis=0, return_index=True) - return all_flav[np.sort(idx)].T + # Calculate total number of flavors and pairs + n_bands = self._dataset["bnd"].size + n_layers = self._dataset["atmos_layer"].size + n_pairs = self._dataset["pair"].size + tot_flavors = n_bands * n_layers + + # Flatten key species array + all_flavors = np.reshape( + self._dataset["key_species"].data, (tot_flavors, n_pairs) + ) - @staticmethod - def get_idx_minor(gas_names, minor_gases): - """Index of each minor gas in col_gas + # Replace (0,0) pairs with (2,2) since these g-points have zero absorption + zero_mask = np.all(all_flavors == [0, 0], axis=1) + all_flavors[zero_mask] = [2, 2] + + # Get unique flavors while preserving original order + _, unique_indices = np.unique(all_flavors, axis=0, return_index=True) + unique_flavors = all_flavors[np.sort(unique_indices)] + + # Create xarray DataArray with flavor data + return xr.DataArray( + unique_flavors, + dims=["flavor", "pair"], + coords={ + "pair": np.arange(unique_flavors.shape[1]), + "flavor": np.arange(1, unique_flavors.shape[0] + 1), + }, + ) + + def get_idx_minor(self, minor_gases: list[str]) -> npt.NDArray[np.int32]: + """Get index of each minor gas in col_gas. Args: - gas_names (list): Gas names - minor_gases (list): List of minor gases + minor_gases: List of minor gases Returns: - list: Index of each minor gas in col_gas + Array containing indices of minor gases """ idx_minor_gas = [] for gas in minor_gases: try: - gas_idx = gas_names.index(gas) + 1 + gas_idx = self._selected_gas_names.index(gas) + 1 except ValueError: gas_idx = -1 idx_minor_gas.append(gas_idx) return np.array(idx_minor_gas, dtype=np.int32) @staticmethod - def extract_names(names): - """Extract names from arrays, decoding and removing the suffix + def extract_names(names: npt.NDArray) -> tuple[str, ...]: + """Extract names from arrays, decoding and removing the suffix. Args: - names (np.ndarray): Names + names: Array of encoded names Returns: - tuple: tuple of names + Tuple of decoded and cleaned names """ output = tuple(gas.tobytes().decode().strip().split("_")[0] for gas in names) return output @staticmethod - def get_col_dry(vmr_h2o, plev, latitude=None): - """Calculate the dry column of the atmosphere + def get_col_dry( + vmr_h2o: xr.DataArray, + atmosphere: xr.Dataset, + latitude: xr.DataArray | None = None, + ) -> xr.DataArray: + """Calculate the dry column of the atmosphere. Args: - vmr_h2o (np.ndarray): Water vapor volume mixing ratio - plev (np.ndarray): Pressure levels - latitude (np.ndarray): Latitude of the location + vmr_h2o: Water vapor volume mixing ratio + atmosphere: Dataset containing atmospheric conditions + latitude: Latitude of the location Returns: - np.ndarray: Dry column of the atmosphere + DataArray containing dry column of the atmosphere """ - ncol = plev.shape[0] - nlev = plev.shape[1] - col_dry = np.zeros((ncol, nlev - 1)) + site_dim = atmosphere.mapping.get_dim("site") + level_dim = atmosphere.mapping.get_dim("level") + layer_dim = atmosphere.mapping.get_dim("layer") + pres_level_var = atmosphere.mapping.get_var("pres_level") + + plev = atmosphere[pres_level_var] + # Convert latitude to g0 DataArray if latitude is not None: - g0 = HELMERT1 - HELMERT2 * np.cos(2.0 * np.pi * latitude / 180.0) + g0 = xr.DataArray( + HELMERT1 - HELMERT2 * np.cos(2.0 * np.pi * latitude / 180.0), + dims=[site_dim], + coords={site_dim: plev.site}, + ) + else: + g0 = xr.full_like(plev.isel(level=0), HELMERT1) + + # Calculate pressure difference between layers + delta_plev = np.abs(plev.diff(dim=level_dim)).rename({level_dim: layer_dim}) + + # Calculate factors using xarray operations + fact = 1.0 / (1.0 + vmr_h2o) + m_air = (M_DRY + M_H2O * vmr_h2o) * fact + + # Calculate col_dry using xarray operations + col_dry = 10.0 * delta_plev * AVOGAD * fact / (1000.0 * m_air * 100.0 * g0) + + return col_dry.rename("dry_air") + + def compute( + self, + atmosphere: xr.Dataset, + problem_type: str, + gas_name_map: dict[str, str] | None = None, + variable_mapping: AtmosphericMapping | None = None, + add_to_input: bool = True, + ) -> xr.Dataset | None: + """Compute gas optics for given atmospheric conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + problem_type: Type of radiative transfer problem to solve + gas_name_map: Optional mapping between gas names and variable names + variable_mapping: Optional mapping for atmospheric variables + add_to_input: Whether to add results to input dataset + + Returns: + Dataset containing gas optics results if add_to_input=False, + otherwise None + + Raises: + ValueError: If problem_type is invalid + """ + # Create and validate gas mapping + gas_mapping = GasMapping.create(self._gas_names, gas_name_map).validate() + + if variable_mapping is None: + variable_mapping = create_default_mapping() + # Set mapping in accessor + atmosphere.mapping.set_mapping(variable_mapping) + + # Modify pressure levels to avoid division by zero, runs inplace + self._initialize_pressure_levels(atmosphere) + + gas_interpolation_data = self.interpolate(atmosphere, gas_mapping) + problem = self.compute_problem(atmosphere, gas_interpolation_data) + sources = self.compute_sources(atmosphere, gas_interpolation_data) + boundary_conditions = self.compute_boundary_conditions(atmosphere) + + gas_optics = xr.merge([sources, problem, boundary_conditions]) + + # Add problem type to dataset attributes + if problem_type == "absorption" and self.is_internal: + problem_type = ProblemTypes.LW_ABSORPTION.value + elif problem_type == "two-stream" and self.is_internal: + problem_type = ProblemTypes.LW_2STREAM.value + elif problem_type == "direct" and not self.is_internal: + problem_type = ProblemTypes.SW_DIRECT.value + elif problem_type == "two-stream" and not self.is_internal: + problem_type = ProblemTypes.SW_2STREAM.value + else: + raise ValueError( + f"Invalid problem type: {problem_type} for {'LW' if self.is_internal else 'SW'} radiation" + ) + + if add_to_input: + atmosphere.update(gas_optics) + atmosphere.attrs["problem_type"] = problem_type + else: + output_ds = gas_optics + output_ds.attrs["problem_type"] = problem_type + output_ds.mapping.set_mapping(variable_mapping) + return output_ds + + +class LWGasOpticsAccessor(BaseGasOpticsAccessor): + """Accessor for internal (longwave) radiation sources. + + This class handles gas optics calculations specific to longwave radiation, including + computing absorption optical depths, Planck sources, and boundary conditions. + """ + + def compute_problem( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute absorption optical depths for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing absorption optical depths + """ + return self.tau_absorption(atmosphere, gas_interpolation_data) + + def compute_sources( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute Planck source terms for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Planck source terms + """ + return self.compute_planck(atmosphere, gas_interpolation_data) + + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.DataArray: + """Compute surface emissivity boundary conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + + Returns: + DataArray containing surface emissivity values + """ + surface_emissivity_var = atmosphere.mapping.get_var("surface_emissivity") + site_dim = atmosphere.mapping.get_dim("site") + + if surface_emissivity_var not in atmosphere.data_vars: + # Add surface emissivity directly to atmospheric conditions + return xr.DataArray( + np.ones((atmosphere.sizes[site_dim],)), + dims=[site_dim], + coords={ + site_dim: atmosphere[site_dim], + }, + ) + else: + return atmosphere[surface_emissivity_var] + + def compute_planck( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute Planck source terms for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Planck source terms including surface, layer and level sources + """ + site_dim = atmosphere.mapping.get_dim("site") + layer_dim = atmosphere.mapping.get_dim("layer") + level_dim = atmosphere.mapping.get_dim("level") + + temp_layer_var = atmosphere.mapping.get_var("temp_layer") + temp_level_var = atmosphere.mapping.get_var("temp_level") + surface_temperature_var = atmosphere.mapping.get_var("surface_temperature") + + # Check if the top layer is at the first level + top_at_1 = atmosphere[layer_dim][0] < atmosphere[layer_dim][-1] + + ncol = atmosphere.sizes[site_dim] + nlay = atmosphere.sizes[layer_dim] + nbnd = self._dataset.sizes["bnd"] + ngpt = self._dataset.sizes["gpt"] + nflav = self.flavors_sets.sizes["flavor"] + neta = self._dataset.sizes["mixing_fraction"] + npres = self._dataset.sizes["pressure"] + ntemp = self._dataset.sizes["temperature"] + nPlanckTemp = self._dataset.sizes["temperature_Planck"] + + sfc_src, lay_source, lev_source, sfc_src_jac = xr.apply_ufunc( + compute_planck_source, + ncol, + nlay, + nbnd, + ngpt, + nflav, + neta, + npres, + ntemp, + nPlanckTemp, + atmosphere[temp_layer_var], + atmosphere[temp_level_var], + atmosphere[surface_temperature_var], + top_at_1, + gas_interpolation_data["fmajor"], + gas_interpolation_data["eta_index"], + gas_interpolation_data["tropopause_mask"], + gas_interpolation_data["temperature_index"], + gas_interpolation_data["pressure_index"], + self._dataset["bnd_limits_gpt"], + self._dataset["plank_fraction"], + self._dataset["temp_ref"].min(), + self._dataset["temp_ref"].max(), + self._dataset["totplnk"], + self.gpoint_flavor, + input_core_dims=[ + [], + [], + [], + [], + [], + [], + [], + [], + [], # scalar dimensions + [site_dim, layer_dim], # tlay + [site_dim, level_dim], # tlev + [site_dim], # tsfc + [], # top_at_1 + [ + "eta_interp", + "press_interp", + "temp_interp", + site_dim, + layer_dim, + "flavor", + ], # fmajor + ["pair", site_dim, layer_dim, "flavor"], # jeta + [site_dim, layer_dim], # tropo + [site_dim, layer_dim], # jtemp + [site_dim, layer_dim], # jpress + ["pair", "bnd"], # band_lims_gpt + ["temperature", "mixing_fraction", "pressure_interp", "gpt"], # pfracin + [], # temp_ref_min + [], # temp_ref_max + ["temperature_Planck", "bnd"], # totplnk + ["atmos_layer", "gpt"], # gpoint_flavor + ], + output_core_dims=[ + ["site", "gpt"], # sfc_src + ["site", "layer", "gpt"], # lay_source + ["site", "level", "gpt"], # lev_source + ["site", "gpt"], # sfc_src_jac + ], + vectorize=True, + dask="allowed", + ) + + return xr.Dataset( + { + "surface_source": sfc_src, + "layer_source": lay_source, + "level_source": lev_source, + "surface_source_jacobian": sfc_src_jac, + } + ) + + +class SWGasOpticsAccessor(BaseGasOpticsAccessor): + """Accessor for external (shortwave) radiation sources. + + This class handles gas optics calculations specific to shortwave radiation, including + computing absorption and Rayleigh scattering optical depths, solar sources, and boundary conditions. + """ + + def compute_problem( + self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset + ) -> xr.Dataset: + """Compute optical properties for shortwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing optical properties (tau, ssa, g) + """ + # Calculate absorption optical depth + tau_abs = self.tau_absorption(atmosphere, gas_interpolation_data) + + # Calculate Rayleigh scattering optical depth + tau_rayleigh = self.tau_rayleigh(gas_interpolation_data) + tau = tau_abs + tau_rayleigh + ssa = xr.where( + tau["tau"] > 2.0 * np.finfo(float).tiny, + tau_rayleigh["tau"] / tau["tau"], + 0.0, + ).rename("ssa") + g = xr.zeros_like(tau["tau"]).rename("g") + return xr.merge([tau, ssa, g]) + + def compute_sources(self, atmosphere: xr.Dataset, *args, **kwargs) -> xr.DataArray: + """Compute solar source terms. + + Args: + atmosphere: Dataset containing atmospheric conditions + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + + Returns: + DataArray containing top-of-atmosphere solar source + """ + a_offset = SOLAR_CONSTANTS["A_OFFSET"] + b_offset = SOLAR_CONSTANTS["B_OFFSET"] + + solar_source_quiet = self._dataset["solar_source_quiet"] + solar_source_facular = self._dataset["solar_source_facular"] + solar_source_sunspot = self._dataset["solar_source_sunspot"] + + mg_index = self._dataset["mg_default"] + sb_index = self._dataset["sb_default"] + + solar_source = ( + solar_source_quiet + + (mg_index - a_offset) * solar_source_facular + + (sb_index - b_offset) * solar_source_sunspot + ) + + total_solar_irradiance = atmosphere["total_solar_irradiance"] + + toa_flux = solar_source.broadcast_like(total_solar_irradiance) + def_tsi = toa_flux.sum(dim="gpt") + return (toa_flux * (total_solar_irradiance / def_tsi)).rename("toa_source") + + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.Dataset: + """Compute surface and solar boundary conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + + Returns: + Dataset containing solar zenith angles, surface albedos and solar angle mask + """ + site_dim = atmosphere.mapping.get_dim("site") + layer_dim = atmosphere.mapping.get_dim("layer") + + solar_zenith_angle_var = atmosphere.mapping.get_var("solar_zenith_angle") + surface_albedo_var = atmosphere.mapping.get_var("surface_albedo") + surface_albedo_dir_var = atmosphere.mapping.get_var("surface_albedo_dir") + surface_albedo_dif_var = atmosphere.mapping.get_var("surface_albedo_dif") + + usecol_values = atmosphere[solar_zenith_angle_var] < ( + 90.0 - 2.0 * np.spacing(90.0) + ) + usecol_values = usecol_values.rename("solar_angle_mask") + mu0 = xr.where( + usecol_values, + np.cos(np.radians(atmosphere[solar_zenith_angle_var])), + 1.0, + ) + solar_zenith_angle = mu0.broadcast_like(atmosphere[layer_dim]).rename( + "solar_zenith_angle" + ) + + if surface_albedo_dir_var not in atmosphere.data_vars: + surface_albedo_direct = atmosphere[surface_albedo_var] + surface_albedo_direct = surface_albedo_direct.rename( + "surface_albedo_direct" + ) + surface_albedo_diffuse = atmosphere[surface_albedo_var] + surface_albedo_diffuse = surface_albedo_diffuse.rename( + "surface_albedo_diffuse" + ) + else: + surface_albedo_direct = atmosphere[surface_albedo_dir_var] + surface_albedo_direct = surface_albedo_direct.rename( + "surface_albedo_direct" + ) + surface_albedo_diffuse = atmosphere[surface_albedo_dif_var] + surface_albedo_diffuse = surface_albedo_diffuse.rename( + "surface_albedo_diffuse" + ) + + return xr.merge( + [ + solar_zenith_angle, + surface_albedo_direct, + surface_albedo_diffuse, + usecol_values, + ] + ) + + def tau_rayleigh(self, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute Rayleigh scattering optical depth. + + Args: + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Rayleigh scattering optical depth + """ + # Combine upper and lower Rayleigh coefficients + krayl = xr.concat( + [self._dataset["rayl_lower"], self._dataset["rayl_upper"]], + dim=pd.Index(["lower", "upper"], name="rayl_bound"), + ) + + site_dim = gas_interpolation_data.mapping.get_dim("site") + layer_dim = gas_interpolation_data.mapping.get_dim("layer") + + tau_rayleigh = xr.apply_ufunc( + compute_tau_rayleigh, + gas_interpolation_data.sizes[site_dim], + gas_interpolation_data.sizes[layer_dim], + self._dataset.sizes["bnd"], + self._dataset.sizes["gpt"], + gas_interpolation_data.sizes["gas"], + self.flavors_sets.sizes["flavor"], + self._dataset.sizes["mixing_fraction"], + self._dataset.sizes["temperature"], + self.gpoint_flavor, # gpoint_flavor + self._dataset["bnd_limits_gpt"], # band_lims_gpt + krayl, # krayl + self._selected_gas_names_ext.index("h2o"), # idx_h2o + gas_interpolation_data["gases_columns"].sel(gas="dry_air"), # col_dry + gas_interpolation_data["gases_columns"].sel( + gas=self._selected_gas_names_ext + ), # col_gas + gas_interpolation_data["fminor"], # fminor + gas_interpolation_data["eta_index"], # jeta + gas_interpolation_data["tropopause_mask"], # tropo + gas_interpolation_data["temperature_index"], # jtemp + input_core_dims=[ + [], + [], + [], + [], + [], + [], + [], + [], # scalar dimensions + ["atmos_layer", "gpt"], # gpoint_flavor + ["pair", "bnd"], # band_lims_gpt + ["temperature", "mixing_fraction", "gpt", "rayl_bound"], # krayl + [], # idx_h2o + [site_dim, layer_dim], # col_dry + [site_dim, layer_dim, "gas"], # col_gas + ["eta_interp", "temp_interp", site_dim, layer_dim, "flavor"], # fminor + ["pair", site_dim, layer_dim, "flavor"], # jeta + [site_dim, layer_dim], # tropo + [site_dim, layer_dim], # jtemp + ], + output_core_dims=[[site_dim, layer_dim, "gpt"]], + vectorize=True, + dask="allowed", + ) + + return tau_rayleigh.rename("tau").to_dataset() + + +@xr.register_dataset_accessor("gas_optics") +class GasOpticsAccessor: + """Factory class that returns appropriate GasOptics implementation based on dataset contents. + + This class determines whether to return a longwave (LW) or shortwave (SW) gas optics + accessor by checking for the presence of internal source variables in the dataset. + + Args: + xarray_obj (xr.Dataset): The xarray Dataset containing gas optics data + selected_gases (list[str] | None): Optional list of gas names to include. + If None, all gases in the dataset will be used. + + Returns: + Union[LWGasOpticsAccessor, SWGasOpticsAccessor]: The appropriate gas optics accessor + based on whether internal source terms are present. + """ + + def __new__( + cls, xarray_obj: xr.Dataset, selected_gases: list[str] | None = None + ) -> Union[LWGasOpticsAccessor, SWGasOpticsAccessor]: + # Check if source is internal by looking for required LW variables + is_internal: bool = ( + "totplnk" in xarray_obj.data_vars + and "plank_fraction" in xarray_obj.data_vars + ) + + if is_internal: + return LWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) else: - g0 = np.full(ncol, HELMERT1) # Assuming grav is a constant value - - # TODO: use numpy instead of loops - for ilev in range(nlev - 1): - for icol in range(ncol): - delta_plev = abs(plev[icol, ilev] - plev[icol, ilev + 1]) - fact = 1.0 / (1.0 + vmr_h2o[icol, ilev]) - m_air = (M_DRY + M_H2O * vmr_h2o[icol, ilev]) * fact - col_dry[icol, ilev] = ( - 10.0 - * delta_plev - * AVOGAD - * fact - / (1000.0 * m_air * 100.0 * g0[icol]) - ) - return col_dry + return SWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) diff --git a/pyrte_rrtmgp/rte_solver.py b/pyrte_rrtmgp/rte_solver.py new file mode 100644 index 0000000..043f8e1 --- /dev/null +++ b/pyrte_rrtmgp/rte_solver.py @@ -0,0 +1,300 @@ +from typing import Optional + +import xarray as xr + +from pyrte_rrtmgp.constants import GAUSS_DS, GAUSS_WTS +from pyrte_rrtmgp.data_types import ProblemTypes +from pyrte_rrtmgp.kernels.rte import lw_solver_noscat, sw_solver_2stream + + +class RTESolver: + GAUSS_DS = GAUSS_DS + GAUSS_WTS = GAUSS_WTS + + def _compute_quadrature( + self, problem_ds: xr.Dataset, site_dim: str, nmus: int + ) -> tuple[xr.DataArray, xr.DataArray]: + """Compute quadrature weights and secants for radiative transfer calculations. + + Args: + problem_ds: Dataset containing the problem specification + site_dim: Name of the site dimension in the dataset + nmus: Number of quadrature angles to use + + Returns: + tuple containing: + ds (xr.DataArray): Quadrature secants (directional cosines) with dimensions + [site, gpt, n_quad_angs]. + weights (xr.DataArray): Quadrature weights with dimension [n_quad_angs]. + """ + n_quad_angs: int = nmus + ncol = problem_ds.sizes[site_dim] + ngpt = problem_ds.sizes["gpt"] + + # Extract quadrature secants for the specified number of angles + ds: xr.DataArray = xr.DataArray( + self.GAUSS_DS[0:n_quad_angs, n_quad_angs - 1], + dims=["n_quad_angs"], + coords={"n_quad_angs": range(n_quad_angs)}, + ) + # Expand dimensions to match problem size + ds = ds.expand_dims({site_dim: ncol, "gpt": ngpt}) + + # Extract quadrature weights for the specified number of angles + weights: xr.DataArray = xr.DataArray( + self.GAUSS_WTS[0:n_quad_angs, n_quad_angs - 1], + dims=["n_quad_angs"], + coords={"n_quad_angs": range(n_quad_angs)}, + ) + + return ds, weights + + def _compute_lw_fluxes_absorption( + self, problem_ds: xr.Dataset, spectrally_resolved: bool = False + ) -> xr.Dataset: + """Compute longwave fluxes for absorption-only radiative transfer. + + Args: + problem_ds: Dataset containing the problem specification with required variables: + - tau: Optical depth + - layer_source: Layer source function + - level_source: Level source function + - surface_emissivity: Surface emissivity + - surface_source: Surface source function + - surface_source_jacobian: Surface source Jacobian + Optional variables: + - incident_flux: Incident flux at top of atmosphere + - ssa: Single scattering albedo + - g: Asymmetry parameter + spectrally_resolved: If True, return spectrally resolved fluxes. + If False, return broadband fluxes. Defaults to False. + + Returns: + Dataset containing the computed fluxes: + - lw_flux_up_jacobian: Upward flux Jacobian + - lw_flux_up_broadband: Broadband upward flux + - lw_flux_down_broadband: Broadband downward flux + - lw_flux_up: Spectrally resolved upward flux + - lw_flux_down: Spectrally resolved downward flux + """ + + site_dim = problem_ds.mapping.get_dim("site") + layer_dim = problem_ds.mapping.get_dim("layer") + level_dim = problem_ds.mapping.get_dim("level") + + surface_emissivity_var = problem_ds.mapping.get_var("surface_emissivity") + + nmus: int = 1 + top_at_1: bool = problem_ds[layer_dim][0] < problem_ds[layer_dim][-1] + + if "incident_flux" not in problem_ds: + incident_flux: xr.DataArray = xr.zeros_like(problem_ds["surface_source"]) + else: + incident_flux = problem_ds["incident_flux"] + + if "gpt" not in problem_ds[surface_emissivity_var].dims: + problem_ds[surface_emissivity_var] = problem_ds[ + surface_emissivity_var + ].expand_dims({"gpt": problem_ds.sizes["gpt"]}, axis=1) + + ds, weights = self._compute_quadrature(problem_ds, site_dim, nmus) + ssa: xr.DataArray = ( + problem_ds["ssa"] if "ssa" in problem_ds else problem_ds["tau"].copy() + ) + g: xr.DataArray = ( + problem_ds["g"] if "g" in problem_ds else problem_ds["tau"].copy() + ) + + ( + solver_flux_up_jacobian, + solver_flux_up_broadband, + solver_flux_down_broadband, + solver_flux_up, + solver_flux_down, + ) = xr.apply_ufunc( + lw_solver_noscat, + problem_ds.sizes[site_dim], + problem_ds.sizes[layer_dim], + problem_ds.sizes["gpt"], + ds, + weights, + problem_ds["tau"], + ssa, + g, + problem_ds["layer_source"], + problem_ds["level_source"], + problem_ds[surface_emissivity_var], + problem_ds["surface_source"], + problem_ds["surface_source_jacobian"], + incident_flux, + kwargs={"do_broadband": not spectrally_resolved, "top_at_1": top_at_1}, + input_core_dims=[ + [], + [], + [], + [site_dim, "gpt", "n_quad_angs"], # ds + ["n_quad_angs"], # weights + [site_dim, layer_dim, "gpt"], # tau + [site_dim, layer_dim, "gpt"], # ssa + [site_dim, layer_dim, "gpt"], # g + [site_dim, layer_dim, "gpt"], # lay_source + [site_dim, level_dim, "gpt"], # lev_source + [site_dim, "gpt"], # sfc_emis + [site_dim, "gpt"], # sfc_src + [site_dim, "gpt"], # sfc_src_jac + [site_dim, "gpt"], # inc_flux + ], + output_core_dims=[ + [site_dim, level_dim], # solver_flux_up_jacobian + [site_dim, level_dim], # solver_flux_up_broadband + [site_dim, level_dim], # solver_flux_down_broadband + [site_dim, level_dim, "gpt"], # solver_flux_up + [site_dim, level_dim, "gpt"], # solver_flux_down + ], + vectorize=True, + dask="allowed", + ) + + return xr.Dataset( + { + "lw_flux_up_jacobian": solver_flux_up_jacobian, + "lw_flux_up_broadband": solver_flux_up_broadband, + "lw_flux_down_broadband": solver_flux_down_broadband, + "lw_flux_up": solver_flux_up, + "lw_flux_down": solver_flux_down, + } + ) + + def _compute_sw_fluxes( + self, problem_ds: xr.Dataset, spectrally_resolved: bool = False + ) -> xr.Dataset: + """Compute shortwave fluxes using two-stream solver. + + Args: + problem_ds: Dataset containing problem definition including optical properties, + surface properties and boundary conditions. + spectrally_resolved: If True, return spectrally resolved fluxes. + If False, return broadband fluxes. + + Returns: + Dataset containing computed shortwave fluxes: + - sw_flux_up_broadband: Upward broadband flux + - sw_flux_down_broadband: Downward broadband flux + - sw_flux_dir_broadband: Direct broadband flux + - sw_flux_up: Upward spectral flux + - sw_flux_down: Downward spectral flux + - sw_flux_dir: Direct spectral flux + """ + # Expand surface albedo dimensions if needed + if "gpt" not in problem_ds["surface_albedo_direct"].dims: + problem_ds["surface_albedo_direct"] = problem_ds[ + "surface_albedo_direct" + ].expand_dims({"gpt": problem_ds.sizes["gpt"]}, axis=1) + if "gpt" not in problem_ds["surface_albedo_diffuse"].dims: + problem_ds["surface_albedo_diffuse"] = problem_ds[ + "surface_albedo_diffuse" + ].expand_dims({"gpt": problem_ds.sizes["gpt"]}, axis=1) + + # Set diffuse incident flux + if "incident_flux_dif" not in problem_ds: + incident_flux_dif = xr.zeros_like(problem_ds["toa_source"]) + else: + incident_flux_dif = problem_ds["incident_flux_dif"] + + site_dim = problem_ds.mapping.get_dim("site") + layer_dim = problem_ds.mapping.get_dim("layer") + level_dim = problem_ds.mapping.get_dim("level") + + # Determine vertical orientation + top_at_1 = problem_ds[layer_dim][0] < problem_ds[layer_dim][-1] + + # Call solver + ( + solver_flux_up_broadband, + solver_flux_down_broadband, + solver_flux_dir_broadband, + solver_flux_up, + solver_flux_down, + solver_flux_dir, + ) = xr.apply_ufunc( + sw_solver_2stream, + problem_ds.sizes[site_dim], + problem_ds.sizes[layer_dim], + problem_ds.sizes["gpt"], + problem_ds["tau"], + problem_ds["ssa"], + problem_ds["g"], + problem_ds[problem_ds.mapping.get_var("solar_zenith_angle")], + problem_ds["surface_albedo_direct"], + problem_ds["surface_albedo_diffuse"], + problem_ds["toa_source"], + incident_flux_dif, + kwargs={"top_at_1": top_at_1, "do_broadband": not spectrally_resolved}, + input_core_dims=[ + [], + [], + [], + [site_dim, layer_dim, "gpt"], # tau + [site_dim, layer_dim, "gpt"], # ssa + [site_dim, layer_dim, "gpt"], # g + [site_dim, layer_dim], # mu0 + [site_dim, "gpt"], # sfc_alb_dir + [site_dim, "gpt"], # sfc_alb_dif + [site_dim, "gpt"], # inc_flux_dir + [site_dim, "gpt"], # inc_flux_dif + ], + output_core_dims=[ + [site_dim, level_dim, "gpt"], # solver_flux_up_broadband + [site_dim, level_dim, "gpt"], # solver_flux_down_broadband + [site_dim, level_dim, "gpt"], # solver_flux_dir_broadband + [site_dim, level_dim], # solver_flux_up + [site_dim, level_dim], # solver_flux_down + [site_dim, level_dim], # solver_flux_dir + ], + vectorize=True, + dask="allowed", + ) + + # Construct output dataset + fluxes = xr.Dataset( + { + "sw_flux_up_broadband": solver_flux_up_broadband, + "sw_flux_down_broadband": solver_flux_down_broadband, + "sw_flux_dir_broadband": solver_flux_dir_broadband, + "sw_flux_up": solver_flux_up, + "sw_flux_down": solver_flux_down, + "sw_flux_dir": solver_flux_dir, + } + ) + + return fluxes * problem_ds["solar_angle_mask"] + + def solve( + self, + problem_ds: xr.Dataset, + add_to_input: bool = True, + spectrally_resolved: bool = False, + ) -> Optional[xr.Dataset]: + """Solve radiative transfer problem based on problem type. + + Args: + problem_ds: Dataset containing problem definition and inputs + add_to_input: If True, add computed fluxes to input dataset. If False, return fluxes separately + spectrally_resolved: If True, return spectrally resolved fluxes. If False, return broadband fluxes + + Returns: + Dataset containing computed fluxes if add_to_input is False, None otherwise + """ + if problem_ds.attrs["problem_type"] == ProblemTypes.LW_ABSORPTION.value: + fluxes = self._compute_lw_fluxes_absorption(problem_ds, spectrally_resolved) + elif problem_ds.attrs["problem_type"] == ProblemTypes.SW_2STREAM.value: + fluxes = self._compute_sw_fluxes(problem_ds, spectrally_resolved) + else: + raise ValueError( + f"Unknown problem type: {problem_ds.attrs['problem_type']}" + ) + + if add_to_input: + problem_ds.assign_coords(fluxes.coords) + return None + return fluxes diff --git a/pyrte_rrtmgp/utils.py b/pyrte_rrtmgp/utils.py deleted file mode 100644 index a95d314..0000000 --- a/pyrte_rrtmgp/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -import xarray as xr - - -def get_usecols(solar_zenith_angle): - """Get the usecols values - - Args: - solar_zenith_angle (np.ndarray): Solar zenith angle in degrees - - Returns: - np.ndarray: Usecols values - """ - return solar_zenith_angle < 90.0 - 2.0 * np.spacing(90.0) - - -def compute_mu0(solar_zenith_angle, nlayer=None): - """Calculate the cosine of the solar zenith angle - - Args: - solar_zenith_angle (np.ndarray): Solar zenith angle in degrees - nlayer (int, optional): Number of layers. Defaults to None. - """ - usecol_values = get_usecols(solar_zenith_angle) - mu0 = np.where(usecol_values, np.cos(np.radians(solar_zenith_angle)), 1.0) - if nlayer is not None: - mu0 = np.stack([mu0] * nlayer).T - return mu0 - - -def compute_toa_flux(total_solar_irradiance, solar_source): - """Compute the top of atmosphere flux - - Args: - total_solar_irradiance (np.ndarray): Total solar irradiance - solar_source (np.ndarray): Solar source - - Returns: - np.ndarray: Top of atmosphere flux - """ - ncol = total_solar_irradiance.shape[0] - toa_flux = np.stack([solar_source] * ncol) - def_tsi = toa_flux.sum(axis=1) - return (toa_flux.T * (total_solar_irradiance / def_tsi)).T - - -def convert_xarray_args(func): - def wrapper(*args, **kwargs): - output_args = [] - for x in args: - if isinstance(x, xr.DataArray): - output_args.append(x.data) - else: - output_args.append(x) - for k, v in kwargs.items(): - if isinstance(v, xr.DataArray): - kwargs[k] = v.data - else: - kwargs[k] = v - return func(*output_args, **kwargs) - - return wrapper diff --git a/tests/lw_solver_test/noscat_test/lw_solver_output.npy b/tests/lw_solver_test/noscat_test/lw_solver_output.npy index fce95e2..b69e74e 100644 Binary files a/tests/lw_solver_test/noscat_test/lw_solver_output.npy and b/tests/lw_solver_test/noscat_test/lw_solver_output.npy differ diff --git a/tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc b/tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc deleted file mode 100644 index 5c2201b..0000000 Binary files a/tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc and /dev/null differ diff --git a/tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc b/tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc deleted file mode 100644 index 14ab6b0..0000000 Binary files a/tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc and /dev/null differ diff --git a/tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc b/tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc deleted file mode 100644 index fdff870..0000000 Binary files a/tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc and /dev/null differ diff --git a/tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc b/tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc deleted file mode 100644 index a45ae8a..0000000 Binary files a/tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc and /dev/null differ diff --git a/tests/test_python_frontend/test_gas_optics.py b/tests/test_python_frontend/test_gas_optics.py deleted file mode 100644 index 85cd6c5..0000000 --- a/tests/test_python_frontend/test_gas_optics.py +++ /dev/null @@ -1,194 +0,0 @@ -import os - -import numpy as np -import pytest -import xarray as xr -from pyrte_rrtmgp import rrtmgp_gas_optics -from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.kernels.rrtmgp import ( - compute_planck_source, - compute_tau_absorption, - compute_tau_rayleigh, - interpolation, -) - -from utils import convert_args_arrays - -ERROR_TOLERANCE = 1e-4 - -rte_rrtmgp_dir = download_rrtmgp_data() -clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" - -rfmip = xr.load_dataset( - f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" -) -rfmip = rfmip.sel(expt=0) # only one experiment -kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc") -kdist_sw = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc") - -rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) -rrtmgp_gas_optics_sw = kdist_sw.gas_optics.load_atmosferic_conditions(rfmip) - -# Prepare the arguments for the interpolation function -interpolation_args = [ - len(kdist["mixing_fraction"]), - kdist.gas_optics.flavors_sets, - kdist["press_ref"].values, - kdist["temp_ref"].values, - kdist["press_ref_trop"].values.item(), - kdist.gas_optics.vmr_ref, - rfmip["pres_layer"].values, - rfmip["temp_layer"].values, - kdist.gas_optics.col_gas, -] - -expected_output = ( - kdist.gas_optics._interpolated.jtemp, - kdist.gas_optics._interpolated.fmajor, - kdist.gas_optics._interpolated.fminor, - kdist.gas_optics._interpolated.col_mix, - kdist.gas_optics._interpolated.tropo, - kdist.gas_optics._interpolated.jeta, - kdist.gas_optics._interpolated.jpress, -) - - -@pytest.mark.parametrize( - "args, expected", - [(i, expected_output) for i in convert_args_arrays(interpolation_args)], -) -def test_compute_interpoaltion(args, expected): - result = interpolation(*args) - assert len(result) == len(expected) - for r, e in zip(result, expected): - assert r.shape == e.shape - assert np.isclose(r, e, atol=ERROR_TOLERANCE).all() - - -# Prepare the arguments for the compute_planck_source function -planck_source_args = [ - rfmip["temp_layer"].data, - rfmip["temp_level"].data, - rfmip["surface_temperature"].data, - kdist.gas_optics.top_at_1, - kdist.gas_optics._interpolated.fmajor, - kdist.gas_optics._interpolated.jeta, - kdist.gas_optics._interpolated.tropo, - kdist.gas_optics._interpolated.jtemp, - kdist.gas_optics._interpolated.jpress, - kdist["bnd_limits_gpt"].data.T, - kdist["plank_fraction"].data.transpose(0, 2, 1, 3), - kdist["temp_ref"].data.min(), - kdist["temp_ref"].data.max(), - kdist["totplnk"].data.T, - kdist.gas_optics.gpoint_flavor, -] - -expected_output = ( - rrtmgp_gas_optics.sfc_src, - rrtmgp_gas_optics.lay_src, - rrtmgp_gas_optics.lev_src, - rrtmgp_gas_optics.sfc_src_jac, -) - - -@pytest.mark.parametrize( - "args, expected", - [(i, expected_output) for i in convert_args_arrays(planck_source_args)], -) -def test_compute_planck_source(args, expected): - result = compute_planck_source(*args) - assert len(result) == len(expected) - for r, e in zip(result, expected): - assert r.shape == e.shape - assert np.isclose(r, e, atol=ERROR_TOLERANCE).all() - - -# Prepare the arguments for the compute_tau_absorption function -minor_gases_lower = kdist.gas_optics.extract_names(kdist["minor_gases_lower"].data) -minor_gases_upper = kdist.gas_optics.extract_names(kdist["minor_gases_upper"].data) -idx_minor_lower = kdist.gas_optics.get_idx_minor( - kdist.gas_optics.gas_names, minor_gases_lower -) -idx_minor_upper = kdist.gas_optics.get_idx_minor( - kdist.gas_optics.gas_names, minor_gases_upper -) - -scaling_gas_lower = kdist.gas_optics.extract_names(kdist["scaling_gas_lower"].data) -scaling_gas_upper = kdist.gas_optics.extract_names(kdist["scaling_gas_upper"].data) -idx_minor_scaling_lower = kdist.gas_optics.get_idx_minor( - kdist.gas_optics.gas_names, scaling_gas_lower -) -idx_minor_scaling_upper = kdist.gas_optics.get_idx_minor( - kdist.gas_optics.gas_names, scaling_gas_upper -) - -tau_absorption_args = [ - kdist.gas_optics.idx_h2o, - kdist.gas_optics.gpoint_flavor, - kdist["bnd_limits_gpt"].values.T, - kdist["kmajor"].values, - kdist["kminor_lower"].values, - kdist["kminor_upper"].values, - kdist["minor_limits_gpt_lower"].values.T, - kdist["minor_limits_gpt_upper"].values.T, - kdist["minor_scales_with_density_lower"].values.astype(bool), - kdist["minor_scales_with_density_upper"].values.astype(bool), - kdist["scale_by_complement_lower"].values.astype(bool), - kdist["scale_by_complement_upper"].values.astype(bool), - idx_minor_lower, - idx_minor_upper, - idx_minor_scaling_lower, - idx_minor_scaling_upper, - kdist["kminor_start_lower"].values, - kdist["kminor_start_upper"].values, - kdist.gas_optics._interpolated.tropo, - kdist.gas_optics._interpolated.col_mix, - kdist.gas_optics._interpolated.fmajor, - kdist.gas_optics._interpolated.fminor, - rfmip["pres_layer"].values, - rfmip["temp_layer"].values, - kdist.gas_optics.col_gas, - kdist.gas_optics._interpolated.jeta, - kdist.gas_optics._interpolated.jtemp, - kdist.gas_optics._interpolated.jpress, -] - - -@pytest.mark.parametrize( - "args, expected", - [ - (i, rrtmgp_gas_optics.tau_absorption) - for i in convert_args_arrays(tau_absorption_args) - ], -) -def test_compute_tau_absorption(args, expected): - result = compute_tau_absorption(*args) - assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all() - - -# Prepare the arguments for the compute_tau_rayleigh function -tau_rayleigh_args = [ - kdist_sw.gas_optics.gpoint_flavor, - kdist_sw["bnd_limits_gpt"].values.T, - np.stack([kdist_sw["rayl_lower"].values, kdist_sw["rayl_upper"].values], axis=-1), - kdist_sw.gas_optics.idx_h2o, - kdist_sw.gas_optics.col_gas[:, :, 0], - kdist_sw.gas_optics.col_gas, - kdist_sw.gas_optics._interpolated.fminor, - kdist_sw.gas_optics._interpolated.jeta, - kdist_sw.gas_optics._interpolated.tropo, - kdist_sw.gas_optics._interpolated.jtemp, -] - - -@pytest.mark.parametrize( - "args, expected", - [ - (i, rrtmgp_gas_optics_sw.tau_rayleigh) - for i in convert_args_arrays(tau_rayleigh_args) - ], -) -def test_compute_tau_rayleigh(args, expected): - result = compute_tau_rayleigh(*args) - assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/test_lw_solver.py b/tests/test_python_frontend/test_lw_solver.py index 878bb96..9aa08cd 100644 --- a/tests/test_python_frontend/test_lw_solver.py +++ b/tests/test_python_frontend/test_lw_solver.py @@ -2,43 +2,54 @@ import numpy as np import xarray as xr + from pyrte_rrtmgp import rrtmgp_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.kernels.rte import lw_solver_noscat +from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics +from pyrte_rrtmgp.rte_solver import RTESolver -ERROR_TOLERANCE = 1e-4 +ERROR_TOLERANCE = 1e-7 rte_rrtmgp_dir = download_rrtmgp_data() -clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" +rfmip_dir = os.path.join(rte_rrtmgp_dir, "examples", "rfmip-clear-sky") +input_dir = os.path.join(rfmip_dir, "inputs") +ref_dir = os.path.join(rfmip_dir, "reference") -rfmip = xr.load_dataset( - f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" +atmosphere = xr.load_dataset( + os.path.join( + input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" + ) ) -rfmip = rfmip.sel(expt=0) # only one experiment -kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc") +atmosphere = atmosphere.sel(expt=0) # only one experiment rlu = xr.load_dataset( - "tests/test_python_frontend/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc" + os.path.join(ref_dir, "rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), + decode_cf=False, ) -ref_flux_up = rlu.isel(expt=0)["rlu"].values +ref_flux_up = rlu.isel(expt=0)["rlu"] rld = xr.load_dataset( - "tests/test_python_frontend/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc" + os.path.join(ref_dir, "rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), + decode_cf=False, ) -ref_flux_down = rld.isel(expt=0)["rld"].values +ref_flux_down = rld.isel(expt=0)["rld"] def test_lw_solver_noscat(): - rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) - - _, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat( - tau=rrtmgp_gas_optics.tau, - lay_source=rrtmgp_gas_optics.lay_src, - lev_source=rrtmgp_gas_optics.lev_src, - sfc_emis=rfmip["surface_emissivity"].data, - sfc_src=rrtmgp_gas_optics.sfc_src, - sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac, - ) - - assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all() - assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all() + # Load gas optics with the new API + gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256) + + # Compute gas optics for the atmosphere + gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption") + + # Solve RTE with the new API + solver = RTESolver() + fluxes = solver.solve(atmosphere, add_to_input=False) + + # Compare results with reference data + assert np.isclose( + fluxes["lw_flux_up_broadband"], ref_flux_up, atol=ERROR_TOLERANCE + ).all() + assert np.isclose( + fluxes["lw_flux_down_broadband"], ref_flux_down, atol=ERROR_TOLERANCE + ).all() diff --git a/tests/test_python_frontend/test_sw_solver.py b/tests/test_python_frontend/test_sw_solver.py index 940d887..4db358b 100644 --- a/tests/test_python_frontend/test_sw_solver.py +++ b/tests/test_python_frontend/test_sw_solver.py @@ -1,66 +1,51 @@ import os import numpy as np -import pytest import xarray as xr + from pyrte_rrtmgp import rrtmgp_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.kernels.rte import sw_solver_2stream -from pyrte_rrtmgp.utils import compute_mu0, compute_toa_flux, get_usecols +from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics +from pyrte_rrtmgp.rte_solver import RTESolver -ERROR_TOLERANCE = 1e-4 +ERROR_TOLERANCE = 1e-7 rte_rrtmgp_dir = download_rrtmgp_data() -clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs" +rfmip_dir = os.path.join(rte_rrtmgp_dir, "examples", "rfmip-clear-sky") +input_dir = os.path.join(rfmip_dir, "inputs") +ref_dir = os.path.join(rfmip_dir, "reference") -rfmip = xr.load_dataset( - f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" +atmosphere = xr.load_dataset( + os.path.join( + input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" + ) ) -rfmip = rfmip.sel(expt=0) # only one experiment -kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc") +atmosphere = atmosphere.sel(expt=0) rsu = xr.load_dataset( - "tests/test_python_frontend/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc" + os.path.join(ref_dir, "rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), + decode_cf=False, ) -ref_flux_up = rsu.isel(expt=0)["rsu"].values +ref_flux_up = rsu.isel(expt=0)["rsu"] rsd = xr.load_dataset( - "tests/test_python_frontend/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc" + os.path.join(ref_dir, "rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), + decode_cf=False, ) -ref_flux_down = rsd.isel(expt=0)["rsd"].values +ref_flux_down = rsd.isel(expt=0)["rsd"] def test_sw_solver_noscat(): - gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip) - - surface_albedo = rfmip["surface_albedo"].data - total_solar_irradiance = rfmip["total_solar_irradiance"].data + # Load gas optics with new API + gas_optics_sw = load_gas_optics(gas_optics_file=GasOpticsFiles.SW_G224) - nlayer = len(rfmip["layer"]) - mu0 = compute_mu0(rfmip["solar_zenith_angle"].values, nlayer=nlayer) - - toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source) - - _, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream( - kdist.gas_optics.top_at_1, - gas_optics.tau, - gas_optics.ssa, - gas_optics.g, - mu0, - sfc_alb_dir=surface_albedo, - sfc_alb_dif=surface_albedo, - inc_flux_dir=toa_flux, - inc_flux_dif=None, - has_dif_bc=False, - do_broadband=True, - ) + # Load and compute gas optics with atmosphere data + gas_optics_sw.gas_optics.compute(atmosphere, problem_type="two-stream") - # RTE will fail if passed solar zenith angles greater than 90 degree. We replace any with - # nighttime columns with a default solar zenith angle. We'll mask these out later, of - # course, but this gives us more work and so a better measure of timing. - usecol = get_usecols(rfmip["solar_zenith_angle"].values) - solver_flux_up = solver_flux_up * usecol[:, np.newaxis] - solver_flux_down = solver_flux_down * usecol[:, np.newaxis] + # Solve using new rte_solve function + solver = RTESolver() + fluxes = solver.solve(atmosphere, add_to_input=False) - assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all() - assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all() + # Compare results + assert np.isclose(fluxes["sw_flux_up"], ref_flux_up, atol=ERROR_TOLERANCE).all() + assert np.isclose(fluxes["sw_flux_down"], ref_flux_down, atol=ERROR_TOLERANCE).all()