diff --git a/auxiliary_tools/cdat_regression_testing/665-streamflow/regression_test_png.ipynb b/auxiliary_tools/cdat_regression_testing/665-streamflow/regression_test_png.ipynb new file mode 100644 index 000000000..38bf60542 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/665-streamflow/regression_test_png.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CDAT Migration Regression Testing Notebook (`.png` files)\n", + "\n", + "This notebook is used to perform regression testing between the development and\n", + "production versions of a diagnostic set.\n", + "\n", + "## How to use\n", + "\n", + "PREREQUISITE: The diagnostic set's netCDF stored in `.json` files in two directories\n", + "(dev and `main` branches).\n", + "\n", + "1. Make a copy of this notebook under `auxiliary_tools/cdat_regression_testing/`.\n", + "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" xarray netcdf4 dask pandas matplotlib-base ipykernel`\n", + "3. Run `mamba activate cdat_regression_test`\n", + "4. Update `SET_DIR` and `SET_NAME` in the copy of your notebook.\n", + "5. Run all cells IN ORDER.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "from auxiliary_tools.cdat_regression_testing.utils import get_image_diffs\n", + "\n", + "SET_NAME = \"streamflow\"\n", + "SET_DIR = \"665-streamflow\"\n", + "\n", + "DEV_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{SET_DIR}/{SET_NAME}/**\"\n", + "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.png\"))\n", + "DEV_NUM_FILES = len(DEV_GLOB)\n", + "\n", + "MAIN_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/{SET_NAME}/**\"\n", + "MAIN_GLOB = sorted(glob.glob(MAIN_PATH + \"/*.png\"))\n", + "MAIN_NUM_FILES = len(MAIN_GLOB)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "def _check_if_files_found():\n", + " if DEV_NUM_FILES == 0 or MAIN_NUM_FILES == 0:\n", + " raise IOError(\n", + " \"No files found at DEV_PATH and/or MAIN_PATH. \"\n", + " f\"Please check {DEV_PATH} and {MAIN_PATH}.\"\n", + " )\n", + "\n", + "\n", + "def _check_if_matching_filecount():\n", + " if DEV_NUM_FILES != MAIN_NUM_FILES:\n", + " raise IOError(\n", + " \"Number of files do not match at DEV_PATH and MAIN_PATH \"\n", + " f\"({DEV_NUM_FILES} vs. {MAIN_NUM_FILES}).\"\n", + " )\n", + "\n", + " print(f\"Matching file count ({DEV_NUM_FILES} and {MAIN_NUM_FILES}).\")\n", + "\n", + "\n", + "def _check_if_missing_files():\n", + " missing_count = 0\n", + "\n", + " for fp_main in MAIN_GLOB:\n", + " fp_dev = fp_main.replace(SET_DIR, \"main\")\n", + "\n", + " if fp_dev not in MAIN_GLOB:\n", + " print(f\"No production file found to compare with {fp_dev}!\")\n", + " missing_count += 1\n", + "\n", + " for fp_dev in DEV_GLOB:\n", + " fp_main = fp_main.replace(\"main\", SET_DIR)\n", + "\n", + " if fp_main not in DEV_GLOB:\n", + " print(f\"No development file found to compare with {fp_main}!\")\n", + " missing_count += 1\n", + "\n", + " print(f\"Number of files missing: {missing_count}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Check for matching and equal number of files\n" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "_check_if_files_found()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of files missing: 0\n" + ] + } + ], + "source": [ + "_check_if_missing_files()" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Matching file count (3 and 3).\n" + ] + } + ], + "source": [ + "_check_if_matching_filecount()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 Compare the plots between branches\n", + "\n", + "- Compare \"ref\" and \"test\" files\n", + "- \"diff\" files are ignored because getting relative diffs for these does not make sense (relative diff will be above tolerance)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/annual_map.png\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/665-streamflow/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/annual_map.png\n", + " * Plots are identical\n", + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/annual_scatter.png\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/665-streamflow/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/annual_scatter.png\n", + " * Plots are identical\n", + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/seasonality_map.png\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/665-streamflow/streamflow/RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM/seasonality_map.png\n", + " * Plots are identical\n" + ] + } + ], + "source": [ + "for main_path, dev_path in zip(MAIN_GLOB, DEV_GLOB):\n", + " print(\"Comparing:\")\n", + " print(f\" * {main_path}\")\n", + " print(f\" * {dev_path}\")\n", + "\n", + " get_image_diffs(dev_path, main_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results\n", + "\n", + "All plots are identical\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cdat_regression_test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/auxiliary_tools/cdat_regression_testing/665-streamflow/run.cfg b/auxiliary_tools/cdat_regression_testing/665-streamflow/run.cfg new file mode 100644 index 000000000..52099f57e --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/665-streamflow/run.cfg @@ -0,0 +1,8 @@ +[#] +sets = ["streamflow"] +case_id = "RIVER_DISCHARGE_OVER_LAND_LIQ_GSIM" +variables = ["RIVER_DISCHARGE_OVER_LAND_LIQ"] +ref_name = "GSIM" +reference_name = "GSIM monthly streamflow" +regions = ["global"] +seasons = ["ANN"] diff --git a/auxiliary_tools/cdat_regression_testing/665-streamflow/run_script.py b/auxiliary_tools/cdat_regression_testing/665-streamflow/run_script.py new file mode 100644 index 000000000..20872c045 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/665-streamflow/run_script.py @@ -0,0 +1,12 @@ +# %% +# python -m auxiliary_tools.cdat_regression_testing.665-streamflow.run_script +from auxiliary_tools.cdat_regression_testing.base_run_script import run_set + +SET_NAME = "streamflow" +SET_DIR = "665-streamflow" +CFG_PATH: str | None = None +# CFG_PATH: str | None = "/global/u2/v/vo13/E3SM-Project/e3sm_diags/auxiliary_tools/cdat_regression_testing/665-streamflow/run.cfg" +MULTIPROCESSING = True + +# %% +run_set(SET_NAME, SET_DIR, CFG_PATH, MULTIPROCESSING) diff --git a/e3sm_diags/driver/streamflow_driver.py b/e3sm_diags/driver/streamflow_driver.py index 13810780b..9be199f10 100644 --- a/e3sm_diags/driver/streamflow_driver.py +++ b/e3sm_diags/driver/streamflow_driver.py @@ -1,21 +1,18 @@ from __future__ import annotations import csv -import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Tuple -import cdms2 -import cdutil -import numpy +import numpy as np import scipy.io +import xarray as xr +import xcdat as xc -from e3sm_diags.driver import utils +from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.logger import custom_logger -from e3sm_diags.plot.cartopy.streamflow_plot import ( - plot_annual_map, - plot_annual_scatter, - plot_seasonality_map, -) +from e3sm_diags.plot.streamflow_plot_map import plot_annual_map +from e3sm_diags.plot.streamflow_plot_scatter import plot_annual_scatter +from e3sm_diags.plot.streamflow_plot_seasonality import plot_seasonality_map logger = custom_logger(__name__) @@ -23,405 +20,287 @@ from e3sm_diags.parameter.streamflow_parameter import StreamflowParameter -def get_drainage_area_error( - radius, resolution, lon_ref, lat_ref, area_upstream, area_ref -): - k_bound = len(range(-radius, radius + 1)) - k_bound *= k_bound - area_test = numpy.zeros((k_bound, 1)) - error_test = numpy.zeros((k_bound, 1)) - lat_lon_test = numpy.zeros((k_bound, 2)) - k = 0 - for i in range(-radius, radius + 1): - for j in range(-radius, radius + 1): - x = int( - 1 + ((lon_ref + j * resolution) - (-180 + resolution / 2)) / resolution - ) - y = int( - 1 + ((lat_ref + i * resolution) - (-90 + resolution / 2)) / resolution - ) - area_test[k] = area_upstream[x - 1, y - 1] / 1000000 - error_test[k] = numpy.abs(area_test[k] - area_ref) / area_ref - lat_lon_test[k, 0] = lat_ref + i * resolution - lat_lon_test[k, 1] = lon_ref + j * resolution - k += 1 - # The id of the center grid in the searching area - center_id = (k_bound - 1) / 2 +# Resolution of MOSART output. +RESOLUTION = 0.5 +# Search radius (number of grids around the center point). +SEARCH_RADIUS = 1 +# The max area error (percent) for all plots. +MAX_AREA_ERROR = 20 - lat_lon_ref = [lat_ref, lon_ref] - drainage_area_error = error_test[int(center_id)] - return drainage_area_error, lat_lon_ref +def run_diag(parameter: StreamflowParameter) -> StreamflowParameter: + """Get metrics for the streamflow set. -def get_seasonality(monthly): - monthly = monthly.astype(numpy.float64) - # See https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2018MS001603 Equations 1 and 2 - if monthly.shape[0] != 12: - raise Exception( - "monthly.shape={} does not include 12 months".format(monthly.shape) - ) - num_years = monthly.shape[1] - p_k = numpy.zeros((12, 1)) - # The total streamflow for each year (sum of Q_ij in the denominator of Equation 1, for all j) - # 1 x num_years - total_streamflow = numpy.sum(monthly, axis=0) - for month in range(12): - # The streamflow for this month in each year (Q_ij in the numerator of Equation 1, for all j) - # 1 x num_years - streamflow_month_all_years = monthly[month, :] - # Proportion that this month contributes to streamflow that year. - # 1 x num_years - # For all i, divide streamflow_month_all_years[i] by total_streamflow[i] - streamflow_proportion = numpy.divide( - streamflow_month_all_years, total_streamflow + Parameters + ---------- + parameter : StreamflowParameter + The parameter for the diagnostic. + + Returns + ------- + StreamflowParameter + The parameter for the diagnostic with the result (completed or failed). + """ + gauges, is_ref_mat_file = _get_gauges(parameter) + + for var_key in parameter.variables: + logger.info(f"Variable: {var_key}") + + test_array, area_upstream = _get_test_data_and_area_upstream(parameter, var_key) + ref_array = _get_ref_data(parameter, var_key, is_ref_mat_file) + + export_data = _generate_export_data( + parameter, gauges, test_array, ref_array, area_upstream, is_ref_mat_file ) - # The sum is the sum over j in Equation 1. - # Dividing the sum of proportions by num_years gives the *average* proportion of annual streamflow during - # this month. - # Multiplying by 12 makes it so that Pk_i (`p_k[month]`) will be 1 if all months have equal streamflow and - # 12 if all streamflow occurs in one month. - # These steps produce the 12/n factor in Equation 1. - p_k[month] = numpy.nansum(streamflow_proportion) * 12 / num_years - # From Equation 2 - seasonality_index = numpy.max(p_k) - # `p_k == numpy.max(p_k)` produces a Boolean matrix, True if the value (i.e., streamflow) is the max value. - # `np.where(p_k == numpy.max(p_k))` produces the indices (i.e., months) where the max value is reached. - peak_month = numpy.where(p_k == numpy.max(p_k))[0] - # If more than one month has peak streamflow, simply define the peak month as the first one of the peak months. - # Month 0 is January, Month 1 is February, and so on. - peak_month = peak_month[0] - return seasonality_index, peak_month + if parameter.test_title == "": + parameter.test_title = parameter.test_name_yrs + if parameter.reference_title == "": + parameter.reference_title = parameter.ref_name_yrs + + # Plot the original ref and test (not regridded versions). + plot_seasonality_map(parameter, export_data) + plot_annual_map(parameter, export_data) + plot_annual_scatter(parameter, export_data) + + return parameter + + +def _get_gauges(parameter: StreamflowParameter) -> Tuple[np.ndarray, bool]: + """Get the gauges. + + Assume `model_vs_model` is an `nc` file and `model_vs_obs` is an `mat` file. + If `model_vs_obs`, the metadata file of GSIM that has observed gauge lat lon + and drainage area. This file includes 25765 gauges, which is a subset of the + entire dataset (30959 gauges). The removed gauges are associated with very + small drainage area (<1km2), which is not meaningful to be included. + + Parameters + ---------- + parameter : StreamflowParameter + The parameter. + + Returns + ------- + Tuple[np.ndarray, bool] + A tuple containing the gauges array and a boolean representing whether + the reference file is a mat file (True) or not (False). + + Raises + ------ + RuntimeError + Non-GSIM reference file specified without using parameter `.gauges_path` + attribute. + RuntimeError + Parameter run type is not supported. + """ + ref_path = parameter.reference_data_path.rstrip("/") -def run_diag(parameter: StreamflowParameter) -> StreamflowParameter: - # Assume `model` will always be a `nc` file. - # Assume `obs` will always be a `mat` file. - using_test_mat_file = False if parameter.run_type == "model_vs_model": - using_ref_mat_file = False + is_ref_mat_file = False + if parameter.gauges_path is None: - raise Exception( - "To use a non-GSIM reference, please specify streamflow_param.gauges_path. This might be {}/{}".format( - parameter.reference_data_path.rstrip("/"), - "GSIM/GSIM_catchment_characteristics_all_1km2.csv", - ) + raise RuntimeError( + "To use a non-GSIM reference, please specify streamflow_param.gauges_path. " + f"This might be {ref_path}/GSIM/GSIM_catchment_characteristics_all_1km2.csv" ) + else: gauges_path = parameter.gauges_path elif parameter.run_type == "model_vs_obs": - using_ref_mat_file = True - # The metadata file of GSIM that has observed gauge lat lon and drainage area - # This file includes 25765 gauges, which is a subset of the entire - # dataset (30959 gauges). The removed gauges are associated with very - # small drainage area (<1km2), which is not meaningful to be included. - gauges_path = "{}/GSIM/GSIM_catchment_characteristics_all_1km2.csv".format( - parameter.reference_data_path.rstrip("/") - ) + is_ref_mat_file = True + gauges_path = f"{ref_path}/GSIM/GSIM_catchment_characteristics_all_1km2.csv" else: - raise Exception( - "parameter.run_type={} not supported".format(parameter.run_type) - ) + raise RuntimeError(f"parameter.run_type={parameter.run_type} not supported") # Set path to the gauge metadata with open(gauges_path) as gauges_file: gauges_list = list(csv.reader(gauges_file)) + # Remove headers gauges_list.pop(0) - gauges = numpy.array(gauges_list) - if parameter.print_statements: - logger.info("gauges.shape={}".format(gauges.shape)) - - variables = parameter.variables - for var in variables: - ref_array = setup_ref(parameter, var, using_ref_mat_file) - area_upstream, test_array = setup_test(parameter, var, using_test_mat_file) - - # Resolution of MOSART output - resolution = 0.5 - # Search radius (number of grids around the center point) - radius = 1 - bins = numpy.floor(gauges[:, 7:9].astype(numpy.float64) / resolution) - # Move the ref lat lon to grid center - lat_lon = (bins + 0.5) * resolution - if parameter.print_statements: - logger.info("lat_lon.shape={}".format(lat_lon.shape)) - - # Define the export matrix - export = generate_export( - parameter, - lat_lon, - area_upstream, - gauges, - radius, - resolution, - using_ref_mat_file, - ref_array, - test_array, - ) + gauges = np.array(gauges_list) - # Remove the gauges with nan flow - # `export[:,0]` => get first column of export - # `numpy.isnan(export[:,0])` => Boolean column, True if value in export[x,0] is nan - # `export[numpy.isnan(export[:,0]),:]` => rows of `export` where the Boolean column was True - # Gauges will thus only be plotted if they have a non-nan value for both test and ref. - if parameter.print_statements: - logger.info( - "export.shape before removing ref nan means={}".format(export.shape) - ) - export = export[~numpy.isnan(export[:, 0]), :] - if parameter.print_statements: - logger.info( - "export.shape before removing test nan means={}".format(export.shape) - ) - export = export[~numpy.isnan(export[:, 1]), :] - if parameter.print_statements: - logger.info("export.shape after both nan removals={}".format(export.shape)) - - if area_upstream is not None: - # Set the max area error (percent) for all plots - max_area_error = 20 - # `export[:,2]` gives the third column of `export` - # `export[:,2]<=max_area_error` gives a Boolean array, - # `True` if the value in the third column of `export` is `<= max_area_error` - # `export[export[:,2]<=max_area_error,:]` is `export` with only the rows where the above is `True`. - export = export[export[:, 2] <= max_area_error, :] - if parameter.print_statements: - logger.info( - "export.shape after max_area_error cut={}".format(export.shape) - ) + return gauges, is_ref_mat_file - if parameter.print_statements: - logger.info("Variable: {}".format(var)) - if parameter.test_title == "": - parameter.test_title = parameter.test_name_yrs - if parameter.reference_title == "": - parameter.reference_title = parameter.ref_name_yrs +def _get_test_data_and_area_upstream( + parameter: StreamflowParameter, var_key: str +) -> Tuple[np.ndarray, np.ndarray]: + """Set up the test data. - # Seasonality - # Plot original ref and test, not regridded versions. - plot_seasonality_map(export, parameter) + Parameters + ---------- + parameter : StreamflowParameter + The parameter. + var_key : str + The key of the variable. - # Bias between test and ref as a percentage - # (Relative error as a percentage) - # 100*((annual_mean_test - annual_mean_ref) / annual_mean_ref) - bias = 100 * ((export[:, 1] - export[:, 0]) / export[:, 0]) - plot_annual_map(export, bias, parameter) + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The test data and area upstream. + """ + test_ds = Dataset(parameter, data_type="test") + parameter.test_name_yrs = test_ds.get_name_yrs_attr() - # Scatterplot - # These arrays will have fewer entries than the original `export` matrix - # because of the nan removal steps. - xs = export[:, 0] - ys = export[:, 1] - zs = export[:, 2] - plot_annual_scatter(xs, ys, zs, parameter) + ds_test = test_ds.get_time_series_dataset(var_key) + + test_array = _get_var_data(ds_test, var_key) + + areatotal2 = ds_test["areatotal2"].values + area_upstream = np.transpose(areatotal2, (1, 0)).astype(np.float64) + + return test_array, area_upstream - return parameter +def _get_ref_data( + parameter: StreamflowParameter, var_key: str, is_ref_mat_file: bool +) -> np.ndarray: + """Set up the reference data. -def setup_ref(parameter, var, using_ref_mat_file): - if not using_ref_mat_file: - ref_data = utils.dataset.Dataset(parameter, ref=True) - parameter.ref_name_yrs = utils.general.get_name_and_yrs(parameter, ref_data) - ref_data_ts = ref_data.get_timeseries_variable(var) - var_array = ref_data_ts(cdutil.region.domain(latitude=(-90.0, 90, "ccb"))) - if parameter.print_statements: - logger.info("ref var original dimensions={}".format(var_array.shape)) - var_transposed = numpy.transpose(var_array, (2, 1, 0)) - if parameter.print_statements: - logger.info("ref var transposed dimensions={}".format(var_transposed.shape)) - ref_array = var_transposed.astype(numpy.float64) + Parameters + ---------- + parameter : StreamflowParameter + The parameter. + var_key : str + The key of the variable. + is_ref_mat_file : bool + If the reference data is from a mat file (True) or not (False). + + Returns + ------- + np.ndarray + The reference data. + """ + ref_ds = Dataset(parameter, data_type="ref") + + if not is_ref_mat_file: + parameter.ref_name_yrs = ref_ds.get_name_yrs_attr() + + ds_ref = ref_ds.get_time_series_dataset(var_key) + ref_array = _get_var_data(ds_ref, var_key) else: # Load the observed streamflow dataset (GSIM) # the data has been reorganized to a 1380 * 30961 matrix. 1380 is the month # number from 1901.1 to 2015.12. 30961 include two columns for year and month plus # streamflow at 30959 gauge locations reported by GSIM - ref_mat_file = "{}/GSIM/GSIM_198601_199512.mat".format( - parameter.reference_data_path.rstrip("/") - ) - if parameter.short_ref_name != "": - ref_name = parameter.short_ref_name - elif parameter.reference_name != "": - # parameter.ref_name is used to search though the reference data directories. - # parameter.reference_name is printed above ref plots. - ref_name = parameter.reference_name - else: - ref_name = "GSIM" - parameter.ref_name_yrs = "{} ({}-{})".format( - ref_name, parameter.ref_start_yr, parameter.ref_end_yr - ) + ref_path = parameter.reference_data_path.rstrip("/") + ref_mat_file = f"{ref_path}/GSIM/GSIM_198601_199512.mat" + parameter.ref_name_yrs = ref_ds.get_name_yrs_attr(default_name="GSIM") + ref_mat = scipy.io.loadmat(ref_mat_file) - ref_array = ref_mat["GSIM"].astype(numpy.float64) - if parameter.print_statements: - # GSIM: 1380 x 30961 - # wrmflow: 720 x 360 x 360 - logger.info("ref_array.shape={}".format(ref_array.shape)) + ref_array = ref_mat["GSIM"].astype(np.float64) return ref_array -def setup_test(parameter, var, using_test_mat_file): - # Load E3SM simulated streamflow dataset - if not using_test_mat_file: - # `Dataset` will take the time slice from test_start_yr to test_end_yr - test_data = utils.dataset.Dataset(parameter, test=True) - parameter.test_name_yrs = utils.general.get_name_and_yrs(parameter, test_data) - test_data_ts = test_data.get_timeseries_variable(var) - var_array = test_data_ts(cdutil.region.domain(latitude=(-90.0, 90, "ccb"))) - if parameter.print_statements: - logger.info("test var original dimensions={}".format(var_array.shape)) - var_transposed = numpy.transpose(var_array, (2, 1, 0)) - if parameter.print_statements: - logger.info( - "test var transposed dimensions={}".format(var_transposed.shape) - ) - test_array = var_transposed.astype(numpy.float64) - areatotal2 = test_data.get_static_variable("areatotal2", var) - area_upstream = numpy.transpose(areatotal2, (1, 0)).astype(numpy.float64) - if parameter.print_statements: - logger.info("area_upstream dimensions={}".format(area_upstream.shape)) - else: - area_upstream, test_array = debugging_case_setup_test(parameter) - if parameter.print_statements: - # For edison: 720x360x600 - logger.info("test_array.shape={}".format(test_array.shape)) - if isinstance(area_upstream, cdms2.tvariable.TransientVariable): - area_upstream = area_upstream.getValue() - - return area_upstream, test_array - - -def debugging_case_setup_test(parameter): - # This block is only for debugging -- i.e., when testing with a `mat` file. - files_in_test_data_path = os.listdir(parameter.test_data_path) - mat_files = list( - filter( - lambda file_name: file_name.endswith(".mat"), - files_in_test_data_path, - ) - ) - if len(mat_files) == 1: - mat_file = mat_files[0] - elif len(mat_files) > 1: - raise Exception( - "More than one .mat file in parameter.test_data_path={}".format( - parameter.test_data_path - ) - ) - else: - raise Exception( - "No .mat file in parameter.test_data_path={}".format( - parameter.test_data_path - ) - ) - test_mat_file = "{}/{}".format(parameter.test_data_path.rstrip("/"), mat_file) - parameter.test_name_yrs = "{} ({}-{})".format( - test_mat_file, parameter.test_start_yr, parameter.test_end_yr - ) - data_mat = scipy.io.loadmat(test_mat_file) - e3sm_flow = get_e3sm_flow(parameter, data_mat) - area_upstream = get_area_upstream(parameter, e3sm_flow) - test_array = get_test_array(parameter, e3sm_flow) - return area_upstream, test_array - - -def get_e3sm_flow(parameter, data_mat): - if "E3SMflow" in data_mat.keys(): - # `edison` file uses this block - e3sm_flow = data_mat["E3SMflow"] - if parameter.print_statements: - logger.info('e3sm_flow = data_mat["E3SMflow"]') - else: - # `test` file uses this block - e3sm_flow = data_mat - if parameter.print_statements: - logger.info("e3sm_flow = data_mat") - return e3sm_flow - - -def get_area_upstream(parameter, e3sm_flow): - try: - if e3sm_flow["uparea"].shape == (1, 1): - # `edison` file uses this block - area_upstream = e3sm_flow["uparea"][0][0].astype(numpy.float64) - if parameter.print_statements: - logger.info('e3sm_flow["uparea"] was indexed into for later use') - else: - area_upstream = e3sm_flow["uparea"].astype(numpy.float64) - if parameter.print_statements: - logger.info('e3sm_flow["uparea"] will be used') - except KeyError: - # `test` file uses this block - area_upstream = None - if parameter.print_statements: - logger.warning("WARNING: uparea not found and will thus not be used") - return area_upstream - - -def get_test_array(parameter, e3sm_flow): - if e3sm_flow["wrmflow"].shape == (1, 1): - # `edison` file uses this block - test_array = e3sm_flow["wrmflow"][0][0].astype(numpy.float64) - if parameter.print_statements: - logger.info('e3sm_flow["wrmflow"] was indexed into for later use') - else: - # `test` file uses this block - test_array = e3sm_flow["wrmflow"].astype(numpy.float64) - if parameter.print_statements: - logger.info('e3sm_flow["wrmflow"] will be used') +def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray: + """Get the variable data then subset on latitude and transpose. + + Parameters + ---------- + ds : xr.Dataset + The dataset object. + var_key : str + The key of the variable. + + Returns + ------- + np.ndarray + The variable data. + """ + da_var = ds[var_key].copy() + lat_dim = xc.get_dim_keys(da_var, axis="Y") + + da_var_reg = da_var.sel({lat_dim: slice(-90, 90)}) + var_transposed = np.transpose(da_var_reg.values, (2, 1, 0)) + test_array = var_transposed.astype(np.float64) + return test_array -def generate_export( - parameter, - lat_lon, - area_upstream, - gauges, - radius, - resolution, - using_ref_mat_file, - ref_array, - test_array, -): - # Annual mean of test, annual mean of ref, error for area, lat, lon - export = numpy.zeros((lat_lon.shape[0], 9)) - if parameter.print_statements: - logger.info("export.shape={}".format(export.shape)) +def _generate_export_data( + parameter: StreamflowParameter, + gauges: np.ndarray, + test_array: np.ndarray, + ref_array: np.ndarray, + area_upstream: np.ndarray, + is_ref_mat_file: bool, +) -> np.ndarray: + """Generate the export data. + + Parameters + ---------- + parameter : StreamflowParameter + The parameter. + gauges : np.ndarray + The gauges. + test_array : np.ndarray + The test data. + ref_array : np.ndarray + The reference data. + area_upstream : np.ndarray + The area upstream. + is_ref_mat_file : bool + If the reference data is a mat file or not. + + Returns + ------- + np.ndarray + The export data as a 2D array with the format: + - col 0: year + - col 1: month + - [i, 0]: annual mean for reference data + - [i, 1]: annual mean for test data + - [i, 2]: percentage of drainage area bias + - [i, 3]: seasonality index of reference data + - [i, 4]: peak month flow of reference data + - [i, 5]: seasonality index of test data + - [i, 6]: peak month flow of test data + - [i, 7:9]: Lat lon index of reference data. + Notes + ----- + TODO: This function should be refactored to make it readable and + maintainable. The number of code comments suggest that the code is not + understandable and needs to be explained line by line. + """ + # Center the reference lat lon to the grid center and use it + # to create the shape for the export data array. + bins = np.floor(gauges[:, 7:9].astype(np.float64) / RESOLUTION) + lat_lon = (bins + 0.5) * RESOLUTION + + export_data = np.zeros((lat_lon.shape[0], 9)) + for i in range(lat_lon.shape[0]): - if parameter.print_statements and (i % 1000 == 0): - logger.info("On gauge #{}".format(i)) if parameter.max_num_gauges and i >= parameter.max_num_gauges: break + lat_ref = lat_lon[i, 1] lon_ref = lat_lon[i, 0] + # Estimated drainage area (km^2) from ref - area_ref = gauges[i, 13].astype(numpy.float64) - - if area_upstream is not None: - drainage_area_error, lat_lon_ref = get_drainage_area_error( - radius, - resolution, - lon_ref, - lat_ref, - area_upstream, - area_ref, - ) - else: - # Use the center location - lat_lon_ref = [lat_ref, lon_ref] - if using_ref_mat_file: - origin_id = gauges[i, 1].astype(numpy.int64) + area_ref = gauges[i, 13].astype(np.float64) + drainage_area_error, lat_lon_ref = _get_drainage_area_error( + lon_ref, + lat_ref, + area_upstream, + area_ref, + ) + + if is_ref_mat_file: + origin_id = gauges[i, 1].astype(np.int64) # Column 0 -- year # Column 1 -- month # Column origin_id + 1 -- the ref streamflow from gauge with the corresponding origin_id extracted = ref_array[:, [0, 1, origin_id + 1]] - monthly_mean = numpy.zeros((12, 1)) + numpy.nan + monthly_mean = np.zeros((12, 1)) + np.nan # For GSIM, shape is (1380,) month_array = extracted[:, 1] for month in range(12): # Add 1 to month to account for the months being 1-indexed month_array_boolean = month_array == month + 1 - s = numpy.sum(month_array_boolean) + s = np.sum(month_array_boolean) if s > 0: # `extracted[:,1]`: for all x, examine `extracted[x,1]` # `extracted[:,1] == m`: Boolean array where 0 means the item in position [x,1] is NOT m, @@ -438,97 +317,232 @@ def generate_export( # [1]] # a[a[:,1] == 2, 2]: [[3], # [3]] - monthly_mean[month] = numpy.nanmean( - extracted[month_array_boolean, 2] - ) + monthly_mean[month] = np.nanmean(extracted[month_array_boolean, 2]) # This is ref annual mean streamflow - annual_mean_ref = numpy.mean(monthly_mean) - if using_ref_mat_file and numpy.isnan(annual_mean_ref): + annual_mean_ref = np.mean(monthly_mean) + if is_ref_mat_file and np.isnan(annual_mean_ref): # All elements of row i will be nan - export[i, :] = numpy.nan + export_data[i, :] = np.nan else: - if using_ref_mat_file: + if is_ref_mat_file: # Reshape extracted[:,2] into a 12 x ? matrix; -1 means to # calculate the size of the missing dimension. - # Note that `numpy.reshape(extracted[:, 2], (12,-1))` will not work. + # Note that `np.reshape(extracted[:, 2], (12,-1))` will not work. # We do need to go from (12n x 1) to (12 x n). # `reshape` alone would make the first row [January of year 1, February of year 1,...] # (i.e., 12 sequential rows with n entries) # We actually want the first row to be [January of year 1, January of year 2,...] # (i.e., n sequential columns with 12 entries) # So, we use `reshape` to slice into n segments of length 12 and then we `transpose`. - mmat = numpy.transpose(numpy.reshape(extracted[:, 2], (-1, 12))) - mmat_id = numpy.sum(mmat, axis=0).transpose() - if numpy.sum(~numpy.isnan(mmat_id), axis=0) > 0: + mmat = np.transpose(np.reshape(extracted[:, 2], (-1, 12))) + mmat_id = np.sum(mmat, axis=0).transpose() + if np.sum(~np.isnan(mmat_id), axis=0) > 0: # There's at least one year of record - monthly = mmat[:, ~numpy.isnan(mmat_id)] + monthly = mmat[:, ~np.isnan(mmat_id)] else: monthly = monthly_mean - seasonality_index_ref, peak_month_ref = get_seasonality(monthly) + seasonality_index_ref, peak_month_ref = _get_seasonality(monthly) else: ref_lon = int( - 1 + (lat_lon_ref[1] - (-180 + resolution / 2)) / resolution + 1 + (lat_lon_ref[1] - (-180 + RESOLUTION / 2)) / RESOLUTION ) ref_lat = int( - 1 + (lat_lon_ref[0] - (-90 + resolution / 2)) / resolution + 1 + (lat_lon_ref[0] - (-90 + RESOLUTION / 2)) / RESOLUTION ) - ref = numpy.squeeze(ref_array[ref_lon - 1, ref_lat - 1, :]) - # Note that `numpy.reshape(ref, (12,-1))` will not work. + ref = np.squeeze(ref_array[ref_lon - 1, ref_lat - 1, :]) + # Note that `np.reshape(ref, (12,-1))` will not work. # We do need to go from (12n x 1) to (12 x n). # `reshape` alone would make the first row [January of year 1, February of year 1,...] # (i.e., 12 sequential rows with n entries) # We actually want the first row to be [January of year 1, January of year 2,...] # (i.e., n sequential columns with 12 entries) # So, we use `reshape` to slice into n segments of length 12 and then we `transpose`. - mmat = numpy.transpose(numpy.reshape(ref, (-1, 12))) - monthly_mean_ref = numpy.nanmean(mmat, axis=1) - annual_mean_ref = numpy.mean(monthly_mean_ref) - if numpy.isnan(annual_mean_ref) == 1: + mmat = np.transpose(np.reshape(ref, (-1, 12))) + monthly_mean_ref = np.nanmean(mmat, axis=1) + annual_mean_ref = np.mean(monthly_mean_ref) + if np.isnan(annual_mean_ref) == 1: # The identified grid is in the ocean - monthly = numpy.ones((12, 1)) + monthly = np.ones((12, 1)) else: monthly = mmat - if isinstance(monthly, cdms2.tvariable.TransientVariable): - monthly = monthly.getValue() + seasonality_index_ref, peak_month_ref = _get_seasonality(monthly) - seasonality_index_ref, peak_month_ref = get_seasonality(monthly) - - test_lon = int(1 + (lat_lon_ref[1] - (-180 + resolution / 2)) / resolution) - test_lat = int(1 + (lat_lon_ref[0] - (-90 + resolution / 2)) / resolution) + test_lon = int(1 + (lat_lon_ref[1] - (-180 + RESOLUTION / 2)) / RESOLUTION) + test_lat = int(1 + (lat_lon_ref[0] - (-90 + RESOLUTION / 2)) / RESOLUTION) # For edison: 600x1 - test = numpy.squeeze(test_array[test_lon - 1, test_lat - 1, :]) + test = np.squeeze(test_array[test_lon - 1, test_lat - 1, :]) # For edison: 12x50 - # Note that `numpy.reshape(test, (12,-1))` will not work. + # Note that `np.reshape(test, (12,-1))` will not work. # We do need to go from (12n x 1) to (12 x n). # `reshape` alone would make the first row [January of year 1, February of year 1,...] # (i.e., 12 sequential rows with n entries) # We actually want the first row to be [January of year 1, January of year 2,...] # (i.e., n sequential columns with 12 entries) # So, we use `reshape` to slice into n segments of length 12 and then we `transpose`. - mmat = numpy.transpose(numpy.reshape(test, (-1, 12))) - monthly_mean_test = numpy.nanmean(mmat, axis=1) - annual_mean_test = numpy.mean(monthly_mean_test) - if numpy.isnan(annual_mean_test) == 1: + mmat = np.transpose(np.reshape(test, (-1, 12))) + monthly_mean_test = np.nanmean(mmat, axis=1) + annual_mean_test = np.mean(monthly_mean_test) + + if np.isnan(annual_mean_test) == 1: # The identified grid is in the ocean - monthly = numpy.ones((12, 1)) + monthly = np.ones((12, 1)) else: monthly = mmat - if isinstance(monthly, cdms2.tvariable.TransientVariable): - monthly = monthly.getValue() - - seasonality_index_test, peak_month_test = get_seasonality(monthly) - - export[i, 0] = annual_mean_ref - export[i, 1] = annual_mean_test - if area_upstream is not None: - export[i, 2] = ( - drainage_area_error * 100 - ) # From fraction to percentage of the drainage area bias - export[i, 3] = seasonality_index_ref # Seasonality index of ref - export[i, 4] = peak_month_ref # Max flow month of ref - export[i, 5] = seasonality_index_test # Seasonality index of test - export[i, 6] = peak_month_test # Max flow month of test - export[i, 7:9] = lat_lon_ref # latlon of ref - return export + seasonality_index_test, peak_month_test = _get_seasonality(monthly) + + # TODO: The export data structure should be turned into a dict. + export_data[i, 0] = annual_mean_ref + export_data[i, 1] = annual_mean_test + # Convert from fraction to percetange. + export_data[i, 2] = drainage_area_error * 100 + export_data[i, 3] = seasonality_index_ref + export_data[i, 4] = peak_month_ref + export_data[i, 5] = seasonality_index_test + export_data[i, 6] = peak_month_test + export_data[i, 7:9] = lat_lon_ref + + export_data = _remove_gauges_with_nan_flow(export_data, area_upstream) + + return export_data + + +def _get_drainage_area_error( + lon_ref: float, lat_ref: float, area_upstream: np.ndarray, area_ref: float +) -> Tuple[np.ndarray, List[float]]: + """Get the drainage area error. + + Parameters + ---------- + lon_ref : float + The reference longitude. + lat_ref : float + The reference latitude. + area_upstream : np.ndarray + The area upstream. + area_ref : float + The reference area. + + Returns + ------- + Tuple[np.ndarray, list[float, float]] + _description_ + """ + k_bound = len(range(-SEARCH_RADIUS, SEARCH_RADIUS + 1)) + k_bound *= k_bound + + area_test = np.zeros((k_bound, 1)) + error_test = np.zeros((k_bound, 1)) + lat_lon_test = np.zeros((k_bound, 2)) + k = 0 + + for i in range(-SEARCH_RADIUS, SEARCH_RADIUS + 1): + for j in range(-SEARCH_RADIUS, SEARCH_RADIUS + 1): + x = int( + 1 + ((lon_ref + j * RESOLUTION) - (-180 + RESOLUTION / 2)) / RESOLUTION + ) + y = int( + 1 + ((lat_ref + i * RESOLUTION) - (-90 + RESOLUTION / 2)) / RESOLUTION + ) + area_test[k] = area_upstream[x - 1, y - 1] / 1000000 + error_test[k] = np.abs(area_test[k] - area_ref) / area_ref + lat_lon_test[k, 0] = lat_ref + i * RESOLUTION + lat_lon_test[k, 1] = lon_ref + j * RESOLUTION + k += 1 + + # The id of the center grid in the searching area + center_id = (k_bound - 1) / 2 + + lat_lon_ref = [lat_ref, lon_ref] + drainage_area_error = error_test[int(center_id)] + + return drainage_area_error, lat_lon_ref + + +def _get_seasonality(monthly: np.ndarray) -> Tuple[int, float]: + """Get the seasonality. + + Parameters + ---------- + monthly : np.ndarray + The monthly data. + + Returns + ------- + Tuple[int, float] + A tuple including the seasonality index and the peak flow month. + + Raises + ------ + RuntimeError + If the monthly data does not include 12 months. + """ + monthly = monthly.astype(np.float64) + + # See https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2018MS001603 Equations 1 and 2 + if monthly.shape[0] != 12: + raise RuntimeError(f"monthly.shape={monthly.shape} does not include 12 months.") + + num_years = monthly.shape[1] + p_k = np.zeros((12, 1)) + + # The total streamflow for each year (sum of Q_ij in the denominator of Equation 1, for all j) + # 1 x num_years + total_streamflow = np.sum(monthly, axis=0) + for month in range(12): + # The streamflow for this month in each year (Q_ij in the numerator of Equation 1, for all j) + # 1 x num_years + streamflow_month_all_years = monthly[month, :] + # Proportion that this month contributes to streamflow that year. + # 1 x num_years + # For all i, divide streamflow_month_all_years[i] by total_streamflow[i] + streamflow_proportion = np.divide(streamflow_month_all_years, total_streamflow) + # The sum is the sum over j in Equation 1. + # Dividing the sum of proportions by num_years gives the *average* proportion of annual streamflow during + # this month. + # Multiplying by 12 makes it so that Pk_i (`p_k[month]`) will be 1 if all months have equal streamflow and + # 12 if all streamflow occurs in one month. + # These steps produce the 12/n factor in Equation 1. + p_k[month] = np.nansum(streamflow_proportion) * 12 / num_years + + # From Equation 2 + seasonality_index = np.max(p_k) + # `p_k == np.max(p_k)` produces a Boolean matrix, True if the value (i.e., streamflow) is the max value. + # `np.where(p_k == np.max(p_k))` produces the indices (i.e., months) where the max value is reached. + peak_month = np.where(p_k == np.max(p_k))[0] + # If more than one month has peak streamflow, simply define the peak month as the first one of the peak months. + # Month 0 is January, Month 1 is February, and so on. + peak_month = peak_month[0] + + return seasonality_index, peak_month # type: ignore + + +def _remove_gauges_with_nan_flow( + export_data: np.ndarray, area_upstream: np.ndarray | None +) -> np.ndarray: + """Remove gauges with NaN flow. + + Gauges will only plotted if they have a non-nan value for both test and ref + data. + + Parameters + ---------- + export_data : np.ndarray + The export data. + area_upstream : np.ndarray | None + The optional area upstream. + + Returns + ------- + np.ndarray + The export with gauges that have NaN flow removed. + """ + export_data_new = np.array(export_data) + export_data_new = export_data_new[~np.isnan(export_data_new[:, 0]), :] + export_data_new = export_data_new[~np.isnan(export_data_new[:, 1]), :] + + if area_upstream is not None: + export_data_new = export_data_new[export_data_new[:, 2] <= MAX_AREA_ERROR, :] + + return export_data_new diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 7645d0c60..0850f1fb2 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -173,7 +173,11 @@ def _get_derived_vars_map(self) -> DerivedVariablesMap: # Attribute related methods # -------------------------------------------------------------------------- - def get_name_yrs_attr(self, season: ClimoFreq | None = None) -> str: + def get_name_yrs_attr( + self, + season: ClimoFreq | None = None, + default_name: str | None = None, + ) -> str: """Get the diagnostic name and 'yrs_averaged' attr as a single string. This method is used to update either `parameter.test_name_yrs` or @@ -199,9 +203,9 @@ def get_name_yrs_attr(self, season: ClimoFreq | None = None) -> str: Replaces `e3sm_diags.driver.utils.general.get_name_and_yrs` """ if self.data_type == "test": - diag_name = self._get_test_name() + diag_name = self._get_test_name(default_name) elif self.data_type == "ref": - diag_name = self._get_ref_name() + diag_name = self._get_ref_name(default_name) if self.is_climo: if season is None: @@ -222,7 +226,7 @@ def get_name_yrs_attr(self, season: ClimoFreq | None = None) -> str: return f"{diag_name} ({yrs_averaged_attr})" - def _get_test_name(self) -> str: + def _get_test_name(self, default_name: str | None = None) -> str: """Get the diagnostic test name. Returns @@ -239,6 +243,9 @@ def _get_test_name(self) -> str: elif self.parameter.test_name != "": return self.parameter.test_name else: + if default_name is not None: + return default_name + # NOTE: This else statement is preserved from the previous CDAT # codebase to maintain the same behavior. if self.parameter.test_name == "": @@ -246,7 +253,7 @@ def _get_test_name(self) -> str: return self.parameter.test_name - def _get_ref_name(self) -> str: + def _get_ref_name(self, default_name: str | None = None) -> str: """Get the diagnostic reference name. Returns @@ -263,6 +270,10 @@ def _get_ref_name(self) -> str: elif self.parameter.reference_name != "": return self.parameter.reference_name else: + # Covers cases such as streamflow which set the ref name to "GSIM". + if default_name is not None: + return default_name + # NOTE: This else statement is preserved from the previous CDAT # codebase to maintain the same behavior. if self.parameter.ref_name == "": diff --git a/e3sm_diags/plot/cartopy/streamflow_plot.py b/e3sm_diags/plot/cartopy/streamflow_plot.py deleted file mode 100644 index a1aad2240..000000000 --- a/e3sm_diags/plot/cartopy/streamflow_plot.py +++ /dev/null @@ -1,731 +0,0 @@ -from __future__ import print_function - -import os - -import cartopy.crs as ccrs -import cartopy.feature as cfeature -import cdutil -import matplotlib -import numpy as np -import scipy.stats -from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter - -from e3sm_diags.derivations.default_regions import regions_specs -from e3sm_diags.driver.utils.general import get_output_dir -from e3sm_diags.logger import custom_logger - -matplotlib.use("Agg") -import matplotlib.colors as colors # isort:skip # noqa: E402 -import matplotlib.lines as lines # isort:skip # noqa: E402 -import matplotlib.pyplot as plt # isort:skip # noqa: E402 - -logger = custom_logger(__name__) - -plotTitle = {"fontsize": 11.5} -plotSideTitle = {"fontsize": 9.5} - -# Border padding relative to subplot axes for saving individual panels -# (left, bottom, width, height) in page coordinates -border = (-0.14, -0.06, 0.04, 0.08) - - -def add_cyclic(var): - lon = var.getLongitude() - return var(longitude=(lon[0], lon[0] + 360.0, "coe")) - - -def get_ax_size(fig, ax): - bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) - width, height = bbox.width, bbox.height - width *= fig.dpi - height *= fig.dpi - return width, height - - -def determine_tick_step(degrees_covered): - if degrees_covered > 180: - return 60 - if degrees_covered > 60: - return 30 - elif degrees_covered > 20: - return 10 - else: - return 1 - - -def plot_panel_seasonality_map( - plot_type, fig, proj, export, color_list, panel, parameter -): - if plot_type == "test": - panel_index = 0 - seasonality_index_export_index = 5 - peak_month_export_index = 6 - title = (None, parameter.test_title, None) - elif plot_type == "ref": - panel_index = 1 - seasonality_index_export_index = 3 - peak_month_export_index = 4 - title = (None, parameter.reference_title, None) - else: - raise Exception("Invalid plot_type={}".format(plot_type)) - - # Plot of streamflow gauges. Color -> peak month, marker size -> seasonality index. - ax = fig.add_axes(panel[panel_index], projection=proj) - region_str = parameter.regions[0] - region = regions_specs[region_str] - if "domain" in region.keys(): # type: ignore - # Get domain to plot - domain = region["domain"] # type: ignore - else: - # Assume global domain - domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb")) - kargs = domain.components()[0].kargs - # lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90) - lon_west, lon_east, lat_south, lat_north = (-180, 180, -90, 90) - if "longitude" in kargs: - lon_west, lon_east, _ = kargs["longitude"] - if "latitude" in kargs: - lat_south, lat_north, _ = kargs["latitude"] - lon_covered = lon_east - lon_west - lon_step = determine_tick_step(lon_covered) - xticks = np.arange(lon_west, lon_east, lon_step) - # Subtract 0.50 to get 0 W to show up on the right side of the plot. - # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot. - # If a number is added, then the value won't show up at all. - xticks = np.append(xticks, lon_east - 0.50) - lat_covered = lat_north - lat_south - lat_step = determine_tick_step(lat_covered) - yticks = np.arange(lat_south, lat_north, lat_step) - yticks = np.append(yticks, lat_north) - ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj) - proj_function = ccrs.PlateCarree - - # Stream gauges - si_2 = 2 - si_4 = 3 - si_6 = 4 - si_large = 5 - # `export` is the array of gauges. Each gauge has multiple fields -- e.g., lat is index 7 - for gauge in export: - lat = gauge[7] - lon = gauge[8] - seasonality_index = gauge[seasonality_index_export_index] - if seasonality_index < 2: - markersize = si_2 - elif seasonality_index < 4: - markersize = si_4 - elif seasonality_index < 6: - markersize = si_6 - elif seasonality_index <= 12: - markersize = si_large - else: - raise Exception("Invalid seasonality index={}".format(seasonality_index)) - if seasonality_index == 1: - color = "black" - else: - peak_month = int(gauge[peak_month_export_index]) - color = color_list[peak_month] - # https://scitools.org.uk/iris/docs/v1.9.2/examples/General/projections_and_annotations.html - # Place a single marker point for each gauge. - plt.plot( - lon, - lat, - marker="o", - color=color, - markersize=markersize, - transform=proj_function(), - ) - # NOTE: the "plt.annotate call" does not have a "transform=" keyword, - # so for this one we transform the coordinates with a Cartopy call. - at_x, at_y = ax.projection.transform_point(lon, lat, src_crs=proj_function()) - # https://matplotlib.org/3.1.1/gallery/text_labels_and_annotations/custom_legends.html - legend_elements = [ - lines.Line2D( - [0], - [0], - marker="o", - color="w", - label="1 <= SI < 2", - markerfacecolor="black", - markersize=si_2, - ), - lines.Line2D( - [0], - [0], - marker="o", - color="w", - label="2 <= SI < 4", - markerfacecolor="black", - markersize=si_4, - ), - lines.Line2D( - [0], - [0], - marker="o", - color="w", - label="4 <= SI < 6", - markerfacecolor="black", - markersize=si_6, - ), - lines.Line2D( - [0], - [0], - marker="o", - color="w", - label="6 <= SI <= 12", - markerfacecolor="black", - markersize=si_large, - ), - ] - seasonality_legend_title = "Seasonality (SI)" - plt.legend( - handles=legend_elements, - title=seasonality_legend_title, - prop={"size": 8}, - ) - - # Full world would be aspect 360/(2*180) = 1 - ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) - ax.coastlines(lw=0.3) - ax.add_feature(cfeature.RIVERS) - if title[0] is not None: - ax.set_title(title[0], loc="left", fontdict=plotSideTitle) - if title[1] is not None: - ax.set_title(title[1], fontdict=plotTitle) - if title[2] is not None: - ax.set_title(title[2], loc="right", fontdict=plotSideTitle) - ax.set_xticks(xticks, crs=proj_function()) - ax.set_yticks(yticks, crs=proj_function()) - lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f") - lat_formatter = LatitudeFormatter() - ax.xaxis.set_major_formatter(lon_formatter) - ax.yaxis.set_major_formatter(lat_formatter) - ax.tick_params(labelsize=8.0, direction="out", width=1) - ax.xaxis.set_ticks_position("bottom") - ax.yaxis.set_ticks_position("left") - - # Color bar - cbax = fig.add_axes( - ( - panel[panel_index][0] + 0.7535, - panel[panel_index][1] + 0.0515, - 0.0326, - 0.1792, - ) - ) - # https://matplotlib.org/tutorials/colors/colorbar_only.html - num_colors = len(color_list) - if parameter.print_statements: - logger.info("num_colors={}".format(num_colors)) - cmap = colors.ListedColormap(color_list) - cbar_label = "Peak month" - - bounds = list(range(num_colors)) - # Set ticks to be in between the bounds - ticks = list(map(lambda bound: bound + 0.5, bounds)) - # Add one more bound at the bottom of the colorbar. - # `bounds` should be one longer than `ticks`. - bounds += [bounds[-1] + 1] - if parameter.print_statements: - logger.info("bounds={}".format(bounds)) - norm = colors.BoundaryNorm(bounds, cmap.N) - cbar = fig.colorbar( - matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm), - cax=cbax, - boundaries=bounds, - ticks=ticks, - spacing="uniform", - orientation="vertical", - label=cbar_label, - ) - # https://matplotlib.org/3.1.1/gallery/ticks_and_spines/colorbar_tick_labelling_demo.html - months = [ - "Jan", - "Feb", - "Mar", - "Apr", - "May", - "Jun", - "Jul", - "Aug", - "Sep", - "Oct", - "Nov", - "Dec", - ] - cbar.ax.set_yticklabels(months) - cbar.ax.invert_yaxis() - - w, h = get_ax_size(fig, cbax) - - cbar.ax.tick_params(labelsize=9.0, length=0) - - -def plot_seasonality_map(export, parameter): - if parameter.backend not in ["cartopy", "mpl", "matplotlib"]: - return - - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - panel = [ - (0.0900, 0.5500, 0.7200, 0.3000), - (0.0900, 0.1300, 0.7200, 0.3000), - ] - - # Create figure, projection - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - proj = ccrs.PlateCarree(central_longitude=0) - - # test and ref color lists - # Selected from 'hsv' colormap: - color_list = [ - (0.05, 0.00, 0.99), - (0.03, 0.30, 0.98), - (0.12, 0.92, 0.99), - (0.13, 1.00, 0.65), - (0.14, 1.00, 0.05), - (0.98, 0.99, 0.04), - (0.99, 0.67, 0.04), - (0.99, 0.34, 0.03), - (0.99, 0.07, 0.03), - (0.99, 0.00, 0.53), - (0.68, 0.00, 1.00), - (0.29, 0.00, 1.00), - ] - - # First panel - plot_panel_seasonality_map("test", fig, proj, export, color_list, panel, parameter) - - # Second panel - plot_panel_seasonality_map("ref", fig, proj, export, color_list, panel, parameter) - - # Figure title - fig.suptitle(parameter.main_title_seasonality_map, x=0.5, y=0.97, fontsize=15) - - # Prepare to save figure - # get_output_dir => {parameter.results_dir}/{set_name}/{parameter.case_id} - # => {parameter.results_dir}/streamflow/{parameter.case_id} - output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Output dir: {}".format(output_dir)) - # get_output_dir => {parameter.orig_results_dir}/{set_name}/{parameter.case_id} - # => {parameter.orig_results_dir}/streamflow/{parameter.case_id} - original_output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Original output dir: {}".format(original_output_dir)) - # parameter.output_file_seasonality_map is defined in e3sm_diags/parameter/streamflow_parameter.py - # {parameter.results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_seasonality_map} - file_path = os.path.join(output_dir, parameter.output_file_seasonality_map) - # {parameter.orig_results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_seasonality_map} - original_file_path = os.path.join( - original_output_dir, parameter.output_file_seasonality_map - ) - - # Save figure - for f in parameter.output_format: - f = f.lower().split(".")[-1] - plot_suffix = "." + f - plot_file_path = file_path + plot_suffix - plt.savefig(plot_file_path) - # Get the filename that the user has passed in and display that. - original_plot_file_path = original_file_path + plot_suffix - # Always print, even without `parameter.print_statements` - logger.info(f"Plot saved in: {original_plot_file_path}") - - # Save individual subplots - for f in parameter.output_format_subplot: - page = fig.get_size_inches() - i = 0 - for p in panel: - # Extent of subplot - subpage = np.array(p).reshape(2, 2) - subpage[1, :] = subpage[0, :] + subpage[1, :] - subpage = subpage + np.array(border).reshape(2, 2) - subpage = list((subpage * page).flatten()) # type: ignore - extent = matplotlib.transforms.Bbox.from_extents(*subpage) - # Save subplot - subplot_suffix = ".%i." % i + f - subplot_file_path = file_path + subplot_suffix - plt.savefig(subplot_file_path, bbox_inches=extent) - # Get the filename that the user has passed in and display that. - original_subplot_file_path = original_file_path + subplot_suffix - # Always print, even without `parameter.print_statements` - logger.info(f"Sub-plot saved in: {original_subplot_file_path}") - i += 1 - - plt.close() - - -def plot_panel_annual_map(panel_index, fig, proj, export, bias_array, panel, parameter): - if panel_index == 0: - panel_type = "test" - elif panel_index == 1: - panel_type = "ref" - elif panel_index == 2: - panel_type = "bias" - else: - raise Exception("Invalid panel_index={}".format(panel_index)) - - # Plot of streamflow gauges. Color -> peak month, marker size -> seasonality index. - - # Position and sizes of subplot axes in page coordinates (0 to 1) - ax = fig.add_axes(panel[panel_index], projection=proj) - region_str = parameter.regions[0] - region = regions_specs[region_str] - if "domain" in region.keys(): # type: ignore - # Get domain to plot - domain = region["domain"] # type: ignore - else: - # Assume global domain - domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb")) - kargs = domain.components()[0].kargs - # lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90) - lon_west, lon_east, lat_south, lat_north = (-180, 180, -90, 90) - if "longitude" in kargs: - lon_west, lon_east, _ = kargs["longitude"] - if "latitude" in kargs: - lat_south, lat_north, _ = kargs["latitude"] - lon_covered = lon_east - lon_west - lon_step = determine_tick_step(lon_covered) - xticks = np.arange(lon_west, lon_east, lon_step) - # Subtract 0.50 to get 0 W to show up on the right side of the plot. - # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot. - # If a number is added, then the value won't show up at all. - xticks = np.append(xticks, lon_east - 0.50) - lat_covered = lat_north - lat_south - lat_step = determine_tick_step(lat_covered) - yticks = np.arange(lat_south, lat_north, lat_step) - yticks = np.append(yticks, lat_north) - ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj) - proj_function = ccrs.PlateCarree - - # Stream gauges - color_list, value_min, value_max, norm = setup_annual_map( - parameter, panel_type, bias_array - ) - plot_gauges_annual_map( - panel_type, - export, - bias_array, - value_min, - value_max, - color_list, - proj_function, - ax, - ) - - # Full world would be aspect 360/(2*180) = 1 - ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) - ax.coastlines(lw=0.3) - ax.add_feature(cfeature.RIVERS) - if panel_type == "test": - title = parameter.test_title - elif panel_type == "ref": - title = parameter.reference_title - elif panel_type == "bias": - title = "Relative Bias" - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - title = (None, title, None) - if title[0] is not None: - ax.set_title(title[0], loc="left", fontdict=plotSideTitle) - if title[1] is not None: - ax.set_title(title[1], fontdict=plotTitle) - if title[2] is not None: - ax.set_title(title[2], loc="right", fontdict=plotSideTitle) - ax.set_xticks(xticks, crs=proj_function()) - ax.set_yticks(yticks, crs=proj_function()) - lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f") - lat_formatter = LatitudeFormatter() - ax.xaxis.set_major_formatter(lon_formatter) - ax.yaxis.set_major_formatter(lat_formatter) - ax.tick_params(labelsize=8.0, direction="out", width=1) - ax.xaxis.set_ticks_position("bottom") - ax.yaxis.set_ticks_position("left") - - # Color bar - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - cbax = fig.add_axes( - ( - panel[panel_index][0] + 0.6635, - panel[panel_index][1] + 0.0115, - 0.0326, - 0.1792, - ) - ) - cmap = colors.ListedColormap(color_list) - if panel_type in ["test", "ref"]: - cbar_label = "Mean annual discharge ($m^3$/$s$)" - elif panel_type == "bias": - cbar_label = "Bias of mean annual discharge (%)\n(test-ref)/ref" - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - cbar = fig.colorbar( - matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm), - cax=cbax, - label=cbar_label, - extend="both", - ) - w, h = get_ax_size(fig, cbax) - if panel_type in ["test", "ref"]: - pass - elif panel_type == "bias": - step_size = (value_max - value_min) // 5 - ticks = np.arange(int(value_min), int(value_max) + step_size, step_size) - cbar.ax.tick_params(labelsize=9.0, length=0) - cbar.ax.set_yticklabels(ticks) - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - - -def setup_annual_map(parameter, panel_type, bias_array): - # Continuous colormap - colormap = plt.get_cmap("jet_r") - color_list = list(map(lambda index: colormap(index)[:3], range(colormap.N))) - if panel_type in ["test", "ref"]: - value_min, value_max = 1, 1e4 - # https://matplotlib.org/3.2.1/tutorials/colors/colormapnorms.html - norm = matplotlib.colors.LogNorm(vmin=value_min, vmax=value_max) - elif panel_type == "bias": - if parameter.print_statements: - value_min = np.floor(np.min(bias_array)) - value_max = np.ceil(np.max(bias_array)) - logger.info( - "Bias of mean annual discharge {} min={}, max={}".format( - panel_type, value_min, value_max - ) - ) - - value_min = -100 - value_max = 100 - norm = matplotlib.colors.Normalize() - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - return color_list, value_min, value_max, norm - - -def plot_gauges_annual_map( - panel_type, export, bias_array, value_min, value_max, color_list, proj_function, ax -): - # `export` is the array of gauges. Each gauge has multiple fields -- e.g., lat is index 7 - for gauge, i in zip(export, range(len(export))): - if panel_type == "test": - # Test mean annual discharge - value = gauge[1] - elif panel_type == "ref": - # Ref mean annual discharge - value = gauge[0] - elif panel_type == "bias": - # Bias - value = bias_array[i] - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - if np.isnan(value): - continue - if value < value_min: - value = value_min - elif value > value_max: - value = value_max - if panel_type in ["test", "ref"]: - # Logarithmic Rescale (min-max normalization) to [-1,1] range - normalized_value = (np.log10(value) - np.log10(value_min)) / ( - np.log10(value_max) - np.log10(value_min) - ) - elif panel_type == "bias": - # Rescale (min-max normalization) to [-1,1] range - normalized_value = (value - value_min) / (value_max - value_min) - else: - raise Exception("Invalid panel_type={}".format(panel_type)) - lat = gauge[7] - lon = gauge[8] - - color = color_list[int(normalized_value * (len(color_list) - 1))] - # https://scitools.org.uk/iris/docs/v1.9.2/examples/General/projections_and_annotations.html - # Place a single marker point for each gauge. - plt.plot( - lon, - lat, - marker="o", - markersize=2, - color=color, - transform=proj_function(), - ) - # NOTE: the "plt.annotate call" does not have a "transform=" keyword, - # so for this one we transform the coordinates with a Cartopy call. - at_x, at_y = ax.projection.transform_point(lon, lat, src_crs=proj_function()) - - -def plot_annual_map(export, bias, parameter): - if parameter.backend not in ["cartopy", "mpl", "matplotlib"]: - return - - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - panel = [ - (0.1691, 0.6810, 0.6465, 0.2258), - (0.1691, 0.3961, 0.6465, 0.2258), - (0.1691, 0.1112, 0.6465, 0.2258), - ] - - # Create figure, projection - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - proj = ccrs.PlateCarree(central_longitude=0) - - # First panel - plot_panel_annual_map(0, fig, proj, export, bias, panel, parameter) - - # Second panel - plot_panel_annual_map(1, fig, proj, export, bias, panel, parameter) - - # Third panel - plot_panel_annual_map(2, fig, proj, export, bias, panel, parameter) - - # Figure title - fig.suptitle(parameter.main_title_annual_map, x=0.5, y=0.97, fontsize=15) - - # Prepare to save figure - # get_output_dir => {parameter.results_dir}/{set_name}/{parameter.case_id} - # => {parameter.results_dir}/streamflow/{parameter.case_id} - output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Output dir: {}".format(output_dir)) - # get_output_dir => {parameter.orig_results_dir}/{set_name}/{parameter.case_id} - # => {parameter.orig_results_dir}/streamflow/{parameter.case_id} - original_output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Original output dir: {}".format(original_output_dir)) - # parameter.output_file_annual_map is defined in e3sm_diags/parameter/streamflow_parameter.py - # {parameter.results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_annual_map} - file_path = os.path.join(output_dir, parameter.output_file_annual_map) - # {parameter.orig_results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_annual_map} - original_file_path = os.path.join( - original_output_dir, parameter.output_file_annual_map - ) - - # Save figure - for f in parameter.output_format: - f = f.lower().split(".")[-1] - plot_suffix = "." + f - plot_file_path = file_path + plot_suffix - plt.savefig(plot_file_path) - # Get the filename that the user has passed in and display that. - original_plot_file_path = original_file_path + plot_suffix - # Always print, even without `parameter.print_statements` - logger.info(f"Plot saved in: {original_plot_file_path}") - - # Save individual subplots - for f in parameter.output_format_subplot: - page = fig.get_size_inches() - i = 0 - for p in panel: - # Extent of subplot - subpage = np.array(p).reshape(2, 2) - subpage[1, :] = subpage[0, :] + subpage[1, :] - subpage = subpage + np.array(border).reshape(2, 2) - subpage = list((subpage * page).flatten()) # type: ignore - extent = matplotlib.transforms.Bbox.from_extents(*subpage) - # Save subplot - subplot_suffix = ".%i." % i + f - subplot_file_path = file_path + subplot_suffix - plt.savefig(subplot_file_path, bbox_inches=extent) - # Get the filename that the user has passed in and display that. - original_subplot_file_path = original_file_path + subplot_suffix - # Always print, even without `parameter.print_statements` - logger.info(f"Sub-plot saved in: {original_subplot_file_path}") - - i += 1 - - plt.close() - - -def plot_annual_scatter(xs, ys, zs, parameter): - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - panel = [(0.0900, 0.2000, 0.7200, 0.6000)] - - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - - ax = fig.add_axes(panel[0]) - cmap = plt.get_cmap("jet") - ax.scatter(xs, ys, label="Scatterplot", marker="o", s=10, c=zs, cmap=cmap) - r, _ = scipy.stats.pearsonr(xs, ys) - r2 = r * r - r2_str = "{0:.2f}".format(r2) - bounds = [0.01, 100000] - ax.plot(bounds, bounds, color="red", linestyle="-") - ax.set_xscale("log") - ax.set_yscale("log") - ax.set_xlabel( - "{} streamflow ($m^3$/$s$)".format(parameter.reference_title), - fontsize=12, - ) - ax.set_ylabel("{} streamflow ($m^3$/$s$)".format(parameter.test_title), fontsize=12) - ax.set_xlim(bounds[0], bounds[1]) - ax.set_ylim(bounds[0], bounds[1]) - ax.tick_params(axis="both", labelsize=12) - - # Color bar - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - cbax = fig.add_axes( - (panel[0][0] + 0.7535, panel[0][1] + 0.0515, 0.0326, 0.1792 * 2) - ) - cbar_label = "Drainage area bias (%)" - cbar = fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap), cax=cbax) - cbar.ax.set_ylabel(cbar_label, fontsize=12) - w, h = get_ax_size(fig, cbax) - zs_max = np.ceil(np.max(zs)) - zs_min = np.floor(np.min(zs)) - step_size = (zs_max - zs_min) // 5 - try: - ticks = np.arange(zs_min, zs_max + step_size, step_size) - cbar.ax.set_yticklabels(ticks) - except ValueError: - # `zs` has invalid values (likely from no area_upstream being found). - # Just use default colorbar. - pass - cbar.ax.tick_params(labelsize=12.0, length=0) - - # Figure title - if parameter.main_title_annual_scatter == "": - main_title_annual_scatter = "Annual mean streamflow\n{} vs {}".format( - parameter.test_title, parameter.reference_title - ) - else: - main_title_annual_scatter = parameter.main_title_annual_scatter - ax.set_title(main_title_annual_scatter, loc="center", y=1.05, fontsize=15) - - legend_title = "$R^2$={}, (n={})".format(r2_str, xs.shape[0]) - ax.legend(handles=[], title=legend_title, loc="upper left", prop={"size": 12}) - - # Prepare to save figure - # get_output_dir => {parameter.results_dir}/{set_name}/{parameter.case_id} - # => {parameter.results_dir}/streamflow/{parameter.case_id} - output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Output dir: {}".format(output_dir)) - # get_output_dir => {parameter.orig_results_dir}/{set_name}/{parameter.case_id} - # => {parameter.orig_results_dir}/streamflow/{parameter.case_id} - original_output_dir = get_output_dir(parameter.current_set, parameter) - if parameter.print_statements: - logger.info("Original output dir: {}".format(original_output_dir)) - # parameter.output_file_annual_scatter is defined in e3sm_diags/parameter/streamflow_parameter.py - # {parameter.results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_annual_scatter} - file_path = os.path.join(output_dir, parameter.output_file_annual_scatter) - # {parameter.orig_results_dir}/streamflow/{parameter.case_id}/{parameter.output_file_annual_scatter} - original_file_path = os.path.join( - original_output_dir, parameter.output_file_annual_scatter - ) - - # Save figure - for f in parameter.output_format: - f = f.lower().split(".")[-1] - plot_suffix = "." + f - plot_file_path = file_path + plot_suffix - plt.savefig(plot_file_path) - # Get the filename that the user has passed in and display that. - original_plot_file_path = original_file_path + plot_suffix - logger.info(f"Plot saved in: {original_plot_file_path}") - - plt.close() diff --git a/e3sm_diags/plot/streamflow_plot_map.py b/e3sm_diags/plot/streamflow_plot_map.py new file mode 100644 index 000000000..394da4c8c --- /dev/null +++ b/e3sm_diags/plot/streamflow_plot_map.py @@ -0,0 +1,325 @@ +from typing import List, Tuple, Union + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib +import numpy as np + +from e3sm_diags.derivations.default_regions_xr import REGION_SPECS +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.streamflow_parameter import StreamflowParameter +from e3sm_diags.plot.utils import ( + _configure_titles, + _configure_x_and_y_axes, + _get_x_ticks, + _get_y_ticks, + _save_plot, +) + +matplotlib.use("Agg") +import matplotlib.colors as colors # isort:skip # noqa: E402 +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +# Border padding relative to subplot axes for saving individual panels +# (left, bottom, width, height) in page coordinates. +BORDER_PADDING = (-0.14, -0.06, 0.04, 0.08) + +PROJECTION = ccrs.PlateCarree(central_longitude=0) +PROJECTION_FUNC = ccrs.PlateCarree + +# Position and sizes of subplot axes in page coordinates (0 to 1) +# (left, bottom, width, height) in page coordinates. +ANNUAL_MAP_PANEL_CFG = [ + (0.1691, 0.6810, 0.6465, 0.2258), + (0.1691, 0.3961, 0.6465, 0.2258), + (0.1691, 0.1112, 0.6465, 0.2258), +] + + +def plot_annual_map(parameter: StreamflowParameter, export_data: np.ndarray): + """Plot the streamflow annual map. + + Parameters + ---------- + parameter : StreamflowParameter + The streamflow parameter. + export_data : np.ndarray + The export data. + """ + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + fig.suptitle(parameter.main_title_annual_map, x=0.5, y=0.97, fontsize=15) + + # Bias between test and ref as a percentage. + # Relative error as a percentage. + # 100*((annual_mean_test - annual_mean_ref) / annual_mean_ref) + bias = 100 * ((export_data[:, 1] - export_data[:, 0]) / export_data[:, 0]) + + _plot_panel_annual_map(0, fig, parameter, export_data, bias) + _plot_panel_annual_map(1, fig, parameter, export_data, bias) + _plot_panel_annual_map(2, fig, parameter, export_data, bias) + + # NOTE: Need to set the output filename to the name of the specific + # streamflow plot before saving the plot, otherwise the filename will + # be blank. + parameter.output_file = parameter.output_file_annual_map + _save_plot(fig, parameter, border_padding=BORDER_PADDING) + + plt.close() + + +def _plot_panel_annual_map( + panel_index: int, + fig: plt.Figure, + parameter: StreamflowParameter, + export_data: np.ndarray, + bias_array: np.ndarray, +): + """Plot the panel for each annual map based on the data type. + + Parameters + ---------- + panel_index : int + The panel index. + fig : plt.Figure + The figure object. + parameter : StreamflowParameter + The streamflow parameter. + export_data : np.ndarray + The export data. + bias_array : np.ndarray + The bias array. + """ + if panel_index == 0: + panel_type = "test" + elif panel_index == 1: + panel_type = "ref" + elif panel_index == 2: + panel_type = "bias" + + # Get region info and X and Y plot ticks. + # -------------------------------------------------------------------------- + region_key = parameter.regions[0] + region_specs = REGION_SPECS[region_key] + + # Get the region's domain slices for latitude and longitude if set, or + # use the default value. If both are not set, then the region type is + # considered "global". + lat_slice = region_specs.get("lat", (-90, 90)) # type: ignore + lon_slice = region_specs.get("lon", (-180, 180)) # type: ignore + + # Boolean flags for configuring plots. + is_global_domain = lat_slice == (-90, 90) and lon_slice == (-180, 180) + is_lon_full = lon_slice == (-180, 180) + + # Determine X and Y ticks using longitude and latitude domains respectively. + lon_west, lon_east = lon_slice + x_ticks = _get_x_ticks( + lon_west, + lon_east, + is_global_domain, + is_lon_full, + tick_step_func=_determine_tick_step, + ) + + lat_south, lat_north = lat_slice + y_ticks = _get_y_ticks(lat_south, lat_north, tick_step_func=_determine_tick_step) + + # Get the figure Axes object using the projection above and configure the + # aspect ratio, coastlines, and add RIVERS. + # -------------------------------------------------------------------------- + ax = fig.add_axes(ANNUAL_MAP_PANEL_CFG[panel_index], projection=PROJECTION) + ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=PROJECTION) + ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) + ax.coastlines(lw=0.3) + ax.add_feature(cfeature.RIVERS) + + # Plot of streamflow gauges. + # -------------------------------------------------------------------------- + color_list, value_min, value_max, norm = setup_annual_map(panel_type) + _plot_gauges( + ax, + panel_type, + export_data, + bias_array, + value_min, + value_max, + color_list, + ) + + # Configure the titles, x and y axes. + # -------------------------------------------------------------------------- + if panel_type == "test": + title = parameter.test_title + elif panel_type == "ref": + title = parameter.reference_title + elif panel_type == "bias": + title = "Relative Bias" + + _configure_titles(ax, (None, title, None)) + _configure_x_and_y_axes( + ax, x_ticks, y_ticks, ccrs.PlateCarree(), parameter.current_set + ) + + # Configure the colorbar. + # -------------------------------------------------------------------------- + cbax = fig.add_axes( + ( + ANNUAL_MAP_PANEL_CFG[panel_index][0] + 0.6635, + ANNUAL_MAP_PANEL_CFG[panel_index][1] + 0.0115, + 0.0326, + 0.1792, + ) + ) + cmap = colors.ListedColormap(color_list) + + if panel_type in ["test", "ref"]: + cbar_label = "Mean annual discharge ($m^3$/$s$)" + elif panel_type == "bias": + cbar_label = "Bias of mean annual discharge (%)\n(test-ref)/ref" + + cbar = fig.colorbar( + matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm), + cax=cbax, + label=cbar_label, + extend="both", + ) + + if panel_type == "bias": + step_size = (value_max - value_min) // 5 + ticks = np.arange(int(value_min), int(value_max) + step_size, step_size) + cbar.ax.tick_params(labelsize=9.0, length=0) + cbar.ax.set_yticklabels(ticks) + + +def setup_annual_map( + panel_type: str, +) -> Tuple[ + List[str], + float, + float, + Union[matplotlib.colors.LogNorm, matplotlib.colors.Normalize], +]: + """Set up the annual map based on the panel type. + + Parameters + ---------- + panel_type : str + The panel type. + + Returns + ------- + Tuple[ List[str], float, float, matplotlib.colors.LogNorm | matplotlib.colors.Normalize ] + A tuple for the color list, the minimum value, the maximum value, and + the color normalization. + """ + colormap = plt.get_cmap("jet_r") + color_list = list(map(lambda index: colormap(index)[:3], range(colormap.N))) + + if panel_type in ["test", "ref"]: + value_min, value_max = 1, 1e4 + norm = matplotlib.colors.LogNorm(vmin=value_min, vmax=value_max) + elif panel_type == "bias": + value_min = -100 + value_max = 100 + norm = matplotlib.colors.Normalize() + + return color_list, value_min, value_max, norm + + +def _plot_gauges( + ax: plt.Axes, + panel_type: str, + export_data: np.ndarray, + bias_array: np.ndarray, + value_min: float, + value_max: float, + color_list: List[str], +): + """Plot the streamflow gauges. + + This function plots each each gauge as a single marker point. + + Parameters + ---------- + ax : plt.Axes + The matplotlib axes object. + panel_type : str + The panel type. + export_data : np.ndarray + The export data. + bias_array : np.ndarray + The bias array. + value_min : float + The minimum value of the map. + value_max : float + The maximum value of the map. + color_list : List[str] + The list of colors to use for markers. + """ + for gauge, i in zip(export_data, range(len(export_data))): + if panel_type == "test": + value = gauge[1] + elif panel_type == "ref": + value = gauge[0] + elif panel_type == "bias": + value = bias_array[i] + + if np.isnan(value): + continue + + if value < value_min: + value = value_min + elif value > value_max: + value = value_max + + if panel_type in ["test", "ref"]: + # Logarithmic Rescale (min-max normalization) to [-1,1] range + normalized_value = (np.log10(value) - np.log10(value_min)) / ( + np.log10(value_max) - np.log10(value_min) + ) + elif panel_type == "bias": + # Rescale (min-max normalization) to [-1,1] range + normalized_value = (value - value_min) / (value_max - value_min) + + lat = gauge[7] + lon = gauge[8] + + color = color_list[int(normalized_value * (len(color_list) - 1))] + + plt.plot( + lon, + lat, + marker="o", + markersize=2, + color=color, + transform=PROJECTION_FUNC(), + ) + + # NOTE: the "plt.annotate call" does not have a "transform=" keyword, + # so for this one we transform the coordinates with a Cartopy call. + ax.projection.transform_point(lon, lat, src_crs=PROJECTION_FUNC()) + + +def _determine_tick_step(degrees_covered: float) -> int: + """Determine the number of tick steps based on the degrees covered by the axis. + + Parameters + ---------- + degrees_covered : float + The degrees covered by the axis. + + Returns + ------- + int + The number of tick steps. + """ + if degrees_covered > 180: + return 60 + if degrees_covered > 60: + return 30 + elif degrees_covered > 20: + return 10 + else: + return 1 diff --git a/e3sm_diags/plot/streamflow_plot_scatter.py b/e3sm_diags/plot/streamflow_plot_scatter.py new file mode 100644 index 000000000..7aced0dd6 --- /dev/null +++ b/e3sm_diags/plot/streamflow_plot_scatter.py @@ -0,0 +1,122 @@ +import matplotlib +import numpy as np +import scipy.stats + +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.streamflow_parameter import StreamflowParameter +from e3sm_diags.plot.utils import _save_plot + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +# Position and sizes of subplot axes in page coordinates (0 to 1) +# (left, bottom, width, height) in page coordinates +ANNUAL_SCATTER_PANEL_CFG = [(0.0900, 0.2000, 0.7200, 0.6000)] + + +def plot_annual_scatter(parameter: StreamflowParameter, export_data: np.ndarray): + """Plot the streamflow annual scatter. + + Parameters + ---------- + parameter : StreamflowParameter + The streamflow parameter. + export_data : np.ndarray + The export data. + """ + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + + ann_mean_ref = export_data[:, 0] + ann_mean_test = export_data[:, 1] + pct_drainage_area_bias = export_data[:, 2] + + # Get the figure Axes object and configure axes. + # -------------------------------------------------------------------------- + ax = fig.add_axes(ANNUAL_SCATTER_PANEL_CFG[0]) + + bounds = [0.01, 100000] + ax.plot(bounds, bounds, color="red", linestyle="-") + + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel( + f"{parameter.reference_title} streamflow ($m^3$/$s$)", + fontsize=12, + ) + ax.set_ylabel(f"{parameter.test_title} streamflow ($m^3$/$s$)", fontsize=12) + ax.set_xlim(bounds[0], bounds[1]) + ax.set_ylim(bounds[0], bounds[1]) + ax.tick_params(axis="both", labelsize=12) + + # Configure the title. + # -------------------------------------------------------------------------- + if parameter.main_title_annual_scatter == "": + main_title_annual_scatter = ( + f"Annual mean streamflow\n{parameter.test_title} vs " + f"{parameter.reference_title}" + ) + else: + main_title_annual_scatter = parameter.main_title_annual_scatter + + ax.set_title(main_title_annual_scatter, loc="center", y=1.05, fontsize=15) + + # Configure the legend. + # -------------------------------------------------------------------------- + r, _ = scipy.stats.pearsonr(ann_mean_ref, ann_mean_test) + r2 = r * r + r2_str = "{0:.2f}".format(r2) + + legend_title = f"$R^2$={r2_str}, (n={ann_mean_ref.shape[0]})" + ax.legend(handles=[], title=legend_title, loc="upper left", prop={"size": 12}) + + # Configure the color map. + # -------------------------------------------------------------------------- + cmap = plt.get_cmap("jet") + ax.scatter( + ann_mean_ref, + ann_mean_test, + label="Scatterplot", + marker="o", + s=10, + c=pct_drainage_area_bias, + cmap=cmap, + ) + + # Configure the colorbar. + # -------------------------------------------------------------------------- + cbax = fig.add_axes( + ( + ANNUAL_SCATTER_PANEL_CFG[0][0] + 0.7535, + ANNUAL_SCATTER_PANEL_CFG[0][1] + 0.0515, + 0.0326, + 0.1792 * 2, + ) + ) + cbar_label = "Drainage area bias (%)" + cbar = fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap), cax=cbax) + cbar.ax.set_ylabel(cbar_label, fontsize=12) + cbar.ax.tick_params(labelsize=12.0, length=0) + + pct_drainage_area_max = np.ceil(np.max(pct_drainage_area_bias)) + pct_drainage_area_min = np.floor(np.min(pct_drainage_area_bias)) + step_size = (pct_drainage_area_max - pct_drainage_area_min) // 5 + + try: + ticks = np.arange( + pct_drainage_area_min, pct_drainage_area_max + step_size, step_size + ) + cbar.ax.set_yticklabels(ticks) + except ValueError: + # `pct_drainage_area_bias` has invalid values (likely from no area_upstream being found). + # Just use default colorbar. + pass + + # NOTE: Need to set the output filename to the name of the specific + # streamflow plot before saving the plot, otherwise the filename will + # be blank. + parameter.output_file = parameter.output_file_annual_scatter + _save_plot(fig, parameter) + + plt.close() diff --git a/e3sm_diags/plot/streamflow_plot_seasonality.py b/e3sm_diags/plot/streamflow_plot_seasonality.py new file mode 100644 index 000000000..2e026f3dd --- /dev/null +++ b/e3sm_diags/plot/streamflow_plot_seasonality.py @@ -0,0 +1,343 @@ +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib +import numpy as np + +from e3sm_diags.derivations.default_regions_xr import REGION_SPECS +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.streamflow_parameter import StreamflowParameter +from e3sm_diags.plot.utils import ( + _configure_titles, + _configure_x_and_y_axes, + _get_x_ticks, + _get_y_ticks, + _save_plot, +) + +matplotlib.use("Agg") +import matplotlib.colors as colors # isort:skip # noqa: E402 +import matplotlib.lines as lines # isort:skip # noqa: E402 +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +# Border padding relative to subplot axes for saving individual panels +# (left, bottom, width, height) in page coordinates +BORDER_PADDING = (-0.14, -0.06, 0.04, 0.08) + +# Position and sizes of subplot axes in page coordinates (0 to 1) +# (left, bottom, width, height) in page coordinates. +PANEL_CFG = [ + (0.0900, 0.5500, 0.7200, 0.3000), + (0.0900, 0.1300, 0.7200, 0.3000), +] + +# Test and ref color lists, slected from 'hsv' colormap. +COLOR_LIST = [ + (0.05, 0.00, 0.99), + (0.03, 0.30, 0.98), + (0.12, 0.92, 0.99), + (0.13, 1.00, 0.65), + (0.14, 1.00, 0.05), + (0.98, 0.99, 0.04), + (0.99, 0.67, 0.04), + (0.99, 0.34, 0.03), + (0.99, 0.07, 0.03), + (0.99, 0.00, 0.53), + (0.68, 0.00, 1.00), + (0.29, 0.00, 1.00), +] + +# Dictionary mapping seasonality index to marker size. +SEASONALITY_INDEX = {"si_2": 2, "si_4": 3, "si_6": 4, "si_large": 5} +# Legend elements based on the marker size using the seasonality index dict. +LEGEND_ELEMENTS = [ + lines.Line2D( + [0], + [0], + marker="o", + color="w", + label="1 <= SI < 2", + markerfacecolor="black", + markersize=SEASONALITY_INDEX["si_2"], + ), + lines.Line2D( + [0], + [0], + marker="o", + color="w", + label="2 <= SI < 4", + markerfacecolor="black", + markersize=SEASONALITY_INDEX["si_4"], + ), + lines.Line2D( + [0], + [0], + marker="o", + color="w", + label="4 <= SI < 6", + markerfacecolor="black", + markersize=SEASONALITY_INDEX["si_6"], + ), + lines.Line2D( + [0], + [0], + marker="o", + color="w", + label="6 <= SI <= 12", + markerfacecolor="black", + markersize=SEASONALITY_INDEX["si_large"], + ), +] + +# Projections to use for the seasonality map. +PROJECTION = ccrs.PlateCarree(central_longitude=0) +PROJECTION_FUNC = ccrs.PlateCarree + +# Month labels for the Y Axis. +MONTHS_Y_AXIS_LABEL = [ + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", +] + + +def plot_seasonality_map(parameter: StreamflowParameter, export_data: np.ndarray): + """Plot the streamflow seasonality map. + + Parameters + ---------- + parameter : StreamflowParameter + The streamflow parameter. + export_data : np.ndarray + The export data. + """ + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + fig.suptitle(parameter.main_title_seasonality_map, x=0.5, y=0.97, fontsize=15) + + _plot_panel_seasonality_map(fig, parameter, "test", export_data) + _plot_panel_seasonality_map(fig, parameter, "ref", export_data) + + # NOTE: Need to set the output filename to the name of the specific + # streamflow plot before saving the plot, otherwise the filename will + # be blank. + parameter.output_file = parameter.output_file_seasonality_map + _save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING) + + plt.close() + + +def _plot_panel_seasonality_map( + fig: plt.Figure, + parameter: StreamflowParameter, + plot_type: str, + export_data: np.ndarray, +): + """Plot the panel each seasonality map. + + Parameters + ---------- + fig : plt.Figure + The figure object. + parameter : StreamflowParameter + The parameter. + plot_type : str + The plot type. + export_data : np.ndarray + The export data. + """ + if plot_type == "test": + panel_idx = 0 + seasonality_idx = 5 + peak_month_idx = 6 + title = (None, parameter.test_title, None) + elif plot_type == "ref": + panel_idx = 1 + seasonality_idx = 3 + peak_month_idx = 4 + title = (None, parameter.reference_title, None) + + # Get region info and X and Y plot ticks. + # -------------------------------------------------------------------------- + region_key = parameter.regions[0] + region_specs = REGION_SPECS[region_key] + + # Get the region's domain slices for latitude and longitude if set, or + # use the default value. If both are not set, then the region type is + # considered "global". + lat_slice = region_specs.get("lat", (-90, 90)) # type: ignore + lon_slice = region_specs.get("lon", (-180, 180)) # type: ignore + + # Boolean flags for configuring plots. + is_global_domain = lat_slice == (-90, 90) and lon_slice == (-180, 180) + is_lon_full = lon_slice == (-180, 180) + + # Determine X and Y ticks using longitude and latitude domains respectively. + lon_west, lon_east = lon_slice + x_ticks = _get_x_ticks( + lon_west, + lon_east, + is_global_domain, + is_lon_full, + tick_step_func=_determine_tick_step, + ) + + lat_south, lat_north = lat_slice + y_ticks = _get_y_ticks(lat_south, lat_north, tick_step_func=_determine_tick_step) + + # Get the figure Axes object using the projection above and configure the + # aspect ratio, coastlines, and add RIVERS. + # -------------------------------------------------------------------------- + ax = fig.add_axes(PANEL_CFG[panel_idx], projection=PROJECTION) + ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=PROJECTION) + ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) + ax.coastlines(lw=0.3) + ax.add_feature(cfeature.RIVERS) + + # Plot the streamflow gauges. + # -------------------------------------------------------------------------- + _plot_gauges(ax, export_data, seasonality_idx, peak_month_idx) + + # Configure legend. + # -------------------------------------------------------------------------- + plt.legend(handles=LEGEND_ELEMENTS, title="Seasonality (SI)", prop={"size": 8}) + + # Configure the titles, x and y axes. + # -------------------------------------------------------------------------- + _configure_titles(ax, title) + _configure_x_and_y_axes( + ax, x_ticks, y_ticks, ccrs.PlateCarree(), parameter.current_set + ) + # Configure the colorbar. + # -------------------------------------------------------------------------- + cbax = fig.add_axes( + ( + PANEL_CFG[panel_idx][0] + 0.7535, + PANEL_CFG[panel_idx][1] + 0.0515, + 0.0326, + 0.1792, + ) + ) + + cmap = colors.ListedColormap(COLOR_LIST) + + # Set ticks to be in between the bounds + num_colors = len(COLOR_LIST) + bounds = list(range(num_colors)) + ticks = list(map(lambda bound: bound + 0.5, bounds)) + + # Add one more bound at the bottom of the colorbar. + # `bounds` should be one longer than `ticks`. + bounds += [bounds[-1] + 1] + norm = colors.BoundaryNorm(bounds, cmap.N) + cbar = fig.colorbar( + matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm), + cax=cbax, + boundaries=bounds, + ticks=ticks, + spacing="uniform", + orientation="vertical", + label="Peak month", + ) + + cbar.ax.set_yticklabels(MONTHS_Y_AXIS_LABEL) + cbar.ax.invert_yaxis() + + cbar.ax.tick_params(labelsize=9.0, length=0) + + +def _plot_gauges( + ax: plt.axes, + export_data: np.ndarray, + seasonality_idx: int, + peak_month_idx: int, +): + """Plot the streamflow gauges. + + This function plots each each gauge as a single marker point. + + Parameters + ---------- + ax : plt.axes + The matplotlib axes object. + export : np.ndarray + An array of gauges, with each gauge having multiple fields (e.g., lat is + index 7). + seasonality_idx_export_idx : int + The index of the seasonality based on the export index, which determines + the size of the plot marker for each gauge. + peak_month_export_idx : int + The index of the peak month export that determines the color of the + plot marker for each gauge. + + Raises + ------ + RuntimeError + Invalid seasonality index found. + """ + for gauge in export_data: + lat = gauge[7] + lon = gauge[8] + seasonality_index = gauge[seasonality_idx] + + if seasonality_index < 2: + markersize = SEASONALITY_INDEX["si_2"] + elif seasonality_index < 4: + markersize = SEASONALITY_INDEX["si_4"] + elif seasonality_index < 6: + markersize = SEASONALITY_INDEX["si_6"] + elif seasonality_index <= 12: + markersize = SEASONALITY_INDEX["si_large"] + else: + raise RuntimeError(f"Invalid seasonality index={seasonality_index}") + + if seasonality_index == 1: + color = "black" + else: + peak_month = int(gauge[peak_month_idx]) + color = COLOR_LIST[peak_month] # type: ignore + + plt.plot( + lon, + lat, + marker="o", + color=color, + markersize=markersize, + transform=PROJECTION_FUNC(), + ) + + # NOTE: The "plt.annotate call" does not have a "transform=" keyword, + # so for this one we transform the coordinates with a Cartopy call. + ax.projection.transform_point(lon, lat, src_crs=PROJECTION_FUNC()) + + +def _determine_tick_step(degrees_covered: float) -> int: + """Determine the number of tick steps based on the degrees covered by the axis. + + Parameters + ---------- + degrees_covered : float + The degrees covered by the axis. + + Returns + ------- + int + The number of tick steps. + """ + if degrees_covered > 180: + return 60 + if degrees_covered > 60: + return 30 + elif degrees_covered > 20: + return 10 + else: + return 1