diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index af639c1..f7683ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,70 +14,70 @@ on: - cron: '0 0 * * *' # daily jobs: - build: - name: Build py${{ matrix.python-version }} @ ${{ matrix.os }} 🐍 - runs-on: ${{ matrix.os }} - strategy: - matrix: - python-version: ['3.8', '3.9', '3.10'] - os: ['ubuntu-latest'] - ymlfile: ['environment.yml'] - include: - - os: 'windows-latest' - python-version: '3.10' - ymlfile: 'environment.yml' - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - uses: conda-incubator/setup-miniconda@v3 - with: - miniconda-version: 'latest' - auto-update-conda: true - python-version: ${{ matrix.python-version }} - environment-file: ${{ matrix.ymlfile }} - activate-environment: qa4sm_reader # todo: must match with name in environment.yml - auto-activate-base: false - - name: Print environment infos - shell: bash -l {0} - run: | - conda info -a - conda list - pip list - which pip - which python - - name: Export Environment - shell: bash -l {0} - run: | - mkdir -p artifacts - filename=env_py${{ matrix.python-version }}_${{ matrix.os }}.yml - conda env export --no-builds | grep -v "prefix" > artifacts/$filename - - name: Upload Artifacts - uses: actions/upload-artifact@v4 - with: - name: Artifacts-py${{ matrix.python-version }}-${{ matrix.os }} - path: artifacts/* - - name: Install package and test - shell: bash -l {0} - run: | - pip install -e . - pytest - - name: Upload Coverage - shell: bash -l {0} - run: | - pip install coveralls && coveralls --service=github - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COVERALLS_FLAG_NAME: ${{ matrix.python-version }} - COVERALLS_PARALLEL: true - coveralls: - name: Submit Coveralls 👚 - needs: build - runs-on: ubuntu-latest - container: python:3-slim - steps: - - name: Finished - run: | - pip3 install --upgrade coveralls && coveralls --service=github --finish - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + build: + name: Build py${{ matrix.python-version }} @ ${{ matrix.os }} 🐍 + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [ '3.8', '3.9', '3.10' ] + os: ["ubuntu-latest"] + ymlfile: ["environment.yml"] + include: + - os: "windows-latest" + python-version: "3.10" + ymlfile: "environment.yml" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - uses: conda-incubator/setup-miniconda@v3 + with: + miniconda-version: "latest" + auto-update-conda: true + python-version: ${{ matrix.python-version }} + environment-file: ${{ matrix.ymlfile }} + activate-environment: qa4sm_reader # todo: must match with name in environment.yml + auto-activate-base: false + - name: Print environment infos + shell: bash -l {0} + run: | + conda info -a + conda list + pip list + which pip + which python + - name: Export Environment + shell: bash -l {0} + run: | + mkdir -p artifacts + filename=env_py${{ matrix.python-version }}_${{ matrix.os }}.yml + conda env export --no-builds | grep -v "prefix" > artifacts/$filename + - name: Upload Artifacts + uses: actions/upload-artifact@v4 + with: + name: Artifacts-py${{ matrix.python-version }}-${{ matrix.os }} + path: artifacts/* + - name: Install package and test + shell: bash -l {0} + run: | + pip install -e . + pytest + - name: Upload Coverage + shell: bash -l {0} + run: | + pip install coveralls && coveralls --service=github + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COVERALLS_FLAG_NAME: ${{ matrix.python-version }} + COVERALLS_PARALLEL: true + coveralls: + name: Submit Coveralls 👚 + needs: build + runs-on: ubuntu-latest + container: python:3-slim + steps: + - name: Finished + run: | + pip3 install --upgrade coveralls && coveralls --service=github --finish + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 3f3f6f3..377d659 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,7 @@ MANIFEST # tests/test_data tests/test_data/old tests/test_results +tests/test_qa4sm_data # personal testing stuff test.py @@ -72,3 +73,4 @@ tests/test_data/out/* .coverage* .artifacts/* .vscode +.logs diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..c1547af --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +graft src/qa4sm_reader/static diff --git a/environment.yml b/environment.yml index 2f1e630..fe8d979 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,4 @@ dependencies: - pygeogrids - pytest - pytest-cov + - pytesmo diff --git a/setup.cfg b/setup.cfg index b00c85a..01a73c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,8 @@ install_requires = parse scipy pygeogrids - + pytesmo + # The usage of test_requires is discouraged, see `Dependency Management` docs # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/src/qa4sm_reader/comparing.py b/src/qa4sm_reader/comparing.py index a7d72eb..28c82c8 100644 --- a/src/qa4sm_reader/comparing.py +++ b/src/qa4sm_reader/comparing.py @@ -21,9 +21,9 @@ class QA4SMComparison: take some time, the class can be updated keeping memory of what has already been initialized """ def __init__(self, - paths: list or str, + paths: Union[list, str], extent: tuple = None, - get_intersection: bool = True): + get_intersection: bool = True) -> None: """ Initialise the QA4SMImages from the paths to netCDF files specified diff --git a/src/qa4sm_reader/custom_intra_annual_windows_example.json b/src/qa4sm_reader/custom_intra_annual_windows_example.json new file mode 100644 index 0000000..3ad9212 --- /dev/null +++ b/src/qa4sm_reader/custom_intra_annual_windows_example.json @@ -0,0 +1,92 @@ +{ + "seasons": { + "S1": [ + [12, 1], + [2, 28] + ], + "S2": [ + [3, 1], + [5, 31] + ], + "S3": [ + [6, 1], + [8, 31] + ], + "S4": [ + [9, 1], + [11, 30] + ] + }, + "months": { + "Jan": [ + [1, 1], + [1, 31] + ], + "Feb": [ + [2, 1], + [2, 28] + ], + "Mar": [ + [3, 1], + [3, 31] + ], + "Apr": [ + [4, 1], + [4, 30] + ], + "May": [ + [5, 1], + [5, 31] + ], + "Jun": [ + [6, 1], + [6, 30] + ], + "Jul": [ + [7, 1], + [7, 31] + ], + "Aug": [ + [8, 1], + [8, 31] + ], + "Sep": [ + [9, 1], + [9, 30] + ], + "Oct": [ + [10, 1], + [10, 31] + ], + "Nov": [ + [11, 1], + [11, 30] + ], + "Dec": [ + [12, 1], + [12, 31] + ] + }, + "custom": { + "star wars month": [ + [5, 1], + [5, 31] + ], + "halloween season": [ + [10, 1], + [10, 31] + ], + "advent": [ + [12, 1], + [12, 24] + ], + "movember": [ + [11, 1], + [11, 30] + ], + "christmas": [ + [12, 24], + [12, 26] + ] + } +} diff --git a/src/qa4sm_reader/globals.py b/src/qa4sm_reader/globals.py index 5588314..7f132ea 100644 --- a/src/qa4sm_reader/globals.py +++ b/src/qa4sm_reader/globals.py @@ -6,8 +6,10 @@ import warnings import cartopy.crs as ccrs +import matplotlib import matplotlib.colors as cl -import matplotlib.pyplot as plt +import numpy as np +import os # PLOT DEFAULT SETTINGS # ===================================================== @@ -40,16 +42,31 @@ # === boxplot_basic defaults === boxplot_printnumbers = True # Print 'median', 'nObs', 'stdDev' to the boxplot_basic. -boxplot_height = 6 +boxplot_height = 7 #$ increased by 1 to house logo boxplot_width = 2.1 # times (n+1), where n is the number of boxes. boxplot_title_len = 8 * boxplot_width # times the number of boxes. maximum length of plot title in chars. tick_size = 8.5 +#TODO: remove eventually, as watermarlk string no longer needed # === watermark defaults === watermark = u'made with QA4SM (qa4sm.eu)' # Watermark string watermark_pos = 'bottom' # Default position ('top' or 'bottom' or None) watermark_fontsize = 8 # fontsize in points (matplotlib uses 72ppi) -watermark_pad = 5 # padding above/below watermark in points (matplotlib uses 72ppi) +watermark_pad = 50 # padding above/below watermark in points (matplotlib uses 72ppi) + +#$$ +# === watermark logo defaults === +watermark_logo_position = 'lower_center' +watermark_logo_scale = 0.1 # height of the logo relative to the height of the figure +watermark_logo_offset_comp_plots = (0, -0.1) +watermark_logo_offset_metadata_plots = (0, -0.08) +watermark_logo_offset_map_plots = (0, -0.15) +watermark_logo_offset_bar_plots = (0, -0.1) +watermark_logo_offset_box_plots = (0, -0.15) +watermark_logo_pth = os.path.join( + os.path.dirname( + os.path.abspath(__file__)), 'static', 'images', 'logo', + 'QA4SM_logo_long.png') # === filename template === ds_fn_templ = "{i}-{ds}.{var}" @@ -93,28 +110,27 @@ def get_status_colors(): # function to get custom cmap for calculation errors # limited to 14 different error entries to produce distinct colors - cmap = plt.cm.get_cmap('Set3', len(status) - 2) + cmap = cl.ListedColormap(matplotlib.colormaps['Set3'].colors[:len(status) - 2]) colors = [cmap(i) for i in range(cmap.N)] colors.insert(0, (0, 0.66666667, 0.89019608, 1.0)) colors.insert(0, (0.45882353, 0.08235294, 0.11764706, 1.0)) cmap = cl.ListedColormap(colors=colors) return cmap - _cclasses = { - 'div_better': plt.cm.get_cmap( + 'div_better': matplotlib.colormaps[ 'RdYlBu' - ), # diverging: 1 good, 0 special, -1 bad (pearson's R, spearman's rho') - 'div_worse': plt.cm.get_cmap( + ], # diverging: 1 good, 0 special, -1 bad (pearson's R, spearman's rho') + 'div_worse': matplotlib.colormaps[ 'RdYlBu_r' - ), # diverging: 1 bad, 0 special, -1 good (difference of bias) + ], # diverging: 1 bad, 0 special, -1 good (difference of bias) 'div_neutr': - plt.cm.get_cmap('RdYlGn'), # diverging: zero good, +/- neutral: (bias) - 'seq_worse': plt.cm.get_cmap( + matplotlib.colormaps['RdYlGn'], # diverging: zero good, +/- neutral: (bias) + 'seq_worse': matplotlib.colormaps[ 'YlGn_r' - ), # sequential: increasing value bad (p_R, p_rho, rmsd, ubRMSD, RSS) - 'seq_better': plt.cm.get_cmap( - 'YlGn'), # sequential: increasing value good (n_obs, STDerr) + ], # sequential: increasing value bad (p_R, p_rho, rmsd, ubRMSD, RSS) + 'seq_better': matplotlib.colormaps[ + 'YlGn'], # sequential: increasing value good (n_obs, STDerr) 'qua_neutr': get_status_colors(), # qualitative category with 2 forced colors } @@ -311,30 +327,105 @@ def get_metric_units(dataset, raise_error=False): return "n.a." - -# label name for all metrics -_metric_name = { # from /qa4sm/validator/validation/globals.py +COMMON_METRICS = { 'R': 'Pearson\'s r', 'p_R': 'Pearson\'s r p-value', - 'rho': 'Spearman\'s ρ', - 'p_rho': 'Spearman\'s ρ p-value', 'RMSD': 'Root-mean-square deviation', 'BIAS': 'Bias (difference of means)', 'n_obs': '# observations', 'urmsd': 'Unbiased root-mean-square deviation', 'RSS': 'Residual sum of squares', - 'tau': 'Kendall rank correlation', - 'p_tau': 'Kendall tau p-value', 'mse': 'Mean square error', 'mse_corr': 'Mean square error correlation', 'mse_bias': 'Mean square error bias', 'mse_var': 'Mean square error variance', +} + +TC_METRICS = { 'snr': 'Signal-to-noise ratio', 'err_std': 'Error standard deviation', 'beta': 'TC scaling coefficient', +} + +READER_EXCLUSIVE_METRICS = { + 'rho': 'Spearman\'s ρ', + 'p_rho': 'Spearman\'s ρ p-value', + 'tau': 'Kendall rank correlation', + 'p_tau': 'Kendall tau p-value', 'status': 'Validation errors' } +QA4SM_EXCLUSIVE_METRICS = { + 'rho': 'Spearman\'s rho', + 'p_rho': 'Spearman\'s rho p-value', + # 'tau': 'Kendall rank correlation', # currently QA4SM is hardcoded not to calculate kendall tau + # 'p_tau': 'Kendall tau p-value', # needs to be changed once tau is calculated again + 'status': '# status', +} + +_metric_name = {**COMMON_METRICS, **READER_EXCLUSIVE_METRICS, **TC_METRICS} + +METRICS = {**COMMON_METRICS, **QA4SM_EXCLUSIVE_METRICS} + +NON_METRICS = [ + 'gpi', + 'lon', + 'lat', + 'clay_fraction', + 'climate_KG', + 'climate_insitu', + 'elevation', + 'instrument', + 'latitude', + 'lc_2000', + 'lc_2005', + 'lc_2010', + 'lc_insitu', + 'longitude', + 'network', + 'organic_carbon', + 'sand_fraction', + 'saturation', + 'silt_fraction', + 'station', + 'timerange_from', + 'timerange_to', + 'variable', + 'instrument_depthfrom', + 'instrument_depthto', + 'frm_class', +] + +METADATA_TEMPLATE = { + 'other_ref': None, + 'ismn_ref': { + 'clay_fraction': np.float32([np.nan]), + 'climate_KG': np.array([' ' * 256]), + 'climate_insitu': np.array([' ' * 256]), + 'elevation': np.float32([np.nan]), + 'instrument': np.array([' ' * 256]), + 'latitude': np.float32([np.nan]), + 'lc_2000': np.float32([np.nan]), + 'lc_2005': np.float32([np.nan]), + 'lc_2010': np.float32([np.nan]), + 'lc_insitu': np.array([' ' * 256]), + 'longitude': np.float32([np.nan]), + 'network': np.array([' ' * 256]), + 'organic_carbon': np.float32([np.nan]), + 'sand_fraction': np.float32([np.nan]), + 'saturation': np.float32([np.nan]), + 'silt_fraction': np.float32([np.nan]), + 'station': np.array([' ' * 256]), + 'timerange_from': np.array([' ' * 256]), + 'timerange_to': np.array([' ' * 256]), + 'variable': np.array([' ' * 256]), + 'instrument_depthfrom': np.float32([np.nan]), + 'instrument_depthto': np.float32([np.nan]), + # only available for FRM4SM ISMN version(s) + 'frm_class': np.array([' ' * 256]), + } +} + # BACKUPS # ===================================================== # to fallback to in case the dataset attributes in the .nc file are @@ -618,7 +709,7 @@ def get_resolution_info(dataset, raise_error=False): "climate_insitu": ("climate in-situ", climate_classes, "classes", None), "elevation": ("elevation", None, "continuous", "[m]"), "instrument": ("instrument type", None, "discrete", - None), # todo: improve labels (too packed) + None), #todo: improve labels (too packed) "lc_2000": ("land cover class (2000)", lc_classes, "classes", None), "lc_2005": ("land cover class (2005)", lc_classes, "classes", None), "lc_2010": ("land cover class (2010)", lc_classes, "classes", None), @@ -651,3 +742,98 @@ def get_resolution_info(dataset, raise_error=False): 'p_tau', 'status', ] + +METRIC_TEMPLATE = '_between_{ds1}_and_{ds2}' +METRIC_CI_TEMPLATE = '{metric}_ci_{bound}_between_{ds1}_and_{ds2}_{ending}' + + +# intra-annual valdiation metric related settings +# ===================================================== + +DEFAULT_TSW = 'bulk' # default temporal sub-window (in the case of no temporal sub-windowing) +TEMPORAL_SUB_WINDOW_NC_COORD_NAME = 'tsw' # name of the period coordinate in the netcdf file (Temporal Sub-Window) + +TEMPORAL_SUB_WINDOW_SEPARATOR = '|' + +INTRA_ANNUAL_METRIC_TEMPLATE = ["{tsw}", TEMPORAL_SUB_WINDOW_SEPARATOR, + "{metric}"] #$$ + +INTRA_ANNUAL_TCOL_METRIC_TEMPLATE = ["{tsw}", TEMPORAL_SUB_WINDOW_SEPARATOR, + "{metric}", "_", "{number}-{dataset}", + "_between_"] + +# default temporal sub windows +TEMPORAL_SUB_WINDOWS = { + "seasons": { + "S1": [[12, 1], [2, 28]], + "S2": [[3, 1], [5, 31]], + "S3": [[6, 1], [8, 31]], + "S4": [[9, 1], [11, 30]], + }, + "months": { + "Jan": [[1, 1], [1, 31]], + "Feb": [[2, 1], [2, 28]], + "Mar": [[3, 1], [3, 31]], + "Apr": [[4, 1], [4, 30]], + 'May': [[5, 1], [5, 31]], + "Jun": [[6, 1], [6, 30]], + "Jul": [[7, 1], [7, 31]], + "Aug": [[8, 1], [8, 31]], + "Sep": [[9, 1], [9, 30]], + "Oct": [[10, 1], [10, 31]], + "Nov": [[11, 1], [11, 30]], + "Dec": [[12, 1], [12, 31]], + } +} + +CLUSTERED_BOX_PLOT_STYLE = { + 'fig_params': { + 'title_fontsize': 20, + 'y_labelsize': 18, + 'tick_labelsize': 16, + 'legend_fontsize': 12, + }, + 'colors': { + 'Teal Blue': '#00778F', + 'Mustard Yellow': '#FFD166', + 'Sage Green': '#8FB339', + 'Coral Pink': '#EF476F', + 'Steel Gray': '#6A0572', + } +} + +CLUSTERED_BOX_PLOT_SAVENAME = 'comparison_boxplot_{metric}.{filetype}' + + + +# netCDF transcription related settings +# ===================================================== +OLD_NCFILE_SUFFIX = '.old' + +IMPLEMENTED_COMPRESSIONS = ['zlib'] + +ALLOWED_COMPRESSION_LEVELS = [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + +BAD_METRICS = ['time'] + +DATASETS = [ + 'C3S_combined', + 'ISMN', + 'GLDAS', + 'SMAP_L3', + 'ASCAT', + 'ESA_CCI_SM_combined', + 'ESA_CCI_SM_active', + 'ESA_CCI_SM_passive', + 'SMOS_IC', + 'ERA5', + 'ERA5_LAND', + 'CGLS_CSAR_SSM1km', + 'CGLS_SCATSAR_SWI1km', + 'SMOS_L3', + 'SMOS_L2', + 'SMAP_L2', + 'SMOS_SBPCA', +] + +MAX_NUM_DS_PER_VAL_RUN = 6 diff --git a/src/qa4sm_reader/handlers.py b/src/qa4sm_reader/handlers.py index 8e0562d..314b3d7 100644 --- a/src/qa4sm_reader/handlers.py +++ b/src/qa4sm_reader/handlers.py @@ -1,10 +1,16 @@ # -*- coding: utf-8 -*- +from dataclasses import dataclass import warnings from qa4sm_reader import globals from parse import * import warnings as warn import re +from typing import List, Optional, Tuple, Dict, Any, Union + +import matplotlib +import matplotlib.axes +from matplotlib.figure import Figure class MixinVarmeta: @@ -43,7 +49,7 @@ def id(self): else: return self.ref_ds[0] - def get_varmeta(self) -> (tuple, tuple, tuple, tuple): + def get_varmeta(self) -> Tuple[Tuple, Tuple, Tuple, Tuple]: """ Get the datasets from the current variable. Each dataset is provided with shape (id, dict{names}) @@ -128,6 +134,11 @@ def _ref_dc(self) -> int: ref_dc = 0 try: + # print(f'globals._ref_ds_attr: {globals._ref_ds_attr}') + # print(f'self.meta: {self.meta}') + # print( + # f'parse(globals._ds_short_name_attr, val_ref): {parse(globals._ds_short_name_attr, self.meta[globals._ref_ds_attr])}' + # ) val_ref = self.meta[globals._ref_ds_attr] ref_dc = parse(globals._ds_short_name_attr, val_ref)[0] except KeyError as e: @@ -298,7 +309,7 @@ def others(self) -> list: return others_meta - def dataset_metadata(self, id: int, element: str or list = None) -> tuple: + def dataset_metadata(self, id: int, element: Union[str, list] = None) -> tuple: """ Get the metadata for the dataset specified by the id. This function is used by the QA4SMMetricVariable class @@ -415,7 +426,7 @@ def _parse_wrap(self, pattern, g): self.varname) return parse(pattern, self.varname) - def _parse_varname(self) -> (str, int, dict): + def _parse_varname(self) -> Tuple[str, int, dict]: """ Parse the name to get the metric, group and variable data @@ -542,3 +553,23 @@ def has_CIs(self): break return it_does + +#$$ +@dataclass() +class ClusteredBoxPlotContainer: + '''Container for the figure and axes of a clustered boxplot. + See `qa4sm_reader.plotting_methods.figure_template` for usage. + ''' + fig: matplotlib.figure.Figure + ax_box: matplotlib.axes.Axes + ax_median: Optional[matplotlib.axes.Axes] = None + ax_iqr: Optional[matplotlib.axes.Axes] = None + ax_n: Optional[matplotlib.axes.Axes] = None + +#$$ +@dataclass(frozen=True) +class CWContainer: + '''Container for the centers and widths of the boxplots. Used for the plotting of the clustered boxplots.''' + centers: List[float] + widths: List[float] + name: Optional[str] = 'Generic Dataset' diff --git a/src/qa4sm_reader/img.py b/src/qa4sm_reader/img.py index e46bb5d..dbe8e91 100644 --- a/src/qa4sm_reader/img.py +++ b/src/qa4sm_reader/img.py @@ -2,25 +2,13 @@ from qa4sm_reader import globals import qa4sm_reader.handlers as hdl from qa4sm_reader.plotting_methods import _format_floats, combine_soils, combine_depths, average_non_additive - +from qa4sm_reader.utils import transcribe from pathlib import Path import warnings -import numpy as np import xarray as xr import pandas as pd -from typing import Union - - -def extract_periods(filepath) -> np.array: - """Get periods from .nc""" - dataset = xr.open_dataset(filepath) - if globals.period_name in dataset.dims: - return dataset[globals.period_name].values - - else: - return np.array([None]) - +from typing import Union, Tuple, Optional class SpatialExtentError(Exception): """Class to handle errors derived from the spatial extent of validations""" @@ -31,7 +19,7 @@ class QA4SMImg(object): """A tool to analyze the results of a validation, which are stored in a netCDF file.""" def __init__(self, filepath, - period=None, + period=globals.DEFAULT_TSW, extent=None, ignore_empty=True, metrics=None, @@ -46,8 +34,8 @@ def __init__(self, ---------- filepath : str Path to the results netcdf file (as created by QA4SM) - period : Any, optional (default: None) - If results for multiple validation periods are stored in file, + period : str, optional (default: `globals.DEFAULT_TSW`) + if results for multiple validation periods, i.e. multiple temporal sub-windows, are stored in file, load this period. extent : tuple, optional (default: None) Area to subset the values for -> (min_lon, max_lon, min_lat, max_lat) @@ -84,17 +72,36 @@ def __init__(self, except AttributeError: self.ref_dataset_grid_stepsize = 'nan' - def _open_ds(self, extent=None, period=None, engine='h5netcdf'): - """Open .nc as xarray datset, with selected extent""" + def _open_ds(self, extent: Optional[Tuple]=None, period:Optional[str]=globals.DEFAULT_TSW, engine:Optional[str]='h5netcdf') -> xr.Dataset: + """Open .nc as `xarray.Datset`, with selected extent and period. + + Parameters + ---------- + extent : tuple, optional (default: None) + Area to subset the values for -> (min_lon, max_lon, min_lat, max_lat) + period : str, optional (default: `globals.DEFAULT_TSW`) + if results for multiple validation periods, i.e. multiple temporal sub-windows, are stored in file, + load this period. + engine: str, optional (default: h5netcdf) + Engine used by xarray to read data from file. + + Returns + ------- + ds : xarray.Dataset + Dataset with the validation results + """ dataset = xr.load_dataset( self.filepath, drop_variables="time", engine=engine, ) - if period is not None: - ds = dataset.sel(dict(period=period)) - else: - ds = dataset + + if not globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME in dataset.dims: + dataset = transcribe(self.filepath) + + + selection = {globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME: period} # allows for flexible loading of both the dimension and temproal sub-window + ds = dataset.sel(selection) # drop non-spatial variables (e.g.'time') if globals.time_name in ds.variables: ds = ds.drop_vars(globals.time_name) @@ -326,7 +333,7 @@ def group_vars(self, filter_parms: dict): return vars - def group_metrics(self, metrics: list = None) -> (dict, dict, dict): + def group_metrics(self, metrics: list = None) -> Union[None, Tuple[dict, dict, dict]]: """ Load and group all metrics from file @@ -394,7 +401,7 @@ def _ds2df(self, varnames: list = None) -> pd.DataFrame: return df - def metric_df(self, metrics: str or list): + def metric_df(self, metrics: Union[str, list]) -> pd.DataFrame: """ Group all variables for the metric in a common data frame diff --git a/src/qa4sm_reader/intra_annual_temp_windows.py b/src/qa4sm_reader/intra_annual_temp_windows.py new file mode 100644 index 0000000..b8d27ce --- /dev/null +++ b/src/qa4sm_reader/intra_annual_temp_windows.py @@ -0,0 +1,457 @@ +from qa4sm_reader.globals import TEMPORAL_SUB_WINDOWS, DEFAULT_TSW + +from pytesmo.validation_framework.metric_calculators_adapters import TsDistributor +from pytesmo.time_series.grouping import YearlessDatetime + +from typing import Optional, List, Tuple, Dict, Union +from pathlib import Path +from datetime import datetime +import json +from abc import ABC, abstractmethod +import numpy as np + + +class InvalidTemporalSubWindowError(Exception): + '''Exception raised when an invalid temporal sub-window is provided.''' + + def __init__(self, tsw, valid_tsw): + super().__init__( + f'The provided temporal sub-window ({tsw}) is invalid. Please provide one of these valid temporal sub-windows: {valid_tsw}.' + ) + + +class TemporalSubWindowsDefault(ABC): + ''' + Class to load default temporal sub-window definitions from the `validator.validation.globals` file. + Alternatively, the user can provide a custom JSON file containing the definitions. + temporal sub-window definitions are stored in dictionaries with the following structure: + + {"seasons": + {"S1": [[12, 1], [2, 28]] # December 1st to February 28th + "S2": [[3, 1], [5, 31]] # March 1st to May 31st + "S3": [[6, 1], [8, 31]] # June 1st to August 31st + "S4": [[9, 1], [11, 30]] }} # September 1st to November 30th + + These dictionaries will be loaded as properties of the class, so that they can be accessed later on and be treated as the default. + + Parameters + ---------- + + custom_file : str, optional + JSON File containing the temporal sub-window definitions in the same format as in `validator.validation.globals`, by default None. If None, the default as defined in `validator.validation.globals` will be used. + + ''' + + def __init__(self, custom_file: Optional[str] = None) -> None: + self.custom_file = custom_file + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(custom_file={self.custom_file})' + + def __str__(self) -> str: + return f'{self.__class__.__name__}({Path(self.custom_file).name})' + + def _load_json_data(self, json_path: str) -> dict: + '''Reads and loads the JSON file into a dictionary. + + Parameters + ---------- + json_path : str + Path to the JSON file containing the temporal sub-window definitions. + + Returns + ------- + dict + Dictionary containing the default temporal sub-window definitions. + ''' + + with open(json_path, 'r') as f: + return json.load(f) + + @abstractmethod + def _get_available_temp_sub_wndws(self): + pass + + +class NewSubWindow: + """ + Class to store the name and the begin and end date of a new temporal sub-window. + + Parameters + ---------- + name : str + Name of the new temporal sub-window. + begin_date : datetime or YearlessDatetime + Begin date of the new temporal sub-window. + end_date : datetime or YearlessDatetime + End date of the new temporal sub-window. + + """ + + def __init__(self, name: str, begin_date: Union[datetime, + YearlessDatetime], + end_date: Union[datetime, YearlessDatetime]) -> None: + self.name = str(name) + + if type(begin_date) != type(end_date): + raise TypeError( + f"`begin_date` and `end_date` must be of the same type, not '{type(begin_date).__name__}' and '{type(end_date).__name__}'" + ) + + if not isinstance(begin_date, (datetime, YearlessDatetime)): + raise TypeError( + f"`begin_date` must be of type 'datetime' or 'YearlessDatetime', not '{type(begin_date).__name__}'" + ) + else: + self.begin_date = begin_date + if not isinstance(end_date, (datetime, YearlessDatetime)): + raise TypeError( + f"`end_date` must be of type 'datetime' or 'YearlessDatetime', not '{type(end_date).__name__}'" + ) + else: + self.end_date = end_date + + if isinstance(begin_date, datetime) and isinstance( + end_date, datetime) and self.begin_date > self.end_date: + raise ValueError( + f"begin_date ({self.begin_date}) must be before end_date ({self.end_date})" + ) + + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.name}, {self.begin_date}, {self.end_date})' + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(name={self.name}, begin_date={self.begin_date}, end_date={self.end_date})' + + @property + def begin_date_pretty(self) -> str: + """ + Returns the begin date in a pretty format. + + Returns + ------- + str + Pretty formatted begin date. + """ + return self.begin_date.strftime('%Y-%m-%d') + + @property + def end_date_pretty(self) -> str: + """ + Returns the end date in a pretty format. + + Returns + ------- + str + Pretty formatted end date. + """ + return self.end_date.strftime('%Y-%m-%d') + + +class TemporalSubWindowsCreator(TemporalSubWindowsDefault): + '''Class to create custom temporal sub-windows, based on the default definitions. + + Parameters + ---------- + temporal_sub_window_type : str + Type of temporal sub-window to be created. Must be one of the available default types. Officially, "months" and "seasons" are implemented. The user can implement their own defaults, though (see `TemporalSubWindowsDefault`). Default is "months". + overlap : int, optional + Number of days to be added/subtracted to the beginning/end of the temporal sub-window. Default is 0. + custom_file : str, optional + Path to the JSON file containing the temporal sub-window definitions, by default None (meaning the defaults as defined in `validator.validation.globals` will be used) + ''' + + def __init__(self, + temporal_sub_window_type: Optional[str] = 'months', + overlap: Optional[int] = 0, + custom_file: Optional[str] = None): + self.overlap = int(np.round(overlap)) + self.temporal_sub_window_type = temporal_sub_window_type + super().__init__(custom_file=custom_file) + + self.available_temp_sub_wndws = self._get_available_temp_sub_wndws() + if not self.available_temp_sub_wndws: + raise FileNotFoundError( + f'Invalid custom file path. Please provide a valid JSON file containing the temporal sub-window definitions.' + ) + elif self.temporal_sub_window_type not in self.available_temp_sub_wndws: + raise InvalidTemporalSubWindowError(self.temporal_sub_window_type, + self.available_temp_sub_wndws) + raise KeyError( + f'Invalid temporal sub-window type. Available types are: {self.available_temp_sub_wndws}' + ) + + self.custom_temporal_sub_windows = self._custom_temporal_sub_windows() + self.additional_temp_sub_wndws_container = {} + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(temporal_sub_window_type={self.temporal_sub_window_type}, overlap={self.overlap})' + + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.temporal_sub_window_type}, {self.overlap})' + + def __date_to_doy(self, date_tuple: Tuple[int, int]) -> int: + '''Converts a date list [month, day] to a day of the year (doy). Leap years are neglected. + + Parameters + ---------- + date_tuple : List[int] + List containing the month and day of the date to be converted. The year is not required, as it is not used in the conversion. + + Returns + ------- + int + Day of the year (doy) corresponding to the date provided. + ''' + + _doy = YearlessDatetime(date_tuple[0], date_tuple[1]).doy + if _doy > 60: # assume NO leap year + _doy -= 1 + + return _doy + + def __doy_to_date(self, doy: int) -> Tuple[int, int]: + '''Converts a day of the year (doy) to a date tuple (month, day). Leap years are neglected. + + Parameters + ---------- + doy : int + Day of the year (doy) to be converted. + + Returns + ------- + Tuple[int] + Tuple containing the month and day corresponding to the doy provided. + ''' + + date = datetime.strptime(str(doy), '%j') + month = date.strftime('%m') + day = date.strftime('%d') + + return int(month), int(day) + + def __update_date(self, date: Tuple[int, int], + overlap_direction: float) -> Tuple[int, int]: + '''Updates a date tuple (month, day) by adding/subtracting the overlap value to/from it. + + Parameters + ---------- + date : Tuple[int] + Date to be updated. + overlap_direction : float + Direction of the overlap. Must be either -1 or +1. -1: subtract overlap value from date; +1: add overlap value to date. + + Returns + ------- + Tuple[int] + Updated date tuple. + ''' + + overlap_direction = overlap_direction / abs( + overlap_direction) # making sure it's either -1 or +1 + _doy = self.__date_to_doy(date) + _doy += int(overlap_direction * self.overlap) + + if _doy < 1: + _doy = 365 - abs(_doy) + elif _doy > 365: + _doy = _doy - 365 + + return self.__doy_to_date(_doy) + + def _custom_temporal_sub_window(self): + return { + key: (self.__update_date(val[0], overlap_direction=-1), + self.__update_date(val[1], overlap_direction=+1)) + for key, val in self.temporal_sub_windows_dict[ + self.temporal_sub_window_type].items() + } + + def _get_available_temp_sub_wndws(self) -> Union[List[str], None]: + if not self.custom_file: + self.temporal_sub_windows_dict = TEMPORAL_SUB_WINDOWS + return list(self.temporal_sub_windows_dict.keys()) + elif Path(self.custom_file).is_file(): + self.temporal_sub_windows_dict = self._load_json_data( + self.custom_file) + return list(self.temporal_sub_windows_dict.keys()) + else: + return None + + def _custom_temporal_sub_windows(self) -> Dict[str, TsDistributor]: + '''Custom temporal sub-window based, on the default definitions and the overlap value. + + Parameters + ---------- + None + + Returns + ------- + dict[str, TsDistributor] + Dictionary containing the custom temporal sub-window definitions. + ''' + + def tsdistributor(_begin_date: Tuple[int], + _end_date: Tuple[int]) -> TsDistributor: + return TsDistributor(yearless_date_ranges=[(YearlessDatetime( + *_begin_date), YearlessDatetime(*_end_date))]) + + return { + key: tsdistributor(val[0], val[1]) + for key, val in self._custom_temporal_sub_window().items() + } + + def add_temp_sub_wndw( + self, + new_temp_sub_wndw: NewSubWindow, + insert_as_first_wndw: Optional[bool] = False + ) -> Union[None, Dict[str, TsDistributor]]: + '''Adds a new custom temporal sub-window to the existing ones. + + Parameters + ---------- + new_temp_sub_wndw : NewSubWindow + Dataclass containing the name, begin date, and end date of the new temporal sub-window. + insert_as_first_wndw : bool, optional + If True, the new temporal sub-window will be inserted as new first element in `TemporalSubWindowsCreator.custom_temporal_sub_windows`. Default is False. + + Returns + ------- + Union[None, Dict[str, TsDistributor]] + None if the new temp_sub_wndw already exists. Otherwise, the dictionary containing the custom temporal sub-window definitions. + + ''' + + self.additional_temp_sub_wndws_container[ + new_temp_sub_wndw.name] = new_temp_sub_wndw + try: + if new_temp_sub_wndw.name in self.custom_temporal_sub_windows: + print( + f'temporal sub-window "{new_temp_sub_wndw.name}" already exists. Overwriting not possible.\ + Please choose a different name. If you want to overwrite the existing temporal sub-window, \ + use the `overwrite_temp_sub_wndw` method instead.' + ) + return None + elif insert_as_first_wndw: + _new_first_element = { + new_temp_sub_wndw.name: + TsDistributor(date_ranges=[(new_temp_sub_wndw.begin_date, + new_temp_sub_wndw.end_date)]) + } + self.custom_temporal_sub_windows = { + **_new_first_element, + **self.custom_temporal_sub_windows + } + return self.custom_temporal_sub_windows + else: + self.custom_temporal_sub_windows[ + new_temp_sub_wndw.name] = TsDistributor( + date_ranges=[(new_temp_sub_wndw.begin_date, + new_temp_sub_wndw.end_date)]) + return self.custom_temporal_sub_windows + except Exception as e: + print(f'Error: {e}') + return None + + def overwrite_temp_sub_wndw( + self, new_temp_sub_wndw: NewSubWindow + ) -> Union[None, Dict[str, TsDistributor]]: + '''Overwrites an existing temporal sub-window with a new definition. + + Parameters + ---------- + new_temp_sub_wndw : NewSubWindow + Dataclass containing the name, begin date, and end date of the new temporal sub-window. + + Returns + ------- + Union[None, Dict[str, TsDistributor]] + None if the new temp_sub_wndw does not exist. Otherwise, the dictionary containing the custom temporal sub-window definitions. + + ''' + + self.additional_temp_sub_wndws_container[ + new_temp_sub_wndw.name] = new_temp_sub_wndw + try: + if new_temp_sub_wndw.name not in self.custom_temporal_sub_windows: + print( + f'temporal sub-window "{new_temp_sub_wndw.name}" does not exist. Overwriting not possible.\ + Please choose a different name. If you want to add a new temporal sub-window, \ + use the `add_temp_sub_wndw` method instead.') + return None + elif isinstance( + new_temp_sub_wndw.begin_date, datetime + ) and isinstance( + new_temp_sub_wndw.end_date, datetime + ) and new_temp_sub_wndw.begin_date < new_temp_sub_wndw.end_date: + self.custom_temporal_sub_windows[ + new_temp_sub_wndw.name] = TsDistributor( + date_ranges=[(new_temp_sub_wndw.begin_date, + new_temp_sub_wndw.end_date)]) + return self.custom_temporal_sub_windows + elif isinstance(new_temp_sub_wndw.begin_date, + YearlessDatetime) and isinstance( + new_temp_sub_wndw.end_date, YearlessDatetime): + self.custom_temporal_sub_windows[ + new_temp_sub_wndw.name] = TsDistributor( + yearless_date_ranges=[(new_temp_sub_wndw.begin_date, + new_temp_sub_wndw.end_date)]) + return self.custom_temporal_sub_windows + + except Exception as e: + print(f'Error: {e}') + return None + + @property + def names(self) -> List[str]: + '''Returns the names of the temporal sub-windows. + + Parameters + ---------- + None + + Returns + ------- + List[str] + List containing the names of the temporal sub-windows. + ''' + + return list(self.custom_temporal_sub_windows.keys()) + + @property + def metadata(self) -> Dict: + '''Returns the metadata of the temporal sub-windows. + + Parameters + ---------- + None + + Returns + ------- + Dict[str, Union[str, List[str]]] + Dictionary containing the metadata of the temporal sub-windows. + ''' + + def _date_formatter(_date: Tuple[int, int]) -> str: + return f'{_date[0]:02d}-{_date[1]:02d}' + + metadata_dict = { + 'Temporal sub-window type': + self.temporal_sub_window_type, + 'Overlap': + f'{self.overlap} days', + 'Pretty Names [MM-DD]': (', ').join([ + f'{key}: {_date_formatter(val[0])} to {_date_formatter(val[1])}' + for key, val in self._custom_temporal_sub_window().items() + ]) + } + + if self._custom_temporal_sub_window().items() != self.names: + unique_tsws = list( + set(self.names) - + set(self._custom_temporal_sub_window().keys())) + if DEFAULT_TSW in unique_tsws: + metadata_dict[ + DEFAULT_TSW] = f'This is the default case "{DEFAULT_TSW}". The user specified it to range from {self.additional_temp_sub_wndws_container[DEFAULT_TSW].begin_date_pretty} to {self.additional_temp_sub_wndws_container[DEFAULT_TSW].end_date_pretty}. Note: This specified time interval might differ from the actual time interval in which all datasets are available. Refer to the datasets section (https://qa4sm.eu/ui/datasets) for more information.' + + return metadata_dict diff --git a/src/qa4sm_reader/netcdf_transcription.py b/src/qa4sm_reader/netcdf_transcription.py new file mode 100644 index 0000000..e4f501d --- /dev/null +++ b/src/qa4sm_reader/netcdf_transcription.py @@ -0,0 +1,762 @@ +import xarray as xr +import numpy as np +from typing import List, Dict, Optional, Union, Tuple +import os +import calendar +import time +import shutil +import tempfile +import sys +from pathlib import Path + +from qa4sm_reader.intra_annual_temp_windows import TemporalSubWindowsCreator, InvalidTemporalSubWindowError +from qa4sm_reader.globals import METRICS, TC_METRICS, NON_METRICS, METADATA_TEMPLATE, \ + IMPLEMENTED_COMPRESSIONS, ALLOWED_COMPRESSION_LEVELS, \ + INTRA_ANNUAL_METRIC_TEMPLATE, INTRA_ANNUAL_TCOL_METRIC_TEMPLATE, \ + TEMPORAL_SUB_WINDOW_SEPARATOR, DEFAULT_TSW, TEMPORAL_SUB_WINDOW_NC_COORD_NAME, \ + MAX_NUM_DS_PER_VAL_RUN, DATASETS, OLD_NCFILE_SUFFIX + + +class TemporalSubWindowMismatchError(Exception): + '''Exception raised when the temporal sub-windows provided do not match the ones present in the provided netCDF file.''' + + def __init__(self, provided, expected): + super().__init__( + f'The temporal sub-windows provided ({provided}) do not match the ones present in the provided netCDF file ({expected}).' + ) + + + +class Pytesmo2Qa4smResultsTranscriber: + """ + Transcribes (=converts) the pytesmo results netCDF4 file format to a more user friendly format, that + is used by QA4SM. + + Parameters + ---------- + pytesmo_results : str + Path to results netCDF4 written by `qa4sm.validation.validation.check_and_store_results`, which is in the old `pytesmo` format. + intra_annual_slices : Union[None, TemporalSubWindowsCreator] + The temporal sub-windows for the results. Default is None, which means that no temporal sub-windows are + used, but only the 'bulk'. If an instance of `valdiator.validation.TemporalSubWindowsCreator` is provided, + the temporal sub-windows are used as provided by the TemporalSubWindowsCreator instance. + keep_pytesmo_ncfile : Optional[bool] + Whether to keep the original pytesmo results netCDF file. Default is False. \ + If True, the original file is kept and indicated by the suffix `qa4sm_reader.globals.OLD_NCFILE_SUFFIX`. + """ + + def __init__(self, + pytesmo_results: str, + intra_annual_slices: Union[None, + TemporalSubWindowsCreator] = None, + keep_pytesmo_ncfile: Optional[bool] = False): + + self.original_pytesmo_ncfile = str(pytesmo_results) + + # windows workaround + # windows keeps a file lock on the original file, which prevents it from being renamed or deleted + # to circumvent this, the file is copied to a temporary directory and the copy is used instead + + if sys.platform.startswith("win"): + if not isinstance(pytesmo_results, Path): + pytesmo_results = Path(pytesmo_results) + + _tmp_dir = Path(tempfile.mkdtemp()) + tmp_dir = _tmp_dir / pytesmo_results.parent.name + + if not tmp_dir.exists(): + tmp_dir.mkdir() + + new_pytesmo_results = tmp_dir / pytesmo_results.name + shutil.copy(pytesmo_results, new_pytesmo_results) + pytesmo_results = str(new_pytesmo_results) + + self.pytesmo_ncfile = f'{pytesmo_results}' + if not Path(pytesmo_results).is_file(): + self.exists = False + raise FileNotFoundError( + f'\n\nFile {pytesmo_results} not found. Please provide a valid path to a pytesmo results netCDF file.' + ) + return None + else: + self.exists = True + + # make sure the intra-annual slices from the argument are the same as the ones present in the pytesmo results + pytesmo_results_tsws = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + pytesmo_results) + if isinstance(intra_annual_slices, TemporalSubWindowsCreator): + self.provided_tsws = intra_annual_slices.names + elif intra_annual_slices is None: + self.provided_tsws = intra_annual_slices + else: + raise InvalidTemporalSubWindowError(intra_annual_slices, + ['months', 'seasons']) + + if self.provided_tsws != pytesmo_results_tsws: + print( + f'The temporal sub-windows provided ({self.provided_tsws}) do not match the ones present in the provided netCDF file ({pytesmo_results_tsws}).' + ) + raise TemporalSubWindowMismatchError(self.provided_tsws, + pytesmo_results_tsws) + + self.intra_annual_slices: Union[ + None, TemporalSubWindowsCreator] = intra_annual_slices + self._temporal_sub_windows: Union[ + None, TemporalSubWindowsCreator] = intra_annual_slices + + self._default_non_metrics: List[str] = NON_METRICS + + self.METADATA_TEMPLATE: Dict[str, Union[None, Dict[str, Union[ + np.ndarray, np.float32, np.array]]]] = METADATA_TEMPLATE + + self.temporal_sub_windows_checker_called: bool = False + self.only_default_case: bool = True + + with xr.open_dataset(pytesmo_results) as pr: + self.pytesmo_results: xr.Dataset = pr + + self.keep_pytesmo_ncfile = keep_pytesmo_ncfile + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(pytesmo_results="{self.pytesmo_ncfile}", intra_annual_slices={self.intra_annual_slices.__repr__()})' + + def __str__(self) -> str: + return f'{self.__class__.__name__}("{Path(self.pytesmo_ncfile).name}", {self.intra_annual_slices})' + + def temporal_sub_windows_checker( + self) -> Tuple[bool, Union[List[str], None]]: + """ + Checks the temporal sub-windows and returns which case of temporal sub-window is used, as well as a list of the + temporal sub-windows. Keeps track of whether the method has been called before. + + Returns + ------- + Tuple[bool, Union[List[str], None]] + A tuple indicating the temporal sub-window type and the list of temporal sub-windows. + bulk case: (True, [`globals.DEFAULT_TSW`]), + intra-annual windows case: (False, list of temporal sub-windows) + """ + + self.temporal_sub_windows_checker_called = True + if self.intra_annual_slices is None: + return True, [DEFAULT_TSW] + elif isinstance(self.intra_annual_slices, TemporalSubWindowsCreator): + return False, self.provided_tsws + else: + raise InvalidTemporalSubWindowError(self.intra_annual_slices) + + @property + def non_metrics_list(self) -> List[str]: + """ + Get the non-metrics from the pytesmo results. + + Returns + ------- + List[str] + A list of non-metric names. + + Raises + ------ + None + """ + + non_metrics_lst = [] + for non_metric in self._default_non_metrics: + if non_metric in self.pytesmo_results: + non_metrics_lst.append(non_metric) + # else: + # print( + # f'Non-metric \'{non_metric}\' not contained in pytesmo results. Skipping...' + # ) + # continue + return non_metrics_lst + + def is_valid_metric_name(self, metric_name): + """ + Checks if a given metric name is valid, based on the defined `globals.INTRA_ANNUAL_METRIC_TEMPLATE`. + + Parameters: + metric_name (str): The metric name to be checked. + + Returns: + bool: True if the metric name is valid, False otherwise. + """ + valid_prefixes = [ + "".join( + template.format(tsw=tsw, metric=metric) + for template in INTRA_ANNUAL_METRIC_TEMPLATE) + for tsw in self.provided_tsws for metric in METRICS + ] + return any(metric_name.startswith(prefix) for prefix in valid_prefixes) + + def is_valid_tcol_metric_name(self, tcol_metric_name): + """ + Checks if a given metric name is a valid TCOL metric name, based on the defined `globals.INTRA_ANNUAL_TCOL_METRIC_TEMPLATE`. + + Parameters: + tcol_metric_name (str): The metric name to be checked. + + Returns: + bool: True if the metric name is valid, False otherwise. + """ + valid_prefixes = [ + "".join( + template.format( + tsw=tsw, metric=metric, number=number, dataset=dataset) + for template in INTRA_ANNUAL_TCOL_METRIC_TEMPLATE) + for tsw in self.provided_tsws for metric in TC_METRICS + for number in range(MAX_NUM_DS_PER_VAL_RUN) for dataset in DATASETS + ] + return any( + tcol_metric_name.startswith(prefix) for prefix in valid_prefixes) + + @property + def metrics_list(self) -> List[str]: + """Get the metrics dictionary. Whole procedure based on the premise, that metric names of valdiations of intra-annual + temporal sub-windows are of the form: `metric_long_name = 'intra_annual_window{validator.validation.globals.TEMPORAL_SUB_WINDOW_SEPARATOR}metric_short_name'`. If this is not the + case, it is assumed the 'bulk' case is present and the metric names are assumed to be the same as the metric + short names. + + Returns + ------- + Dict[str, str] + The metrics dictionary. + """ + + # check if the metric names are of the form: `metric_long_name = 'intra_annual_window{TEMPORAL_SUB_WINDOW_SEPARATOR}metric_short_name'` and if not, assume the 'bulk' case + + _metrics = [ + metric for metric in self.pytesmo_results + if self.is_valid_metric_name(metric) + or self.is_valid_tcol_metric_name(metric) + ] + + if len(_metrics) != 0: # intra-annual case + return list(set(_metrics)) + else: # bulk case + return [ + long for long in self.pytesmo_results + if long not in self.non_metrics_list + ] + + def get_pytesmo_attrs(self) -> None: + """ + Get the attributes of the pytesmo results and add them to the transcribed dataset. + """ + for attr in self.pytesmo_results.attrs: + self.transcribed_dataset.attrs[attr] = self.pytesmo_results.attrs[ + attr] + + def handle_n_obs(self) -> None: + """ + Each data variable of the flavor 'n_obs_between_*' contains the same data. Thus, only one is kept and renamned\ + to plain 'n_obs'. + """ + + _n_obs_vars = sorted( + [var for var in self.transcribed_dataset if 'n_obs' in var]) + + if _n_obs_vars[0] != 'n_obs': + self.transcribed_dataset = self.transcribed_dataset.drop_vars( + _n_obs_vars[1:]) + self.transcribed_dataset = self.transcribed_dataset.rename( + {_n_obs_vars[0]: 'n_obs'}) + + def drop_obs_dim(self) -> None: + """ + Drops the 'obs' dimension from the transcribed dataset, if it exists. + """ + if 'obs' in self.transcribed_dataset.dims: + self.transcribed_dataset = self.transcribed_dataset.drop_dims( + 'obs') + + @staticmethod + def update_dataset_var(ds: xr.Dataset, var: str, coord_key: str, + coord_val: str, data_vals: List) -> xr.Dataset: + """ + Update a variable of given coordinate in the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to be updated. + var : str + The variable to be updated. + coord_key : str + The name of the coordinate of the variable to be updated. + coord_val : str + The value of the coordinate of the variable to be updated. + data_vals : List + The data to be updated. + + Returns + ------- + xr.Dataset + The updated dataset. + """ + + ds[var] = ds[var].copy( + ) # ugly, but necessary, as xr.Dataset objects are immutable + ds[var].loc[{coord_key: coord_val}] = data_vals + + return ds + + def get_transcribed_dataset(self) -> xr.Dataset: + """ + Get the transcribed dataset, containing all metric and non-metric data provided by the pytesmo results. Metadata + is not yet included. + + + Returns + ------- + xr.Dataset + The transcribed, metadata-less dataset. + """ + self.only_default_case, self.provided_tsws = self.temporal_sub_windows_checker( + ) + + self.pytesmo_results[ + TEMPORAL_SUB_WINDOW_NC_COORD_NAME] = self.provided_tsws + + metric_vars = self.metrics_list + self.transcribed_dataset = xr.Dataset() + + for var_name in metric_vars: + new_name = var_name + if not self.only_default_case: + _tsw, new_name = new_name.split(TEMPORAL_SUB_WINDOW_SEPARATOR) + + if new_name not in self.transcribed_dataset: + # takes the data associated with the metric new_name and adds it as a new variabel + # more precisely, it assigns copies of this data to each temporal sub-window, which is the new dimension + self.transcribed_dataset[new_name] = self.pytesmo_results[ + var_name].expand_dims( + { + TEMPORAL_SUB_WINDOW_NC_COORD_NAME: + self.provided_tsws + }, + axis=-1) + else: + # the variable already exists, but we need to update it with the real data (as opposed to the data of the first temporal sub-window, which is the default behaviour of expand_dims()) + self.transcribed_dataset = Pytesmo2Qa4smResultsTranscriber.update_dataset_var( + ds=self.transcribed_dataset, + var=new_name, + coord_key=TEMPORAL_SUB_WINDOW_NC_COORD_NAME, + coord_val=_tsw, + data_vals=self.pytesmo_results[var_name].data) + + # Copy attributes from the original variable to the new variable + self.transcribed_dataset[new_name].attrs = self.pytesmo_results[ + var_name].attrs + + # Add non-metric variables directly + self.transcribed_dataset = self.transcribed_dataset.merge( + self.pytesmo_results[self.non_metrics_list]) + + self.get_pytesmo_attrs() + self.handle_n_obs() + self.drop_obs_dim() + + self.transcribed_dataset[ + TEMPORAL_SUB_WINDOW_NC_COORD_NAME].attrs = dict( + long_name="temporal sub-window", + standard_name="temporal sub-window", + units="1", + valid_range=[0, len(self.provided_tsws)], + axis="T", + description="temporal sub-window name for the dataset", + temporal_sub_window_type="No temporal sub-windows used" + if self.only_default_case is True else self. + _temporal_sub_windows.metadata['Temporal sub-window type'], + overlap="No temporal sub-windows used" + if self.only_default_case is True else + self._temporal_sub_windows.metadata['Overlap'], + intra_annual_window_definition="No temporal sub-windows used" + if self.only_default_case is True else + self._temporal_sub_windows.metadata['Pretty Names [MM-DD]'], + ) + + try: + _dict = { + 'attr_name': DEFAULT_TSW, + 'attr_value': self._temporal_sub_windows.metadata[DEFAULT_TSW] + } + self.transcribed_dataset[ + TEMPORAL_SUB_WINDOW_NC_COORD_NAME].attrs.update( + {_dict['attr_name']: _dict['attr_value']}) + except AttributeError: + pass + + self.pytesmo_results.close() + + return self.transcribed_dataset + + def build_outname(self, root: str, keys: List[Tuple[str]]) -> str: + """ + Build the output name for the NetCDF file. Slight alteration of the original function from pytesmo + `pytesmo.validation_framework.results_manager.build_filename`. + + Parameters + ---------- + root : str + The root path, where the file is to be written to. + keys : List[Tuple[str]] + The keys of the pytesmo results. + + Returns + ------- + str + The output name for the NetCDF file. + + """ + + ds_names = [] + for key in keys: + for ds in key: + if isinstance(ds, tuple): + ds_names.append(".".join(list(ds))) + else: + ds_names.append(ds) + + fname = "_with_".join(ds_names) + ext = "nc" + if len(str(Path(root) / f"{fname}.{ext}")) > 255: + ds_names = [str(ds[0]) for ds in key] + fname = "_with_".join(ds_names) + + if len(str(Path(root) / f"{fname}.{ext}")) > 255: + fname = "validation" + self.outname = Path(root) / f"{fname}.{ext}" + return self.outname + + def write_to_netcdf(self, + path: str, + mode: Optional[str] = 'w', + format: Optional[str] = 'NETCDF4', + engine: Optional[str] = 'netcdf4', + encoding: Optional[dict] = None, + compute: Optional[bool] = True, + **kwargs) -> str: + """ + Write the transcribed dataset to a NetCDF file, based on `xarray.Dataset.to_netcdf` + + Parameters + ---------- + path : str + The path to write the NetCDF file + mode : Optional[str], optional + The mode to open the NetCDF file, by default 'w' + format : Optional[str], optional + The format of the NetCDF file, by default 'NETCDF4' + engine : Optional[str], optional + The engine to use, by default 'netcdf4' + encoding : Optional[dict], optional + The encoding to use, by default {"zlib": True, "complevel": 5} + compute : Optional[bool], optional + Whether to compute the dataset, by default True + **kwargs : dict + Keyword arguments passed to `xarray.Dataset.to_netcdf`. + + Returns + ------- + str + The path to the NetCDF file. + """ + # Default encoding applied to all variables + if encoding is None: + encoding = {} + for var in self.transcribed_dataset.variables: + if not np.issubdtype(self.transcribed_dataset[var].dtype, + np.object_): + encoding[str(var)] = {'zlib': True, 'complevel': 1} + else: + encoding[str(var)] = {'zlib': False} + + try: + self.pytesmo_results.close() + Path(self.original_pytesmo_ncfile).rename( + self.original_pytesmo_ncfile + OLD_NCFILE_SUFFIX) + except PermissionError as e: + shutil.copy(self.original_pytesmo_ncfile, + self.original_pytesmo_ncfile + OLD_NCFILE_SUFFIX) + + if not self.keep_pytesmo_ncfile: + retry_count = 5 + for i in range(retry_count): + try: + self.pytesmo_results.close() + Path(self.original_pytesmo_ncfile + + OLD_NCFILE_SUFFIX).unlink() + break + except PermissionError: + if i < retry_count - 1: + time.sleep(1) + + for var in self.transcribed_dataset.data_vars: + # Check if the data type is Unicode (string type) + if self.transcribed_dataset[var].dtype.kind == 'U': + # Find the maximum string length in this variable + max_len = self.transcribed_dataset[var].str.len().max().item() + + # Create a character array of shape (n, max_len), where n is the number of strings + char_array = np.array([ + list(s.ljust(max_len)) + for s in self.transcribed_dataset[var].values + ], + dtype=f'S1') + + # Create a new DataArray for the character array with an extra character dimension + self.transcribed_dataset[var] = xr.DataArray( + char_array, + dims=(self.transcribed_dataset[var].dims[0], + f"{var}_char"), + coords={ + self.transcribed_dataset[var].dims[0]: + self.transcribed_dataset[var].coords[ + self.transcribed_dataset[var].dims[0]] + }, + attrs=self.transcribed_dataset[var]. + attrs # Preserve original attributes if needed + ) + + # Ensure the dataset is closed + if isinstance(self.transcribed_dataset, xr.Dataset): + self.transcribed_dataset.close() + + # Write the transcribed dataset to a new NetCDF file + self.transcribed_dataset.to_netcdf( + path=path, + mode=mode, + encoding=encoding, + ) + + return path + + def compress(self, + path: str, + compression: str = 'zlib', + complevel: int = 5) -> None: + """ + Opens the generated results netCDF file and writes to a new netCDF file with new compression parameters. The smaller of both files is then deleted and the remainign one named according to the original file. + + Parameters + ---------- + fpath: str + Path to the netCDF file to be re-compressed. + compression: str + Compression algorithm to be used. Currently only 'zlib' is implemented. + complevel: int + Compression level to be used. The higher the level, the better the compression, but the longer it takes. + + Returns + ------- + None + """ + + if compression in IMPLEMENTED_COMPRESSIONS and complevel in ALLOWED_COMPRESSION_LEVELS: + + def encoding_params(ds: xr.Dataset, compression: str, + complevel: int) -> dict: + return { + str(var): { + compression: True, + 'complevel': complevel + } + for var in ds.variables + if not np.issubdtype(ds[var].dtype, np.object_) + } + + try: + with xr.open_dataset(path) as ds: + parent_dir = Path(path).parent + file = Path(path).name + re_name = parent_dir / f're_{file}' + ds.to_netcdf(re_name, + mode='w', + format='NETCDF4', + encoding=encoding_params( + ds, compression, complevel)) + print(f'\n\nRe-compression finished\n\n') + + # for small initial files, the re-compressed file might be larger than the original + # delete the larger file and keep the smaller; rename the smaller file to the original name if necessary + fpath_size = os.path.getsize(path) + re_name_size = os.path.getsize(re_name) + + if fpath_size < re_name_size: + Path(re_name).unlink() + else: + Path(path).unlink() + Path(re_name).rename(path) + + return True + + except Exception as e: + print( + f'\n\nRe-compression failed. {e}\nContinue without re-compression.\n\n' + ) + return False + + else: + raise NotImplementedError( + f'\n\nRe-compression failed. Compression has to be {IMPLEMENTED_COMPRESSIONS} and compression levels other than {ALLOWED_COMPRESSION_LEVELS} are not supported. Continue without re-compression.\n\n' + ) + + @staticmethod + def get_tsws_from_qa4sm_ncfile(ncfile: str) -> Union[List[str], None]: + """ + Get the temporal sub-windows from a proper QA4SM NetCDF file. + + Parameters + ---------- + ncfile : str + The path to the NetCDF file. + + Returns + ------- + List[str] + The temporal sub-windows. + """ + + with xr.open_dataset(ncfile) as ds: + try: + return list(ds[TEMPORAL_SUB_WINDOW_NC_COORD_NAME].values) + except KeyError: + return None + + @staticmethod + def get_tsws_from_pytesmo_ncfile(ncfile: str) -> Union[List[str], None]: + """ + Get the temporal sub-windows from a pytesmo NetCDF file. + + **ATTENTION**: Only retrieves the temporal sub-windows if they are explicitly stated in the data variable names \ + present in the file. An implicit presence of the bulk case in pytesmo files is not detected. + + Parameters + ---------- + ncfile : str + The path to the NetCDF file. + + Returns + ------- + List[str] + The temporal sub-windows. + """ + + with xr.open_dataset(ncfile) as ds: + try: + out = list({ + data_var.split(TEMPORAL_SUB_WINDOW_SEPARATOR)[0] + for data_var in list(ds.data_vars) + if TEMPORAL_SUB_WINDOW_SEPARATOR in data_var + and any([metric in data_var for metric in METRICS]) + }) + if not out: + return None + return out + + except KeyError: + return None + + @staticmethod + def get_tsws_from_ncfile(ncfile: str) -> Union[List[str], None]: + """ + Get the temporal sub-windows from a QA4SM or pytesmo NetCDF file. + + **ATTENTION**: Only retrieves the temporal sub-windows if they are explicitly stated in the data variable names \ + present in the file. An implicit presence of the bulk case is not detected. + + Parameters + ---------- + ncfile : str + The path to the NetCDF file. + + Returns + ------- + Union[List[str], None] + A list of temporal sub-windows or None if the file does not contain any. + """ + + def sort_tsws(tsw_list: List[str]) -> List[str]: + '''Sort the temporal sub-windows in the order of the calendar months, the seasons, \ + and the custom temporal sub-windows. Only sorts if temporal sub-windows of only one \ + kind are present; + + Parameters + ---------- + tsw_list : List[str] + The list of temporal sub-windows. + + Returns + ------- + List[str] + The sorted list of temporal sub-windows. + ''' + if not tsw_list: + return tsw_list + + bulk_present = DEFAULT_TSW in tsw_list + if bulk_present: + tsw_list.remove(DEFAULT_TSW) + + month_order = { + month: index + for index, month in enumerate(calendar.month_abbr) if month + } + seasons_1_order = {f'S{i}': i - 1 for i in range(1, 5)} + seasons_2_order = { + season: index + for index, season in enumerate(['DJF', 'MAM', 'JJA', 'SON']) + } + + def get_custom_tsws(tsw_list): + customs = [ + tsw for tsw in tsw_list + if tsw not in month_order and tsw not in seasons_1_order + and tsw not in seasons_2_order + ] + return customs, list(set(tsw_list) - set(customs)) + + custom_tsws, tsw_list = get_custom_tsws(tsw_list) + lens = {len(tsw) for tsw in tsw_list} + + if lens == {2} and all( + tsw.startswith('S') + for tsw in tsw_list): # seasons like S1, S2, S3, S4 + _presorted = sorted(tsw_list, key=seasons_1_order.get) + + elif lens == {3} and all( + tsw in seasons_2_order + for tsw in tsw_list): # seasons like DJF, MAM, JJA, SON + _presorted = sorted(tsw_list, key=seasons_2_order.get) + + elif lens == {3} and all(tsw.isalpha() + for tsw in tsw_list) and all( + tsw in month_order for tsw in tsw_list + ): # months like Jan, Feb, Mar, Apr, ... + _presorted = sorted(tsw_list, key=month_order.get) + + else: + _presorted = tsw_list + + return ([DEFAULT_TSW] + if bulk_present else []) + _presorted + custom_tsws + + tsws = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_qa4sm_ncfile( + ncfile) + if not tsws: + tsws = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_pytesmo_ncfile( + ncfile) + return sort_tsws(tsws) + + +if __name__ == '__main__': + pth = '/tmp/qa4sm/basic/0-ISMN.soil moisture_with_1-C3S.sm.nc' + + transcriber = Pytesmo2Qa4smResultsTranscriber(pytesmo_results=pth, + intra_annual_slices=None, + keep_pytesmo_ncfile=True) + ds = transcriber.get_transcribed_dataset() + print('writing to netcdf') + transcriber.write_to_netcdf( + path='/tmp/qa4sm/basic/0-ISMN.soil moisture_with_1-C3S.sm.nc.new') diff --git a/src/qa4sm_reader/plot_all.py b/src/qa4sm_reader/plot_all.py index 33a872f..1b3f9e0 100644 --- a/src/qa4sm_reader/plot_all.py +++ b/src/qa4sm_reader/plot_all.py @@ -1,15 +1,21 @@ -# -*- coding: utf-8 -*- +# %% +# # -*- coding: utf-8 -*- import os -import warnings -from typing import Union +from typing import Union, List, Tuple, Dict +from itertools import chain import pandas as pd -from qa4sm_reader.plotter import QA4SMPlotter -from qa4sm_reader.img import QA4SMImg, extract_periods +from qa4sm_reader.plotter import QA4SMPlotter, QA4SMCompPlotter +from qa4sm_reader.img import QA4SMImg +from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber import qa4sm_reader.globals as globals +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path def plot_all(filepath: str, + temporal_sub_windows: List[str] = None, metrics: list = None, extent: tuple = None, out_dir: str = None, @@ -18,7 +24,7 @@ def plot_all(filepath: str, save_metadata: Union[str, bool] = 'never', save_csv: bool = True, engine: str = 'h5netcdf', - **plotting_kwargs) -> tuple: + **plotting_kwargs) -> Tuple[List[str], List[str], List[str], List[str]]: """ Creates boxplots for all metrics and map plots for all variables. Saves the output in a folder-structure. @@ -27,6 +33,8 @@ def plot_all(filepath: str, ---------- filepath : str path to the *.nc file to be processed. + temporal_sub_windows : List[str], optional (default: None) + List of temporal sub-windows to be processed. If None, all periods present are automatically extracted from the file. metrics : set or list, optional (default: None) metrics to be plotted. If None, all metrics with data are plotted extent : tuple, optional (default: None) @@ -60,7 +68,10 @@ def plot_all(filepath: str, fnames_mapplots: list lists of filenames for created mapplots and boxplots fnames_csv: list + fnames_cbplot: list + list of filenames for created comparison boxplots """ + if isinstance(save_metadata, bool): if not save_metadata: save_metadata = 'never' @@ -76,8 +87,13 @@ def plot_all(filepath: str, # initialise image and plotter fnames_bplot, fnames_mapplot, fnames_csv = [], [], [] - periods = extract_periods(filepath) + if temporal_sub_windows is None: + periods = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(filepath) + else: + periods = np.array(temporal_sub_windows) + for period in periods: + print(f'period: {period}') img = QA4SMImg( filepath, period=period, @@ -96,6 +112,7 @@ def plot_all(filepath: str, for metric in metrics: metric_bplots, metric_mapplots = plotter.plot_metric( metric=metric, + period=period, out_types=out_type, save_all=save_all, **plotting_kwargs) @@ -106,27 +123,50 @@ def plot_all(filepath: str, fnames_mapplot.extend(metric_mapplots) if img.metadata and (save_metadata != 'never'): if save_metadata == 'always': - kwargs = { - 'meta_boxplot_min_samples': 0 - } + kwargs = {'meta_boxplot_min_samples': 0} else: kwargs = { - 'meta_boxplot_min_samples': globals.meta_boxplot_min_samples + 'meta_boxplot_min_samples': + globals.meta_boxplot_min_samples } - fnames_bplot.extend( - plotter.plot_save_metadata( - metric, - out_types=out_type, - **kwargs - )) + if period == globals.DEFAULT_TSW: + kwargs['period'] = globals.DEFAULT_TSW + fnames_bplot.extend( + plotter.plot_save_metadata(metric, + out_types=out_type, + **kwargs)) if save_csv: - out_csv = plotter.save_stats() + out_csv = plotter.save_stats(period=period) fnames_csv.append(out_csv) - return fnames_bplot, fnames_mapplot, fnames_csv + #$$ + # ? move somewhere else? + fnames_cbplot = [] + if isinstance(out_type, str): + out_type = [out_type] + metrics_not_to_plot = list(set(chain(globals._metadata_exclude, globals.metric_groups[3], ['n_obs']))) # metadata, tcol metrics, n_obs + if globals.DEFAULT_TSW in periods and len(periods) > 1: + cbp = QA4SMCompPlotter(filepath) + if not os.path.isdir(os.path.join(out_dir, 'comparison_boxplots')): + os.makedirs(os.path.join(out_dir, 'comparison_boxplots')) + + for available_metric in cbp.metric_kinds_available: + if available_metric in metrics.keys( + ) and available_metric not in metrics_not_to_plot: + spth = [Path(out_dir) / 'comparison_boxplots' / + f'{globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=available_metric, filetype=_out_type)}' + for _out_type in out_type] + _fig = cbp.plot_cbp( + chosen_metric=available_metric, + out_name=spth, + ) + plt.close(_fig) + fnames_cbplot.extend(spth) + + return fnames_bplot, fnames_mapplot, fnames_csv, fnames_cbplot def get_img_stats( diff --git a/src/qa4sm_reader/plotter.py b/src/qa4sm_reader/plotter.py index de169ac..d940337 100644 --- a/src/qa4sm_reader/plotter.py +++ b/src/qa4sm_reader/plotter.py @@ -1,18 +1,28 @@ # -*- coding: utf-8 -*- +import re +from unittest.mock import DEFAULT import warnings from pathlib import Path - +import os +from warnings import warn import pandas as pd -from typing import Union +import xarray as xr +from typing import Generator, Union, List, Tuple, Dict, Optional import numpy as np +import itertools + import matplotlib.pyplot as plt +from matplotlib.pylab import f +import matplotlib +from matplotlib.patches import Rectangle from qa4sm_reader.img import QA4SMImg import qa4sm_reader.globals as globals from qa4sm_reader import plotting_methods as plm - +from qa4sm_reader.plotting_methods import ClusteredBoxPlot, patch_styling from qa4sm_reader.exceptions import PlotterError -from warnings import warn +import qa4sm_reader.handlers as hdl +from qa4sm_reader.utils import note, filter_out_self_combination_tcmetric_vars class QA4SMPlotter: @@ -45,10 +55,11 @@ def __init__(self, image: QA4SMImg, out_dir: str = None): def get_dir(self, out_dir: str) -> Path: """Use output path if specified, otherwise same directory as the one storing the netCDF file""" + # if out_dir and globals.DEFAULT_TSW not in out_dir: if out_dir: out_dir = Path(out_dir) # use directory if specified if not out_dir.exists(): - out_dir.mkdir() # make if not existing + os.makedirs(out_dir) # make if not existing else: out_dir = self.img.filepath.parent # use default otherwise @@ -215,7 +226,7 @@ def _filenames_lut(type: str) -> str: except KeyError: raise PlotterError(f"type '{type}' is not in the lookup table") - def create_title(self, Var, type: str) -> str: + def create_title(self, Var, type: str, period: str = None) -> str: """ Create title of the plot @@ -229,10 +240,11 @@ def create_title(self, Var, type: str) -> str: parts = [globals._metric_name[Var.metric]] parts.extend(self._get_parts_name(Var=Var, type=type)) title = self._titles_lut(type=type).format(*parts) - + if period: + title = f'{period}: {title}' return title - def create_filename(self, Var, type: str) -> str: + def create_filename(self, Var, type: str, period: str = None) -> str: """ Create name of the file @@ -264,6 +276,8 @@ def create_filename(self, Var, type: str) -> str: parts.extend([mds_meta[0], mds_meta[1]['short_name'], Var.metric]) name = name.format(*parts) + if period: + name = f'{period}_{name}' return name @@ -273,7 +287,7 @@ def _yield_values( tc: bool = False, stats: bool = True, mean_ci: bool = True, - ) -> tuple: + ) -> Generator[pd.DataFrame, hdl.MetricVariable, pd.DataFrame]: """ Get iterable with pandas dataframes for all variables of a metric to plot @@ -298,6 +312,9 @@ def _yield_values( Vars = self.img._iter_vars(type="metric", filter_parms={"metric": metric}) + if metric in globals.TC_METRICS: + Vars = filter_out_self_combination_tcmetric_vars(Vars) + for n, Var in enumerate(Vars): values = Var.values[Var.varname] # changes if it's a common-type Var @@ -338,6 +355,7 @@ def _boxplot_definition(self, metric: str, df: pd.DataFrame, type: str, + period: str = None, ci=None, offset=0.07, Var=None, @@ -387,15 +405,22 @@ def _boxplot_definition(self, Var = Var break if not type == "metadata": - title = self.create_title(Var, type=type) + title = self.create_title(Var, type=type, period=period) ax.set_title(title, pad=globals.title_pad) - # add watermark if self.img.has_CIs: offset = 0.08 # offset smaller as CI variables have a larger caption if Var.g == 0: offset = 0.03 # offset larger as common metrics have a shorter caption - if globals.watermark_pos not in [None, False]: - plm.make_watermark(fig, offset=offset) + + # fig.tight_layout() + + plm.add_logo_to_figure(fig = fig, + logo_path = globals.watermark_logo_pth, + position = globals.watermark_logo_position, + offset = globals.watermark_logo_offset_box_plots, + scale = globals.watermark_logo_scale, + ) + return fig, ax @@ -403,6 +428,7 @@ def _barplot_definition(self, metric: str, df: pd.DataFrame, type: str, + period: str = None, Var=None) -> tuple: """ Define parameters of plot @@ -434,15 +460,19 @@ def _barplot_definition(self, Var = Var break - title = self.create_title(Var, type=type) + title = self.create_title(Var, type=type, period=period) ax.set_title(title, pad=globals.title_pad) # add watermark - if globals.watermark_pos not in [None, False]: - plm.make_watermark(fig, for_barplot=True) - - def _save_plot(self, out_name: str, out_types: str = 'png') -> list: + plm.add_logo_to_figure(fig = fig, + logo_path = globals.watermark_logo_pth, + position = globals.watermark_logo_position, + offset = globals.watermark_logo_offset_bar_plots, + scale = globals.watermark_logo_scale, + ) + + def _save_plot(self, out_name: str, out_types: Optional[Union[List[str], str]] = 'png') -> list: """ Save plot with name to self.out_dir @@ -450,8 +480,8 @@ def _save_plot(self, out_name: str, out_types: str = 'png') -> list: ---------- out_name: str name of output file - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' Returns ------- @@ -463,7 +493,7 @@ def _save_plot(self, out_name: str, out_types: str = 'png') -> list: for ext in out_types: fname = self._standard_filename(out_name, out_type=ext) if fname.exists(): - warn('Overwriting file {}'.format(fname.name)) + warn(f'Overwriting file {fname.name}') try: plt.savefig(fname, dpi='figure', bbox_inches='tight') except ValueError: @@ -474,8 +504,9 @@ def _save_plot(self, out_name: str, out_types: str = 'png') -> list: def boxplot_basic(self, metric: str, + period: str = None, out_name: str = None, - out_types: str = 'png', + out_types: Optional[Union[List[str], str]] = 'png', save_files: bool = False, **plotting_kwargs) -> Union[list, None]: """ @@ -489,8 +520,8 @@ def boxplot_basic(self, into one plot. out_name: str name of output file - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_files: bool, optional. Default is False wether to save the file in the output directory plotting_kwargs: arguments for _boxplot_definition function @@ -516,14 +547,17 @@ def boxplot_basic(self, fig, ax = self._boxplot_definition(metric=metric, df=values, type='boxplot_basic', + period=period, ci=ci, Var=Var, **plotting_kwargs) if not out_name: - out_name = self.create_filename(Var, type='boxplot_basic') + out_name = self.create_filename(Var, + type='boxplot_basic', + period=period) # save or return plotting objects if save_files: - fnames = self._save_plot(out_name, out_types=out_types) + fnames.extend(self._save_plot(out_name, out_types=out_types)) plt.close('all') return fnames @@ -533,8 +567,9 @@ def boxplot_basic(self, def boxplot_tc(self, metric: str, + period: str = None, out_name: str = None, - out_types: str = 'png', + out_types: Optional[Union[List[str], str]] = 'png', save_files: bool = False, **plotting_kwargs) -> list: """ @@ -548,8 +583,8 @@ def boxplot_tc(self, into one plot. out_name: str name of output file - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_files: bool, optional. Default is False wether to save the file in the output directory plotting_kwargs: arguments for _boxplot_definition function @@ -589,17 +624,19 @@ def boxplot_tc(self, df=df, ci=ci_id, type='boxplot_tc', + period=period, Var=Var, **plotting_kwargs) # save. Below workaround to avoid same names if not out_name: - save_name = self.create_filename(Var, type='boxplot_tc') + save_name = self.create_filename(Var, + type='boxplot_tc', + period=period) else: save_name = out_name # save or return plotting objects if save_files: - fns = self._save_plot(save_name, out_types=out_types) - fnames.extend(fns) + fnames.extend(self._save_plot(save_name, out_types=out_types)) plt.close('all') if save_files: @@ -608,7 +645,8 @@ def boxplot_tc(self, def barplot( self, metric: str, - out_types: str = 'png', + period: str = None, + out_types: Optional[Union[List[str], str]] = 'png', save_files: bool = False, ) -> Union[list, None]: """ @@ -620,8 +658,8 @@ def barplot( ---------- metric : str metric that is collected from the file for all datasets. - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_files: bool, optional. Default is False wether to save the file in the output directory @@ -644,14 +682,16 @@ def barplot( self._barplot_definition(metric=metric, df=values, type='barplot_basic', + period=period, Var=Var) - out_name = self.create_filename(Var, type='barplot_basic') - + out_name = self.create_filename(Var, + type='barplot_basic', + period=period) # save or return plotting objects if save_files: fnames.extend(self._save_plot(out_name, out_types=out_types)) - plt.close('all') + plt.close('all') if fnames: return fnames @@ -659,12 +699,12 @@ def barplot( def mapplot_var( self, Var, - out_name: str = None, - out_types: str = 'png', + period: str = None, + out_types: Optional[Union[List[str], str]] = 'png', save_files: bool = False, compute_dpi: bool = True, **style_kwargs, - ) -> Union[list, tuple]: + ) -> Union[list, Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]]: """ Plots values to a map, using the values as color. Plots a scatterplot for ISMN and a image plot for other input values. @@ -675,8 +715,8 @@ def mapplot_var( Var in the image to make the map for. out_name: str name of output file - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_files: bool, optional. Default is False wether to save the file in the output directory compute_dpi : bool, optional. Default is True. @@ -693,6 +733,7 @@ def mapplot_var( ------- fnames: list of file names with all the extensions """ + fnames = [] ref_meta, mds_meta, other_meta, scl_meta = Var.get_varmeta() metric = Var.metric ref_grid_stepsize = self.img.ref_dataset_grid_stepsize @@ -723,42 +764,55 @@ def mapplot_var( metric=metric, ref_short=ref_meta[1]['short_name'], ref_grid_stepsize=ref_grid_stepsize, - plot_extent= - None, # if None, extent is automatically adjusted (as opposed to img.extent) + plot_extent=None, # if None, extent is automatically adjusted (as opposed to img.extent) scl_short=scl_short, **style_kwargs) # title and plot settings depend on the metric group if Var.varname.startswith('status'): - title = self.create_title(Var=Var, type='mapplot_status') - save_name = self.create_filename(Var=Var, type="mapplot_status") + title = self.create_title(Var=Var, + type='mapplot_status', + period=period) + save_name = self.create_filename(Var=Var, + type="mapplot_status", + period=period) elif Var.g == 0: title = "{} between all datasets".format( globals._metric_name[metric]) - save_name = self.create_filename(Var, type='mapplot_common') + if period: + title = f'{period}: {title}' + save_name = self.create_filename(Var, + type='mapplot_common', + period=period) elif Var.g == 2: - title = self.create_title(Var=Var, type='mapplot_basic') - save_name = self.create_filename(Var, type='mapplot_double') + title = self.create_title(Var=Var, + type='mapplot_basic', + period=period) + save_name = self.create_filename(Var, + type='mapplot_double', + period=period) else: - title = self.create_title(Var=Var, type='mapplot_tc') - save_name = self.create_filename(Var, type='mapplot_tc') - - # overwrite output file name if given - if out_name: - save_name = out_name + title = self.create_title(Var=Var, + type='mapplot_tc', + period=period) + save_name = self.create_filename(Var, + type='mapplot_tc', + period=period) # use title for plot, make watermark ax.set_title(title, pad=globals.title_pad) - if globals.watermark_pos not in [None, False]: - plm.make_watermark(fig, - globals.watermark_pos, - for_map=True, - offset=0.04) + + plm.add_logo_to_figure(fig = fig, + logo_path = globals.watermark_logo_pth, + position = globals.watermark_logo_position, + offset = globals.watermark_logo_offset_map_plots, + scale = globals.watermark_logo_scale, + ) # save file or just return the image if save_files: - fnames = self._save_plot(save_name, out_types=out_types) - + fnames.extend(self._save_plot(save_name, out_types=out_types)) + plt.close('all') return fnames else: @@ -766,7 +820,8 @@ def mapplot_var( def mapplot_metric(self, metric: str, - out_types: str = 'png', + period: str = None, + out_types: Optional[Union[List[str], str]] = 'png', save_files: bool = False, **plotting_kwargs) -> list: """ @@ -776,8 +831,8 @@ def mapplot_metric(self, ---------- metric : str Name of a metric. File is searched for variables for that metric. - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_files: bool, optional. Default is False wether to save the file in the output directory plotting_kwargs: arguments for mapplot function @@ -794,7 +849,7 @@ def mapplot_metric(self, continue if not (np.isnan(Var.values.to_numpy()).all() or Var.is_CI): fns = self.mapplot_var(Var, - out_name=None, + period=period, out_types=out_types, save_files=save_files, **plotting_kwargs) @@ -803,14 +858,15 @@ def mapplot_metric(self, continue if save_files: fnames.extend(fns) - plt.close('all') + plt.close('all') if fnames: return fnames def plot_metric(self, metric: str, - out_types: str = 'png', + period: str = None, + out_types: Optional[Union[List[str], str]] = 'png', save_all: bool = True, **plotting_kwargs) -> tuple: """ @@ -820,8 +876,8 @@ def plot_metric(self, ---------- metric: str name of the metric - out_types: str or list - extensions which the files should be saved in + out_types: str or list of str, Optional + extensions which the files should be saved in. Default is 'png' save_all: bool, optional. Default is True. all plotted images are saved to the output directory plotting_kwargs: arguments for mapplot function. @@ -830,23 +886,30 @@ def plot_metric(self, if Metric.name == 'status': fnames_bplot = self.barplot(metric='status', + period=period, out_types=out_types, save_files=save_all) elif Metric.g == 0 or Metric.g == 2: fnames_bplot = self.boxplot_basic(metric=metric, + period=period, out_types=out_types, save_files=save_all, **plotting_kwargs) elif Metric.g == 3: fnames_bplot = self.boxplot_tc(metric=metric, + period=period, out_types=out_types, save_files=save_all, **plotting_kwargs) - fnames_mapplot = self.mapplot_metric(metric=metric, - out_types=out_types, - save_files=save_all, - **plotting_kwargs) + if period == globals.DEFAULT_TSW: + fnames_mapplot = self.mapplot_metric(metric=metric, + period=period, + out_types=out_types, + save_files=save_all, + **plotting_kwargs) + else: + fnames_mapplot = None return fnames_bplot, fnames_mapplot @@ -884,18 +947,20 @@ def meta_single(self, the boxplot ax : matplotlib.axes.Axes """ + values = [] for data, Var, var_ci in self._yield_values(metric=metric, stats=False, mean_ci=False): values.append(data) - if not values: raise PlotterError(f"No valid values for {metric}") values = pd.concat(values, axis=1) + # override values from metric if df is not None: values = df + # get meta and select only metric values with metadata available meta_values = self.img.metadata[metadata].values.dropna() values = values.reindex(index=meta_values.index) @@ -990,7 +1055,7 @@ def meta_combo( if binned_values is None: raise PlotterError( f"Could not bin metadata {metadata} with function {bin_funct}") - # dictionary with subset values + values_subset = { a_bin: values.reindex(index=binned_values[a_bin].index) for a_bin in binned_values.keys() @@ -1014,20 +1079,27 @@ def plot_metadata(self, metadata: str, metadata_discrete: str = None, save_file: bool = False, - out_types: str = 'png', + out_types: Optional[Union[List[str], str]] = 'png', + period: str = None, **plotting_kwargs): """ Wrapper built around the 'meta_single' or 'meta_combo' functions to produce a metadata-based boxplot of a metric. Parameters - __________ + ---------- metric : str name of metric to plot metadata : str name of metadata to subdivide the metric results metadata_discrete : str name of the metadata of the type 'discrete' + save_file : bool, optional + whether to save the plot to the output directory. Default is False + out_types : str or list of str, optional + extensions which the files should be saved in. Default is 'png' + period: str, optional + temporal sub-window to use Retrun ------ @@ -1035,6 +1107,7 @@ def plot_metadata(self, the boxplot ax : matplotlib.axes.Axes """ + fnames = [] if metadata_discrete is None: fig, ax = self.meta_single(metric=metric, metadata=metadata, @@ -1061,16 +1134,27 @@ def plot_metadata(self, title = self._titles_lut("metadata").format( globals._metric_name[metric], ", ".join(meta_names), self.img.datasets.ref["pretty_title"]) + if period: #$$ + title = f'{period}: {title}' fig.suptitle(title) - plm.make_watermark(fig=fig, offset=0) + fig.subplots_adjust(bottom=0.2) + + plm.add_logo_to_figure(fig = fig, + logo_path = globals.watermark_logo_pth, + position = globals.watermark_logo_position, + offset = globals.watermark_logo_offset_metadata_plots, + scale = globals.watermark_logo_scale, + ) if save_file: out_name = self._filenames_lut("metadata").format( metric, "_and_".join(metadata_tuple)) - out_name = self._save_plot(out_name, out_types=out_types) - - return out_name + if period: + out_name = f'{period}_{out_name}' + fnames.extend(self._save_plot(out_name, out_types=out_types)) + plt.close('all') + return fnames else: return fig, ax @@ -1078,8 +1162,9 @@ def plot_metadata(self, def plot_save_metadata( self, metric, - out_types: str = 'png', + out_types: Optional[Union[List[str], str]] = 'png', meta_boxplot_min_samples: int = 5, + period: str = None, ): """ Plots and saves three metadata boxplots per metric (defined in globals.py): @@ -1092,8 +1177,8 @@ def plot_save_metadata( ---------- metric : str name of metric - out_types: str or list, optional - extensions which the files should be saved in + out_types: str or list of str, optional + extensions which the files should be saved in. Default is 'png' meta_boxplot_min_samples: int, optional minimum number of samples per bin required to plot a metadata boxplot @@ -1108,6 +1193,9 @@ def plot_save_metadata( if metric in globals._metadata_exclude: return filenames + if not period: #$$ + return filenames + for meta_type, meta_keys in globals.out_metadata_plots.items(): try: # the presence of instrument_depth in the out file depends on the ismn release version @@ -1118,7 +1206,9 @@ def plot_save_metadata( *meta_keys, save_file=True, out_types=out_types, - meta_boxplot_min_samples=meta_boxplot_min_samples) + meta_boxplot_min_samples=meta_boxplot_min_samples, + period=period, + ) filenames.extend(outfiles) else: @@ -1131,12 +1221,632 @@ def plot_save_metadata( return filenames - def save_stats(self): + def save_stats(self, period: str = None) -> str: """Saves the summary statistics to a .csv file and returns the name""" table = self.img.stats_df() filename = self._filenames_lut("table") + '.csv' + if period: + filename = f'{period}_{filename}' filepath = self.out_dir.joinpath(filename) - table.to_csv(path_or_buf=filepath) return filepath + +#$$ +class QA4SMCompPlotter: + """ + Class to create plots containing the calculated metric for all temporal sub-window, default case excldued + + Parameters + ---------- + + results_file : str + path to the .nc results file + include_default_case : bool, default is False + whether to include the bulk case in the plots + + Returns + ------- + QA4SMCompPlotter object + """ + + def __init__(self, + results_file: str, + include_default_case: bool = False) -> None: + self.results_file = results_file + self.include_default_case = include_default_case + if os.path.isfile(results_file): + with xr.open_dataset(results_file) as ds: + self.ds: xr.Dataset = ds + self.datasets = hdl.QA4SMDatasets(self.ds.attrs) + self.ref_dataset: Dict = self.datasets.ref + self.candidate_datasets: List[Dict] = self.datasets.others + # self.metrics_in_ds = self.__ds_metrics() + self.metric_kinds_available: List = list( + self.metrics_in_ds.keys()) + self.metric_lut: Dict = self.metrics_ds_grouped_lut( + include_ci=False) + # self.df = self._ds2df() + # self.check_for_unexpecetd_metrics() + + self.cbp: ClusteredBoxPlot = ClusteredBoxPlot( + anchor_list=np.linspace(1, len(self.tsws_used), + len(self.tsws_used)), + no_of_ds=len(self.candidate_datasets), + space_per_box_cluster=0.9, + rel_indiv_box_width=0.9, + ) + + else: + warnings.warn( + f'FileNotFoundError: The file {results_file} does not exist. Please check the file path and try again.' + ) + return None + + @property + def temp_sub_win_dependent_vars(self) -> List[str]: + _list = [] + for var_name in self.ds.data_vars: + if 'tsw' in self.ds[var_name].dims: + _list.append(var_name) + return _list + + @property + def metrics_in_ds(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of metrics in the dataset, whereas each individual metric kind is a key in the dictionary \ + and the values are lists of variables in the dataset that are associated with the respective metric kind. + + Returns + ------- + dict + dictionary of metrics in the dataset + + """ + return { + metric: [ + var_name for var_name in self.temp_sub_win_dependent_vars + if var_name.startswith(f'{metric}_') + ] + for metric in globals._colormaps.keys() if any( + var_name.startswith(f'{metric}_') + for var_name in self.temp_sub_win_dependent_vars) + } # globals._colormaps just so happens to contain all metrics + + def check_for_unexpecetd_metrics(self) -> bool: + """ + Checks if the metrics are present in the dataset that were not specified in `globals.METRICS` and adds them to \ + `QA4SMCompPlotter.ds_metrics`. + + Returns + ------- + bool + True if no unexpected metrics are found in the dataset, False otherwise + """ + + flattened_list = list( + set(itertools.chain.from_iterable(self.metrics_in_ds.values()))) + elements_not_in_flattened_list = set( + self.temp_sub_win_dependent_vars) - set(flattened_list) + _list = list( + set([ + m.split('_between')[0] for m in elements_not_in_flattened_list + ])) + grouped_dict = { + prefix: + [element for element in _list if element.startswith(prefix)] + for prefix in set([element.split('_')[0] for element in _list]) + } + + for prefix, elements in grouped_dict.items(): + self.metrics_in_ds[prefix] = elements + + if len(elements_not_in_flattened_list) > 0: + warnings.warn( + f"Following metrics were found in the dataset that were not specified in `globals.METRICS` and have \ + been added to `QA4SMCompPlotter.ds_metrics`: {elements_not_in_flattened_list}" + ) + return False + + return True + + def metrics_ds_grouped_lut(self, + include_ci: Optional[bool] = False + ) -> Dict[str, List[str]]: + """ + Returns a dictionary of for each metric, containing the QA4SM dataset combination used to compute said metric + + Parameters + ---------- + include_ci : bool, default is False + Whether to include the confidence intervals of a specific metric in the output + + Returns + ------- + dict + dictionary of grouped metrics in the dataset + """ + _metric_lut = {} + + def parse_metric_string( + metric_string: str) -> Union[Tuple[str, str], None]: + pattern = globals.METRIC_TEMPLATE.format( + ds1= + '(?P\d+-\w+)', # matches one or more digits (\d+), followed by a hyphen (-), \ + # followed by one or more word characters (\w+) + ds2= + '(?P\d+-\w+)', # matches one or more digits (\d+), followed by a hyphen (-), \ + # followed by one or more word characters (\w+) + ) + + match = re.search(pattern, metric_string) + if match: + return match.group('ds1'), match.group('ds2') + else: + return None + + def purge_ci_metrics(_dict: Dict) -> Dict: + return { + ds_combi: + [metric for metric in metric_values if "ci" not in metric][0] + for ds_combi, metric_values in _dict.items() + } + + for metric_kind, metrics_in_ds in self.metrics_in_ds.items(): + + parsed_metrics = set([ + pp for metric in metrics_in_ds + if (pp := parse_metric_string(metric)) is not None + ]) + + grouped_dict = { + metric: [ + item for item in metrics_in_ds + if parse_metric_string(item) == metric + ] + for metric in parsed_metrics + } + + if not include_ci: + grouped_dict = purge_ci_metrics(grouped_dict) + + _metric_lut[metric_kind] = grouped_dict + + return _metric_lut + + @property + def tsws_used(self): + """ + Get all temporal sub-windows used in the validation + + Parameters + ---------- + incl_default : bool, default is False + Whether to include the default TSW in the output + + + Returns + ------- + tsws_used : list + list of all TSWs used in the validation + """ + + temp_sub_wins_names = [ + tsw + for tsw in self.ds.coords[globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME].values + if tsw != globals.DEFAULT_TSW + ] + + if self.include_default_case: + temp_sub_wins_names.append(globals.DEFAULT_TSW) + + return temp_sub_wins_names + + def get_specific_metric_df(self, specific_metric: str) -> pd.DataFrame: + """ + Get the DataFrame for a single **specific** metric (e.g. "R_between_0-ISMN_and_1-SMOS_L3") from a QA4SM netCDF \ + file with temporal sub-windows. + + Parameters + ---------- + specific_metric : str + Name of the specific metric + + Returns + ------- + pd.DataFrame + DataFrame for this specific metric + """ + + _data_dict = {} + _data_dict['lat'] = self.ds['lat'].values + _data_dict['lon'] = self.ds['lon'].values + _data_dict['gpi'] = self.ds['gpi'].values + for tsw in self.tsws_used: + selection = {globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME: tsw} + + _data_dict[tsw] = self.ds[specific_metric].sel( + selection).values.astype(np.float32) + + df = pd.DataFrame(_data_dict) + df.set_index(['lat', 'lon', 'gpi'], inplace=True) + + return df + + def get_metric_df(self, generic_metric: str) -> pd.DataFrame: + """ + Get the DataFrame for a single **generic** metric/metric kind (e.g. "R") from a QA4SM netCDF file with \ + temporal sub-windows. + + Parameters + ---------- + generic_metric : str + Name of the generic metric/metric kind + + Returns + ------- + pd.DataFrame + Multilevel DataFrame for this generic metric/metric kind, whereas the two column levels are all candidate \ + datasets and the temporal sub-windows + """ + + df_dict = { + ds_combi[1]: + self.get_specific_metric_df(specific_metric=specific_metric) + for ds_combi, specific_metric in + self.metric_lut[generic_metric].items() + } + return pd.concat(df_dict.values(), keys=df_dict.keys(), axis=1) + + @staticmethod + @note( + "This method is redundant, as it yields the same result as `QA4SMCompPlotter.tsws_used()`. \ + It is kept as a static method for debugging purposes." + ) + def get_tsws_from_df(df: pd.DataFrame) -> List[str]: + """ + Get all temporal sub-windows used in the validation from a DataFrame as returned by \ + `QA4SMCompPlotter.get_metric_df()` + + Parameters + ---------- + df : pd.DataFrame + DataFrame with the temporal sub-windows + + Returns + ------- + tsws_used : list + list of all TSWs used in the validation + """ + return df.columns.levels[1].unique().tolist() + + @staticmethod + def get_datasets_from_df(df: pd.DataFrame) -> List[str]: + """ + Get all candiate datasets used in the validation from a DataFrame as returned by \ + `QA4SMCompPlotter.get_metric_df()` + + Parameters + ---------- + df : pd.DataFrame + DataFrame with the datasets + + Returns + ------- + datasets_used : list + list of all datasets used in the validation + """ + return sorted(df.columns.levels[0].unique().tolist()) + + def create_title(self, Var, type: str) -> str: + """ + Create title of the plot + + Parameters + ---------- + Var: MetricVar + variable for a metric + type: str + type of plot + """ + parts = [globals._metric_name[Var.metric]] + parts.extend(QA4SMPlotter._get_parts_name(Var=Var, type=type)) + title = QA4SMPlotter._titles_lut(type=type).format(*parts) + + return title + + def create_label(self, Var) -> str: + """ + Create y-label of the plot + + Parameters + ---------- + + Var: MetricVar + variable for a metric + + Returns + ------- + label: str + y-label for the plot + """ + parts = [globals._metric_name[Var.metric]] + parts.append(globals._metric_description[Var.metric].format( + globals.get_metric_units(self.ref_dataset['short_name']))) + return "{}{}".format(*parts) + + def get_metric_vars(self, generic_metric: str) -> Dict[str, str]: + _dict = {} + + _df = self.get_metric_df(generic_metric=generic_metric) + for dataset in self.get_datasets_from_df(_df): + for ds_combi, specific_metric in self.metrics_ds_grouped_lut( + )[generic_metric].items(): + if dataset in ds_combi: + _Var = hdl.MetricVariable(varname=specific_metric, + global_attrs=self.ds.attrs) + if _Var.values == None: + _Var.values = _df.loc[:, (dataset, slice(None))] + + _dict[dataset] = _Var + + return _dict + + def get_legend_entries(self, generic_metric: str) -> Dict[str, str]: + return { + f'{Var.metric_ds[0]}-{Var.metric_ds[1]["short_name"]}': + # 'hello': + self.cbp.label_template.format( + dataset_name=Var.metric_ds[1]["pretty_name"], + dataset_version=Var.metric_ds[1] + ["pretty_version"], # Replace with your actual dataset version + variable_name=Var.metric_ds[1] + ["pretty_variable"], # Replace with your actual variable name + unit=Var.metric_ds[1]["mu"]) + for Var in self.get_metric_vars(generic_metric).values() + } + + def _load_vars(self, empty=False, only_metrics=False) -> list: + """ + Create a list of common variables and fill each with values + + Parameters + ---------- + empty : bool, default is False + if True, Var.values is an empty dataframe + only_metrics : bool, default is False + if True, only variables for metric scores are kept (i.e. not gpi, idx ...) + + Returns + ------- + vars : list + list of QA4SMVariable objects for the validation variables + """ + vars = [] + for varname in self.metric_kinds_available: + df = self.get_metric_df(generic_metric=varname) + if empty: + values = None + else: + # lat, lon are in varnames but not in datasframe (as they are the index) + try: + # values = df + values = df + except: # KeyError: + values = None + + Var = hdl.QA4SMVariable(varname, self.ds.attrs, + values=df).initialize() + + if only_metrics and isinstance(Var, hdl.MetricVariable): + vars.append(Var) + elif not only_metrics: + vars.append(Var) + + return vars + + def _iter_vars(self, + type: str = None, + name: str = None, + filter_parms: dict = None) -> iter: + """ + Iter through QA4SMVariable objects that are in the file + + Parameters + ---------- + type : str, default is None + One of 'metric', 'ci', 'metadata' can be specified to only iterate through the specific group + name : str, default is None + yield a specific variable by its name + filter_parms : dict + dictionary with QA4SMVariable attributes as keys and filter value as values (e.g. {g: 0}) + """ + type_lut = { + "metric": hdl.MetricVariable, + "ci": hdl.ConfidenceInterval, + "metadata": hdl.Metadata, + } + for Var in self._load_vars(): + if name: + if name in [Var.varname, Var.pretty_name]: + yield Var + break + else: + continue + if type and not isinstance(Var, type_lut[type]): + continue + if filter_parms: + for key, val in filter_parms.items(): + if getattr(Var, + key) == val: # check all attribute individually + check = True + else: + check = False # does not match requirements + break + if check != True: + continue + + yield Var + + def plot_cbp(self, + chosen_metric: str, + out_name: Optional[Union[List, List[str]]] = None) -> matplotlib.figure.Figure: + """ + Plot a Clustered Boxplot for a chosen metric + + Parameters + ---------- + chosen_metric : str + name of the metric + out_name : str or list of str, optional + name of the output file. Default is None + + Returns + ------- + fig : matplotlib.figure.Figure + the boxplot + + """ + + def get_metric_vars( + generic_metric: str) -> Dict[str, hdl.MetricVariable]: + _dict = {} + + for dataset in self.get_datasets_from_df(metric_df): + for ds_combi, specific_metric in self.metrics_ds_grouped_lut( + )[generic_metric].items(): + if dataset in ds_combi: + _Var = hdl.MetricVariable(varname=specific_metric, + global_attrs=self.ds.attrs) + if _Var.values == None: + _Var.values = metric_df.loc[:, + (dataset, slice(None))] + + _dict[dataset] = _Var + + return _dict + + def get_legend_entries(cbp_obj: ClusteredBoxPlot, + generic_metric: str) -> Dict[str, str]: + return { + f'{Var.metric_ds[0]}-{Var.metric_ds[1]["short_name"]}': + cbp_obj.label_template.format( + dataset_name=Var.metric_ds[1]["pretty_name"], + dataset_version=Var.metric_ds[1]["pretty_version"], + variable_name=Var.metric_ds[1]["pretty_variable"], + unit=Var.metric_ds[1]["mu"]) + for Var in get_metric_vars(generic_metric).values() + } + + metric_df = self.get_metric_df(chosen_metric) + Vars = get_metric_vars(chosen_metric) + + legend_entries = get_legend_entries(cbp_obj=self.cbp, + generic_metric=chosen_metric) + + centers_and_widths = self.cbp.centers_and_widths( + anchor_list=self.cbp.anchor_list, + no_of_ds=self.cbp.no_of_ds, + space_per_box_cluster=0.9, + rel_indiv_box_width=0.8) + + figwidth = globals.boxplot_width * (len(metric_df.columns) + 1 + ) # otherwise it's too narrow + figsize = [figwidth, globals.boxplot_height] + fig_kwargs = { + 'figsize': figsize, + 'dpi': 'figure', + 'bbox_inches': 'tight' + } + + cbp_fig = self.cbp.figure_template(incl_median_iqr_n_axs=False, + fig_kwargs=fig_kwargs) + + legend_handles = [] + for dc_num, (dc_val_name, Var) in enumerate(Vars.items()): + _df = Var.values + bp = cbp_fig.ax_box.boxplot( + _df.dropna().values, + positions=centers_and_widths[dc_num].centers, + widths=centers_and_widths[dc_num].widths, + showfliers=False, + patch_artist=True, + ) + + for box in bp['boxes']: + box.set(color=list(globals.CLUSTERED_BOX_PLOT_STYLE['colors']. + values())[dc_num]) + + legend_handles.append( + Rectangle( + (0, 0), + 1, + 1, + color=list( + globals.CLUSTERED_BOX_PLOT_STYLE['colors'].values()) + [dc_num], + alpha=0.7, + label=legend_entries[dc_val_name])) + + patch_styling( + bp, + list(globals.CLUSTERED_BOX_PLOT_STYLE['colors'].values()) + [dc_num]) + + if self.cbp.no_of_ds >= 3: + _ncols = 3 + else: + _ncols = self.cbp.no_of_ds + + cbp_fig.ax_box.legend( + handles=legend_handles, + fontsize=globals.CLUSTERED_BOX_PLOT_STYLE['fig_params'] + ['legend_fontsize'], + ncols=_ncols) + + xtick_pos = self.cbp.centers_and_widths( + anchor_list=self.cbp.anchor_list, + no_of_ds=1, + space_per_box_cluster=0.7, + rel_indiv_box_width=0.8) + cbp_fig.ax_box.set_xticks([]) + cbp_fig.ax_box.set_xticklabels([]) + cbp_fig.ax_box.set_xticks(xtick_pos[0].centers) + + def get_xtick_labels(df: pd.DataFrame) -> List: + _count_dict = df.count().to_dict() + return [ + f"{tsw[1]}\nN: {count}" for tsw, count in _count_dict.items() + ] + + cbp_fig.ax_box.set_xticklabels(get_xtick_labels(_df), ) + cbp_fig.ax_box.tick_params( + axis='both', + labelsize=globals.CLUSTERED_BOX_PLOT_STYLE['fig_params'] + ['tick_labelsize']) + + _dummy_xticks = [ + cbp_fig.ax_box.axvline(x=(a + b) / 2, color='lightgrey') for a, b + in zip(xtick_pos[0].centers[:-1], xtick_pos[0].centers[1:]) + ] + cbp_fig.fig.suptitle( + self.create_title(Var, type='boxplot_basic'), + fontsize=globals.CLUSTERED_BOX_PLOT_STYLE['fig_params'] + ['title_fontsize']) + cbp_fig.ax_box.set_ylabel( + self.create_label(Var), + fontsize=globals.CLUSTERED_BOX_PLOT_STYLE['fig_params'] + ['y_labelsize'], + ) + + spth = [Path(f"{globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric = chosen_metric, filetype = 'png')}")] + if out_name: + spth = out_name + + [cbp_fig.fig.savefig( + fname=outname, + dpi=fig_kwargs['dpi'], + bbox_inches=fig_kwargs['bbox_inches'], + ) for outname in spth] + + return cbp_fig.fig diff --git a/src/qa4sm_reader/plotting_methods.py b/src/qa4sm_reader/plotting_methods.py index 5e542d2..1596747 100644 --- a/src/qa4sm_reader/plotting_methods.py +++ b/src/qa4sm_reader/plotting_methods.py @@ -2,32 +2,43 @@ """ Contains helper functions for plotting qa4sm results. """ +from logging import handlers from qa4sm_reader import globals from qa4sm_reader.exceptions import PlotterError +from qa4sm_reader.handlers import ClusteredBoxPlotContainer, CWContainer +from qa4sm_reader.utils import note import numpy as np import pandas as pd import os.path -from typing import Union +from typing import Union, List, Tuple, Dict, Optional, Any import copy import seaborn as sns +import matplotlib +import matplotlib.axes +import matplotlib.cbook as cbook +import matplotlib.image as mpimg import matplotlib.pyplot as plt import matplotlib.colors as mcol import matplotlib.ticker as mticker import matplotlib.gridspec as gridspec from matplotlib.patches import Patch, PathPatch -from matplotlib.lines import Line2D + from cartopy import config as cconfig import cartopy.feature as cfeature from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER +import cartopy.crs as ccrs from pygeogrids.grids import BasicGrid, genreg_grid from shapely.geometry import Polygon, Point import warnings +import os +from collections import namedtuple + cconfig['data_dir'] = os.path.join(os.path.dirname(__file__), 'cartopy') @@ -101,9 +112,9 @@ def geotraj_to_geo2d(df, index=globals.index_names, grid_stepsize=None): """ Converts geotraj (list of lat, lon, value) to a regular grid over lon, lat. The values in df needs to be sampled from a regular grid, the order does not matter. - When used with plt.imshow(), specify data_extent to make sure, + When used with plt.imshow(), specify data_extent to make sure, the pixels are exactly where they are expected. - + Parameters ---------- df : pandas.DataFrame @@ -190,8 +201,8 @@ def get_value_range(ds, try: v_min = ranges[metric][0] v_max = ranges[metric][1] - if (v_min is None and v_max is - None): # get quantile range and make symmetric around 0. + if (v_min is None and v_max is None + ): # get quantile range and make symmetric around 0. v_min, v_max = get_quantiles(ds, quantiles) v_max = max( abs(v_min), @@ -262,12 +273,12 @@ def get_plot_extent(df, grid_stepsize=None, grid=False) -> tuple: whether the values in df is on a equally spaced grid (for use in mapplot) df : pandas.DataFrame Plot values. - + Returns ------- extent : tuple | list (x_min, x_max, y_min, y_max) in Data coordinates. - + """ lat, lon, gpi = globals.index_names if grid and grid_stepsize in ['nan', None]: @@ -312,22 +323,29 @@ def get_plot_extent(df, grid_stepsize=None, grid=False) -> tuple: return extent -def init_plot(figsize, dpi, add_cbar=None, projection=None) -> tuple: +def init_plot(figsize, dpi, add_cbar=None, projection=None, fig_template = None) -> tuple: """Initialize mapplot""" if not projection: projection = globals.crs - fig = plt.figure(figsize=figsize, dpi=dpi) + + if fig_template is None: + # fig, ax_main = plt.subplots(figsize=figsize, dpi=dpi) + fig = plt.figure(figsize=figsize, dpi=dpi) + else: + fig = fig_template.fig + ax_main = fig_template.ax_main + + if add_cbar: gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[19, 1]) - ax = fig.add_subplot(gs[0], projection=projection) + ax_main = fig.add_subplot(gs[0], projection=projection) cax = fig.add_subplot(gs[1]) else: gs = gridspec.GridSpec(nrows=1, ncols=1) - ax = fig.add_subplot(gs[0], projection=projection) + ax_main = fig.add_subplot(gs[0], projection=projection) cax = None - return fig, ax, cax - + return fig, ax_main, cax def get_extend_cbar(metric): """ @@ -447,11 +465,16 @@ def style_map( return ax + +@note( + "DeprecationWarning: The function `qa4sm_reader.plotting_methods.make_watermark()` is deprecated and will be removed in the next release. Use `qa4sm_reader.plotting_methods.add_logo_to_figure` instead to add a logo." +) def make_watermark(fig, placement=globals.watermark_pos, for_map=False, offset=0.03, - for_barplot=False): + for_barplot=False, + fontsize=globals.watermark_fontsize): """ Adds a watermark to fig and adjusts the current axis to make sure there is enough padding around the watermarks. @@ -472,56 +495,146 @@ def make_watermark(fig, """ # ax = fig.gca() # pos1 = ax.get_position() #fraction of figure - fontsize = globals.watermark_fontsize pad = globals.watermark_pad height = fig.get_size_inches()[1] offset = offset + (( (fontsize + pad) / globals.matplotlib_ppi) / height) * 2.2 if placement == 'top': - plt.annotate(globals.watermark, - xy=[0.5, 1], - xytext=[-pad, -pad], - fontsize=fontsize, - color='grey', - horizontalalignment='center', - verticalalignment='top', - xycoords='figure fraction', - textcoords='offset points') + plt.annotate( + globals.watermark, + xy=[0.5, 1], + xytext=[-pad, -pad], + fontsize=fontsize, + color='white', #TODO! change back to grey + horizontalalignment='center', + verticalalignment='top', + xycoords='figure fraction', + textcoords='offset points') top = fig.subplotpars.top fig.subplots_adjust(top=top - offset) elif for_map or for_barplot: if for_barplot: - plt.suptitle(globals.watermark, - color='grey', - fontsize=fontsize, - x=-0.07, - y=0.5, - va='center', - rotation=90) + plt.suptitle( + globals.watermark, + color='white', #TODO! change back to grey + fontsize=fontsize, + x=-0.07, + y=0.5, + va='center', + rotation=90) else: - plt.suptitle(globals.watermark, - color='grey', - fontsize=fontsize, - y=0, - ha='center') + plt.suptitle( + globals.watermark, + color='white', #TODO! change back to grey + fontsize=fontsize, + y=0, + ha='center') elif placement == 'bottom': - plt.annotate(globals.watermark, - xy=[0.5, 0], - xytext=[pad, pad], - fontsize=fontsize, - color='grey', - horizontalalignment='center', - verticalalignment='bottom', - xycoords='figure fraction', - textcoords='offset points') + plt.annotate( + globals.watermark, + xy=[0.5, 0], + xytext=[pad, pad], + fontsize=fontsize, + color='white', #TODO! change back to grey + horizontalalignment='center', + verticalalignment='bottom', + xycoords='figure fraction', + textcoords='offset points') bottom = fig.subplotpars.bottom if not for_map: fig.subplots_adjust(bottom=bottom + offset) else: raise NotImplementedError +#$$ +Offset = namedtuple('offset', ['x', 'y']) # helper for offset in add_logo_to_figure +def add_logo_to_figure(fig: matplotlib.figure.Figure, + logo_path: Optional[str] = globals.watermark_logo_pth, + position: Optional[str] = globals.watermark_logo_position, + offset: Optional[Union[Tuple, Offset]] = (0., -0.15), + scale: Optional[float] = 0.15) -> None: + """ + Add a logo to an existing figure. This is done by creating an additional axis in the figure, at the location\ + specified by `position`. The logo is then placed on this axis. + + Parameters + ---------- + fig: matplotlib.figure.Figure + The figure to add the logo to. The figure should have at least one axis, otherwise an axis is created.z + + logo_path: Optional[str] + Path to the logo image. If the path does not exist, a warning is raised and the function returns. Default is\ + `globals.watermark_logo_pth`. + + position: Optional[str] + The position of the logo in the figure. Valid values are 'lower_left', 'lower_center', 'lower_right',\ + 'upper_left', 'upper_center', 'upper_right'. Default is `globals.watermark_logo_position`. + + offset: Optional[Tuple | Offset] + Offset of the logo from the right edge of the subplot (right lower corner of the main plot).\ + The first value is the x-offset, the second value is the y-offset. Default is (0., 0). + + scale: Optional[float] + Scale of the logo relative to the figure height (= fraction of figure height). Valid values are (0, 1].\ + Default is 0.15. + + Returns + ------- + None + """ + + if not fig.get_axes(): + warnings.warn("No axes found in the figure. Creating a new one.") + fig.add_subplot(111) + + if not os.path.exists(logo_path): + warnings.warn(f"No logo found at the specified path: '{logo_path}'. Skipping logo addition.") + print(f"No logo found at the specified path: '{logo_path}'. Skipping logo addition.") + return + + with cbook.get_sample_data(logo_path) as file: + im = mpimg.imread(file) + + # Get the dimensions of the image + height, width, _ = im.shape + + fig_height_pixels = fig.get_figheight() * fig.dpi + + logo_height_pixels = scale * fig_height_pixels + logo_width_pixels = width * logo_height_pixels / height + + # Convert back to figure coordinates + logo_width_fig = logo_width_pixels / fig.dpi + + if not isinstance(offset, Offset): + offset = Offset(*offset) + + + if 'left' in position: + left = 1 - (logo_width_fig) + offset.x + elif 'center' in position: + left = 0.5 - (logo_width_fig / 2) + offset.x + elif 'right' in position: # 'right' in position + left = 0 + offset.x + + if 'lower' in position: + bottom = offset.y + elif 'upper' in position: # 'upper' in position + bottom = 1 - offset.y + + # Define the new position of ax_logo + # [left, bottom, width, height] + ax_logo_pos = [left, bottom, logo_width_fig, scale] + + # Add a new axis to the figure at the position of ax_logo to house the logo + ax_logo = fig.add_axes(ax_logo_pos) + ax_logo.imshow(im) + + # Hide the axis + ax_logo.axis('off') + def _make_cbar(fig, im, @@ -716,12 +829,17 @@ def boxplot( values = df.copy() center_pos = np.arange(len(values.columns)) * 2 # make plot + ax = axis if axis is None: fig, ax = plt.subplots(figsize=figsize, dpi=dpi) + else: + fig = None ticklabels = values.columns # styling of the boxes kwargs = {"patch_artist": True, "return_type": "dict"} + for key, value in plotting_kwargs.items(): + kwargs[key] = value # changes necessary to have confidence intervals in the plot # could be an empty list or could be 'None', if de-selected from the kwargs if ci: @@ -750,29 +868,41 @@ def boxplot( patch_styling(low, 'skyblue') patch_styling(up, 'tomato') - cen = values.boxplot(positions=center_pos, + if not 'positions' in kwargs: + positions = center_pos + else: + positions = kwargs['positions'] + del kwargs['positions'] + + if not 'widths' in kwargs: + widths = 0.3 + else: + widths = kwargs['widths'] + del kwargs['widths'] + + cen = values.boxplot(positions=positions, showfliers=False, - widths=0.3, + widths=widths, ax=ax, **kwargs) patch_styling(cen, 'white') + if ci: low_ci = Patch(color='skyblue', alpha=0.7, label='Lower CI') up_ci = Patch(color='tomato', alpha=0.7, label='Upper CI') # _CI_difference(fig, ax, ci) - plt.legend(handles=[low_ci, up_ci], fontsize=8, loc="best") + ax.legend(handles=[low_ci, up_ci], fontsize=8, loc="best") # provide y label if label is not None: plt.ylabel(label, weight='normal') - ax.set_xticks(center_pos) + ax.set_xticks(positions) ax.set_xticklabels(ticklabels) ax.tick_params(labelsize=globals.tick_size) ax.grid(axis='x') ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) - if axis is None: - return fig, ax + return fig, ax def _replace_status_values(ser): @@ -805,6 +935,7 @@ def barplot( label=None, figsize=None, dpi=100, + axis=None, ) -> tuple: """ Create a barplot from the validation errors in df. @@ -822,6 +953,8 @@ def barplot( Figure size in inches. The default is globals.map_figsize. dpi : int, optional Resolution for raster graphic output. The default is globals.dpi. + axis : matplotlib Axis obj. + if provided, the plot will be shown on it Returns ------- @@ -829,7 +962,12 @@ def barplot( the boxplot ax : matplotlib.axes.Axes """ - fig, ax = plt.subplots(figsize=figsize, dpi=dpi) + + ax = axis + if axis is None: + fig, ax = plt.subplots(figsize=figsize, dpi=dpi) + else: + fig = None values = df.copy() values = values[[values.keys()[0]]] @@ -1231,10 +1369,12 @@ def bplot_multiple(to_plot, y_axis, n_bars, **kwargs) -> tuple: def _dict2df(to_plot_dict: dict, meta_key: str) -> pd.DataFrame: """Transform a dictionary into a DataFrame for catplotting""" to_plot_df = [] + for range, values in to_plot_dict.items(): range_grouped = [] for ds in values: - values_ds = values[ds].to_frame(name="values") + values_ds = values[ds] + values_ds = values_ds.to_frame(name="values") values_ds["Dataset"] = ds values_ds[meta_key] = "\n[".join(range.split(" [")) range_grouped.append(values_ds) @@ -1389,8 +1529,9 @@ def boxplot_metadata( metric_label = "values" meta_key = metadata_values.columns[0] # sort data according to the metadata type - type = globals.metadata[meta_key][2] - bin_funct = bin_function_lut(type) + metadata_type = globals.metadata[meta_key][2] + + bin_funct = bin_function_lut(metadata_type) to_plot = bin_funct( df=df, metadata_values=metadata_values, @@ -1413,6 +1554,7 @@ def boxplot_metadata( elif isinstance(to_plot, pd.DataFrame): generate_plot = bplot_catplot + out = generate_plot( to_plot=to_plot, y_axis=ax_label, @@ -1428,20 +1570,20 @@ def boxplot_metadata( return fig, axes -def mapplot(df, - metric, - ref_short, - scl_short=None, - ref_grid_stepsize=None, - plot_extent=None, +def mapplot(df: pd.DataFrame, + metric: str, + ref_short : str, + scl_short: Optional[str] = None, + ref_grid_stepsize: Optional[float] = None, + plot_extent: Optional[Tuple[float, float, float, float]] = None, colormap=None, - projection=None, - add_cbar=True, - label=None, - figsize=globals.map_figsize, - dpi=globals.dpi_min, - diff_map=False, - **style_kwargs) -> tuple: + projection: Optional[ccrs.Projection] = None, + add_cbar: Optional[bool] = True, + label: Optional[str] = None, + figsize: Optional[Tuple[float, float]] = globals.map_figsize, + dpi: Optional[int] = globals.dpi_min, + diff_map: Optional[bool] = False, + **style_kwargs: Dict) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """ Create an overview map from df using values as color. Plots a scatterplot for ISMN and an image plot for other input values. @@ -1810,3 +1952,152 @@ def average_non_additive(values: Union[pd.Series, np.array], # Back transform the result return np.tanh(mean) + +#$$ +class ClusteredBoxPlot: + """ + Class to create an empty figure object with one main axis and optionally three sub-axis. It is used to create a template for the clustered boxplot, which can then be filled with data. + """ + + def __init__(self, + anchor_list: Union[List[float], np.ndarray], + no_of_ds: int, + space_per_box_cluster: Optional[float] = 0.9, + rel_indiv_box_width: Optional[float] = 0.9): + self.anchor_list = anchor_list + self.no_of_ds = no_of_ds + self.space_per_box_cluster = space_per_box_cluster + self.rel_indiv_box_width = rel_indiv_box_width + + # xticklabel and legend label templates + # self.xticklabel_template = "{tsw}:\n{dataset_name}\n({dataset_version})\nVariable: {variable_name} [{unit}]\n Median: {median:.3e}\n IQR: {iqr:.3e}\nN: {count}" + self.xticklabel_template = "Median: {median:.3e}\n IQR: {iqr:.3e}\nN: {count}" + self.label_template = "{dataset_name} ({dataset_version})\nVariable: {variable_name} [{unit}]" + + @staticmethod + def centers_and_widths( + anchor_list: Union[List[float], np.ndarray], + no_of_ds: int, + space_per_box_cluster: Optional[float] = 0.9, + rel_indiv_box_width: Optional[float] = 0.9) -> List[CWContainer]: + """ + Function to calculate the centers and widths of the boxes of a clustered boxplot. The function returns a list of tuples, each containing the center and width of a box in the clustered boxplot. The output can then be used as indices for creating the boxes a boxplot using `matplotlib.pyplot.boxplot()` + + Parameters + ---------- + + anchor_list: Union[List[float], np.ndarray] + A list of floats representing the anchor points for each box cluster + no_of_ds: int + The number of datasets, i.e. the number of boxes in each cluster + space_per_box_cluster: float + The space each box cluster can occupy, 0.9 per default. This value should be <= 1 for a clustered boxplot to prevent overlap between neighboring clusters and boxes + rel_indiv_box_width: float + The relative width of the individual boxes in a cluster, 0.9 per default. This value should be <= 1 to prevent overlap between neighboring boxes + + Returns + ------- + + List[CWContainer] + A list of CWContainer objects. Each dataset present has its own CWContainer object, each containing the centers and widths of the boxes in the clustered boxplot + + """ + + b_lb_list = [ + -space_per_box_cluster / 2 + anchor for anchor in anchor_list + ] # list of lower bounds for each box cluster + b_ub_list = [ + space_per_box_cluster / 2 + anchor for anchor in anchor_list + ] # list of upper bounds for each box cluster + + _centers = sorted([(b_ub - b_lb) / (no_of_ds + 1) + b_lb + i * + ((b_ub - b_lb) / (no_of_ds + 1)) + for i in range(int(no_of_ds)) + for b_lb, b_ub in zip(b_lb_list, b_ub_list)]) + _widths = [ + rel_indiv_box_width * (_centers[0] - b_lb_list[0]) + for _center in _centers + ] + + return [ + CWContainer(name=f'ds_{ds}', + centers=_centers[ds::no_of_ds], + widths=_widths[ds::no_of_ds]) + for ds in range(int(no_of_ds)) + ] + + @staticmethod + def figure_template(incl_median_iqr_n_axs: Optional[bool] = False, + **fig_kwargs) -> ClusteredBoxPlotContainer: + """ + Function to create a figure template for e.g. a clustered boxplot. The function returns a \ + ClusteredBoxPlotContainer object, which contains the figure and the subplots for the boxplot as well as \ + optionally the median, IQR and N values. The layout is as follows: the axes are arranged in a 2x1 grid, \ + with the boxplot in the upper subplot and the median, IQR and N values in the lower subplot. \ + The lower subplot is further divided into three subplots, one for each value. + + Parameters + ---------- + incl_median_iqr_n_axs: Optional[bool] + If True, creates three subplots with median, IQR and N values for each box. If False, only the boxplot is \ + created. Default is False + fig_kwargs: dict + Keyword arguments for the figure + + Returns + ------- + ClusteredBoxPlotContainer + A ClusteredBoxPlotContainer object containing the figure and the subplots for the boxplot, median, \ + IQR and N values + """ + + if 'figsize' in fig_kwargs: + _fig = plt.figure(figsize=fig_kwargs['figsize']) + else: + _fig = plt.figure(figsize=(15, 10.5)) + + if not incl_median_iqr_n_axs: + ax_box = _fig.add_subplot(111) + ax_median, ax_iqr, ax_n = None, None, None + + if incl_median_iqr_n_axs: + # Create a main gridspec for ax_box and subplots below + gs_main = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.2) + + # Subgridspec for ax_box and ax_median (top subplot) + gs_top = gridspec.GridSpecFromSubplotSpec(1, + 1, + subplot_spec=gs_main[0]) + + # Subgridspec for ax_iqr and ax_n (bottom subplots) + gs_bottom = gridspec.GridSpecFromSubplotSpec( + 3, + 1, + height_ratios=[1, 1, 1], + subplot_spec=gs_main[1], + hspace=0) + ax_box = plt.subplot(gs_top[0]) + ax_median = plt.subplot(gs_bottom[0], sharex=ax_box) + ax_iqr = plt.subplot(gs_bottom[1], sharex=ax_box) + ax_n = plt.subplot(gs_bottom[2], sharex=ax_box) + + for _ax in [ax_box, ax_median, ax_iqr, ax_n]: + try: + _ax.tick_params(labelsize=globals.tick_size) + _ax.spines['right'].set_visible(False) + _ax.spines['top'].set_visible(False) + except AttributeError: + pass + + add_logo_to_figure(fig = _fig, + logo_path = globals.watermark_logo_pth, + position = globals.watermark_logo_position, + offset = globals.watermark_logo_offset_comp_plots, + scale = globals.watermark_logo_scale, + ) + + return ClusteredBoxPlotContainer(fig=_fig, + ax_box=ax_box, + ax_median=ax_median, + ax_iqr=ax_iqr, + ax_n=ax_n) diff --git a/src/qa4sm_reader/static/images/logo/QA4SM_logo_long.png b/src/qa4sm_reader/static/images/logo/QA4SM_logo_long.png new file mode 100755 index 0000000..0aec08f Binary files /dev/null and b/src/qa4sm_reader/static/images/logo/QA4SM_logo_long.png differ diff --git a/src/qa4sm_reader/utils.py b/src/qa4sm_reader/utils.py new file mode 100644 index 0000000..3ce2ea6 --- /dev/null +++ b/src/qa4sm_reader/utils.py @@ -0,0 +1,99 @@ +from functools import wraps +import inspect +import logging +from typing import Any, Callable, TypeVar, Union, Dict, List +from re import match as regex_match +import qa4sm_reader.globals +from qa4sm_reader.handlers import QA4SMVariable +from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber +import qa4sm_reader.globals as globals +import xarray as xr +from pathlib import PosixPath + +T = TypeVar('T', bound=Callable[..., Any]) + + +def note(note_text: Any) -> Callable[[T], T]: + """ + Factory function creating a decorator, that prints a note before the execution of the decorated function. + + Parameters: + ---------- + note_text : Any + The note to be printed. + + Returns: + ------- + Callable[[T], T] + The decorated function. + """ + + def decorator(func: T) -> T: + + @wraps(func) + def wrapper(*args, **kwargs): + print(f'\n\n{note_text}\n\n') + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def log_function_call(func: Callable) -> Callable[[T], T]: + '''Decorator that logs the function call with its arguments and their values.''' + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe().f_back + func_name = frame.f_code.co_name + local_vars = frame.f_locals + logging.info(f'**{func_name}**({", ".join(f"{k}={v}" for k, v in local_vars.items())})') + return func(*args, **kwargs) + return wrapper + + +def transcribe(file_path: Union[str, PosixPath]) -> Union[None, xr.Dataset]: + '''If the dataset is not in the new format, transcribe it to the new format. + This is done under the assumption that the dataset is a `pytesmo` dataset and corresponds to a default\ + validation, i.e. no temporal sub-windows are present. + + Parameters + ---------- + file_path : str or PosixPath + path to the file to be transcribed + + Returns + ------- + dataset : xr.Dataset + the transcribed dataset + ''' + + temp_sub_wdw_instance = None # bulk case, no temporal sub-windows + + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=file_path, + intra_annual_slices=temp_sub_wdw_instance, + keep_pytesmo_ncfile=False) + + if transcriber.exists: + return transcriber.get_transcribed_dataset() + + +def filter_out_self_combination_tcmetric_vars(variables: List[QA4SMVariable]) -> List[QA4SMVariable]: + """ + Filters out the 'self-combination' temporal collocation metric varriables, referring to variables that \ + match the pattern: {METRIC}_{DATASET_A}_between_{DATASET_A}_and_{WHATEVER}. The occurence of these \ + metric vars is a consequence of reference dataset tcol metric vas being written to the file + + Parameters + ---------- + variables : List[QA4SMVariable] + list of variables to be filtered + + Returns + ------- + List[QA4SMVariable] + the filtered list of variables + """ + + return [var for var in variables if var.metric_ds != var.ref_ds] diff --git a/tests/conftest.py b/tests/conftest.py index e0015f2..1c45f95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +import pytest +from pathlib import Path + """ Dummy conftest.py for qa4sm_reader. @@ -6,3 +9,18 @@ Read more about conftest.py under: https://pytest.org/latest/plugins.html """ + +def pytest_collection_modifyitems(items): + # Move the test_utils::test_transcribe to the beginning of the list, as it is required for transcribing test files for other tests + first_test = None + for item in items: + if item.name == "test_transcribe": + first_test = item + break + if first_test: + items.insert(0, items.pop(items.index(first_test))) + + +@pytest.fixture(scope="session") +def TEST_DATA_DIR(): + return Path(__file__).parent / 'test_data' diff --git a/tests/test_data/intra_annual/custom_intra_annual_windows.json b/tests/test_data/intra_annual/custom_intra_annual_windows.json new file mode 100644 index 0000000..e96749d --- /dev/null +++ b/tests/test_data/intra_annual/custom_intra_annual_windows.json @@ -0,0 +1,218 @@ +{ + "seasons": { + "S1": [ + [ + 12, + 1 + ], + [ + 2, + 28 + ] + ], + "S2": [ + [ + 3, + 1 + ], + [ + 5, + 31 + ] + ], + "S3": [ + [ + 6, + 1 + ], + [ + 8, + 31 + ] + ], + "S4": [ + [ + 9, + 1 + ], + [ + 11, + 30 + ] + ] + }, + "months": { + "Jan": [ + [ + 1, + 1 + ], + [ + 1, + 31 + ] + ], + "Feb": [ + [ + 2, + 1 + ], + [ + 2, + 28 + ] + ], + "Mar": [ + [ + 3, + 1 + ], + [ + 3, + 31 + ] + ], + "Apr": [ + [ + 4, + 1 + ], + [ + 4, + 30 + ] + ], + "May": [ + [ + 5, + 1 + ], + [ + 5, + 31 + ] + ], + "Jun": [ + [ + 6, + 1 + ], + [ + 6, + 30 + ] + ], + "Jul": [ + [ + 7, + 1 + ], + [ + 7, + 31 + ] + ], + "Aug": [ + [ + 8, + 1 + ], + [ + 8, + 31 + ] + ], + "Sep": [ + [ + 9, + 1 + ], + [ + 9, + 30 + ] + ], + "Oct": [ + [ + 10, + 1 + ], + [ + 10, + 31 + ] + ], + "Nov": [ + [ + 11, + 1 + ], + [ + 11, + 30 + ] + ], + "Dec": [ + [ + 12, + 1 + ], + [ + 12, + 31 + ] + ] + }, + "custom": { + "star wars month": [ + [ + 5, + 1 + ], + [ + 5, + 31 + ] + ], + "halloween season": [ + [ + 10, + 1 + ], + [ + 10, + 31 + ] + ], + "advent": [ + [ + 12, + 1 + ], + [ + 12, + 24 + ] + ], + "movember": [ + [ + 11, + 1 + ], + [ + 11, + 30 + ] + ], + "christmas": [ + [ + 12, + 24 + ], + [ + 12, + 26 + ] + ] + } +} diff --git a/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_pytesmo.nc b/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_pytesmo.nc new file mode 100644 index 0000000..dc9b91e Binary files /dev/null and b/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_pytesmo.nc differ diff --git a/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_qa4sm.nc b/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_qa4sm.nc new file mode 100644 index 0000000..7040e8d Binary files /dev/null and b/tests/test_data/intra_annual/monthly/0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_qa4sm.nc differ diff --git a/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_pytesmo.nc b/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_pytesmo.nc new file mode 100644 index 0000000..c734516 Binary files /dev/null and b/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_pytesmo.nc differ diff --git a/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_qa4sm.nc b/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_qa4sm.nc new file mode 100644 index 0000000..7ab9863 Binary files /dev/null and b/tests/test_data/intra_annual/seasonal/0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_qa4sm.nc differ diff --git a/tests/test_image.py b/tests/test_image.py index 51e9636..bb1ee36 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -70,7 +70,7 @@ def test_load_vars(img): Vars = img._load_vars() assert len(Vars) == len(img.varnames) Metr_Vars = img._load_vars(only_metrics=True) - assert len(Metr_Vars) == len(Vars) - 21 + assert len(Metr_Vars) == len(Vars) - 22 def test_iter_vars(img): @@ -120,13 +120,16 @@ def test_ds2df(img): def test_metric_df(img): df = img.metric_df(['R']) + print(list(df.columns)) assert list(df.columns) == [ 'R_between_0-ERA5_LAND_and_1-C3S_combined', 'R_ci_lower_between_0-ERA5_LAND_and_1-C3S_combined', 'R_ci_upper_between_0-ERA5_LAND_and_1-C3S_combined', 'R_between_0-ERA5_LAND_and_2-SMOS_IC', 'R_ci_lower_between_0-ERA5_LAND_and_2-SMOS_IC', - 'R_ci_upper_between_0-ERA5_LAND_and_2-SMOS_IC', 'idx' + 'R_ci_upper_between_0-ERA5_LAND_and_2-SMOS_IC', + globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME, + 'idx', ] diff --git a/tests/test_intra_annual_temp_windows.py b/tests/test_intra_annual_temp_windows.py new file mode 100644 index 0000000..c8c0a1b --- /dev/null +++ b/tests/test_intra_annual_temp_windows.py @@ -0,0 +1,513 @@ +import os +from pathlib import Path +import pytest +import json +from copy import deepcopy +from datetime import datetime +from pytesmo.validation_framework.metric_calculators_adapters import TsDistributor +from pytesmo.time_series.grouping import YearlessDatetime + +from qa4sm_reader.intra_annual_temp_windows import TemporalSubWindowsDefault, TemporalSubWindowsCreator, NewSubWindow, InvalidTemporalSubWindowError + + +@pytest.fixture +def default_monthly_sub_windows_no_overlap(): + # default 'months' defined in globals.py + return TemporalSubWindowsCreator(temporal_sub_window_type="months", + overlap=0, + custom_file=None) + + +@pytest.fixture +def default_seasonal_sub_windows_no_overlap(): + # default 'seasons' defined in globals.py + return TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=0, + custom_file=None) + + +@pytest.fixture +def seasonal_sub_windows_positive_overlap(): + # the ovelap is in units of days and can be positive or negative and is applied to both ends of the temporal sub-windows + # a positive overlap will result in temporal sub-windows that overlap with each other + return TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=5, + custom_file=None) + + +@pytest.fixture +def seasonal_sub_windows_negative_overlap(): + # a negative overlap will result in temporal sub-windows that have gaps between them + return TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=-5, + custom_file=None) + + +@pytest.fixture +def temporal_sub_windows_custom(): + # load custom temporal sub-windows from json file + return TemporalSubWindowsCreator( + temporal_sub_window_type='custom', + overlap=0, + custom_file=Path(__file__).resolve().parent.parent / 'tests' / + 'test_data' / 'intra_annual' / 'custom_intra_annual_windows.json') + + +@pytest.fixture +def additional_temp_sub_window(): + # create a new temporal sub-window, to be used in addition to the default ones + return NewSubWindow(name="Feynman", + begin_date=datetime(1918, 5, 11), + end_date=datetime(1988, 2, 15)) + + +#------------------------- Tests for TemporalSubwindowsDefault class ----------------------------------------------------------------------- + + +class TemporalSubWindowsConcrete(TemporalSubWindowsDefault): + # used to test the abstract class TemporalSubWindowsDefault + def _get_available_temp_sub_wndws(self): + return {"seasons": {"S1": [[12, 1], [2, 28]], "S2": [[3, 1], [5, 31]]}} + + +def test_initialization(): + temp_sub_windows = TemporalSubWindowsConcrete(custom_file='test.json') + assert temp_sub_windows.custom_file == 'test.json' + + +def test_load_json_data(tmp_path): + test_data = { + "seasons": { + "S1": [[12, 1], [2, 28]], + "S2": [[3, 1], [5, 31]] + } + } + test_file = tmp_path / "test.json" + with open(test_file, 'w') as f: + json.dump(test_data, f) + + temp_sub_windows = TemporalSubWindowsConcrete() + loaded_data = temp_sub_windows._load_json_data(test_file) + assert loaded_data == test_data + + +def test_get_available_temp_sub_wndws(): + temp_sub_windows = TemporalSubWindowsConcrete() + available_windows = temp_sub_windows._get_available_temp_sub_wndws() + assert available_windows == { + "seasons": { + "S1": [[12, 1], [2, 28]], + "S2": [[3, 1], [5, 31]] + } + } + + +#------------------------- Tests for NewSubWindow class ----------------------------------------------------------------------- + + +def test_new_sub_window_attributes(additional_temp_sub_window): + # used to generate proper description of temporal sub-window dimenison in the netCDF file + assert additional_temp_sub_window.begin_date_pretty == '1918-05-11' + assert additional_temp_sub_window.end_date_pretty == '1988-02-15' + + +def test_faulty_new_sub_window(): + # begin_date and end_date must be datetime objects + with pytest.raises((TypeError, AttributeError)): + NewSubWindow(name="Test Window", + begin_date="2023-01-01", + end_date=datetime.now()) + + with pytest.raises((TypeError, AttributeError)): + NewSubWindow(name="Test Window", + begin_date=datetime.now(), + end_date="2023-12-31") + + # begin_date must be before end_date, bc date is NOT a yearless date + with pytest.raises(ValueError): + NewSubWindow(name="Test Window", + begin_date=datetime(5000, 1, 1), + end_date=datetime(1000, 1, 1)) + + # both begin_date and end_date must be instances of the same class + with pytest.raises(TypeError): + NewSubWindow(name="Test Window", + begin_date=datetime(5000, 1, 1), + end_date=YearlessDatetime(1, 1)) + + +#------------------------- Tests for TemporalSubWindowsCreator class ---------------------------------------------------------- + + +def test_default_monthly_sub_windows_attributes( + default_monthly_sub_windows_no_overlap, + default_seasonal_sub_windows_no_overlap): + assert default_monthly_sub_windows_no_overlap.temporal_sub_window_type == "months" + + assert default_seasonal_sub_windows_no_overlap.temporal_sub_window_type == "seasons" + + assert default_monthly_sub_windows_no_overlap.overlap == default_seasonal_sub_windows_no_overlap.overlap == 0 + + assert default_monthly_sub_windows_no_overlap.custom_file == default_seasonal_sub_windows_no_overlap.custom_file == None + + assert default_monthly_sub_windows_no_overlap.available_temp_sub_wndws == default_seasonal_sub_windows_no_overlap.available_temp_sub_wndws == [ + 'seasons', 'months' + ] + + assert default_monthly_sub_windows_no_overlap.names == [ + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', + 'Nov', 'Dec' + ] + + # included so that if definition of months changes in the globals.py file, the test will fail + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Jan'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(1, 1), + YearlessDatetime(1, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Feb'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(2, 1), + YearlessDatetime(2, 28))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Mar'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(3, 1), + YearlessDatetime(3, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Apr'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(4, 1), + YearlessDatetime(4, 30))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'May'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(5, 1), + YearlessDatetime(5, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Jun'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(6, 1), + YearlessDatetime(6, 30))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Jul'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(7, 1), + YearlessDatetime(7, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Aug'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(8, 1), + YearlessDatetime(8, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Sep'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(9, 1), + YearlessDatetime(9, 30))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Oct'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(10, 1), + YearlessDatetime(10, 31))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Nov'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(11, 1), + YearlessDatetime(11, 30))]).yearless_date_ranges + assert default_monthly_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'Dec'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(12, 1), + YearlessDatetime(12, 31))]).yearless_date_ranges + + assert default_monthly_sub_windows_no_overlap.additional_temp_sub_wndws_container == {} + + # used to generate proper description of temporal sub-window dimenison in the netCDF file + assert default_monthly_sub_windows_no_overlap.metadata == { + 'Temporal sub-window type': + 'months', + 'Overlap': + '0 days', + 'Pretty Names [MM-DD]': + 'Jan: 01-01 to 01-31, Feb: 02-01 to 02-28, Mar: 03-01 to 03-31, Apr: 04-01 to 04-30, May: 05-01 to 05-31, Jun: 06-01 to 06-30, Jul: 07-01 to 07-31, Aug: 08-01 to 08-31, Sep: 09-01 to 09-30, Oct: 10-01 to 10-31, Nov: 11-01 to 11-30, Dec: 12-01 to 12-31' + } + + +def test_default_seasonal_sub_windows_attributes( + default_seasonal_sub_windows_no_overlap): + assert default_seasonal_sub_windows_no_overlap.temporal_sub_window_type == "seasons" + + assert default_seasonal_sub_windows_no_overlap.overlap == 0 + + assert default_seasonal_sub_windows_no_overlap.custom_file == None + + assert default_seasonal_sub_windows_no_overlap.available_temp_sub_wndws == [ + 'seasons', 'months' + ] + + assert default_seasonal_sub_windows_no_overlap.names == [ + 'S1', 'S2', 'S3', 'S4' + ] + + # included so that if definition of months changes in the globals.py file, the test will fail + assert default_seasonal_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'S1'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(12, 1), + YearlessDatetime(2, 28))]).yearless_date_ranges + assert default_seasonal_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'S2'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(3, 1), + YearlessDatetime(5, 31))]).yearless_date_ranges + assert default_seasonal_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'S3'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(6, 1), + YearlessDatetime(8, 31))]).yearless_date_ranges + assert default_seasonal_sub_windows_no_overlap.custom_temporal_sub_windows[ + 'S4'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(9, 1), + YearlessDatetime(11, 30))]).yearless_date_ranges + + assert default_seasonal_sub_windows_no_overlap.additional_temp_sub_wndws_container == {} + + # used to generate proper description of temporal sub-window dimenison in the netCDF file + assert default_seasonal_sub_windows_no_overlap.metadata == { + 'Temporal sub-window type': + 'seasons', + 'Overlap': + '0 days', + 'Pretty Names [MM-DD]': + 'S1: 12-01 to 02-28, S2: 03-01 to 05-31, S3: 06-01 to 08-31, S4: 09-01 to 11-30' + } + + +def test_faulty_temporal_sub_windows_creator(): + # temporal_sub_window_type must be either 'months' or 'seasons' + with pytest.raises(InvalidTemporalSubWindowError): + TemporalSubWindowsCreator( + temporal_sub_window_type="not-a-default-value", + overlap=0, + custom_file=None) + + +def test_load_custom_temporal_sub_windows(temporal_sub_windows_custom): + # 'temporal_sub_window_type' corresponds to the defined temporal sub-windows in the provided json file + # the file may contain any number of temporal sub-windows, but one is selected via a keyword argument 'temporal_sub_window_type' for each TemporalSubWindowsCreator instance + + assert temporal_sub_windows_custom.custom_file == Path(__file__).resolve( + ).parent.parent / 'tests' / 'test_data' / 'intra_annual' / 'custom_intra_annual_windows.json' + + assert temporal_sub_windows_custom.temporal_sub_window_type == 'custom' + + assert temporal_sub_windows_custom.overlap == 0 + + assert temporal_sub_windows_custom.available_temp_sub_wndws == [ + 'seasons', 'months', 'custom' + ] + + assert temporal_sub_windows_custom.names == [ + 'star wars month', 'halloween season', 'advent', 'movember', + 'christmas' + ] + + assert temporal_sub_windows_custom.custom_temporal_sub_windows[ + 'star wars month'].yearless_date_ranges == TsDistributor( + yearless_date_ranges=[( + YearlessDatetime(5, 1), + YearlessDatetime(5, 31))]).yearless_date_ranges + assert temporal_sub_windows_custom.custom_temporal_sub_windows[ + 'halloween season'].yearless_date_ranges == TsDistributor( + yearless_date_ranges=[( + YearlessDatetime(10, 1), + YearlessDatetime(10, 31))]).yearless_date_ranges + assert temporal_sub_windows_custom.custom_temporal_sub_windows[ + 'advent'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(12, 1), + YearlessDatetime(12, 24))]).yearless_date_ranges + assert temporal_sub_windows_custom.custom_temporal_sub_windows[ + 'movember'].yearless_date_ranges == TsDistributor( + yearless_date_ranges=[( + YearlessDatetime(11, 1), + YearlessDatetime(11, 30))]).yearless_date_ranges + assert temporal_sub_windows_custom.custom_temporal_sub_windows[ + 'christmas'].yearless_date_ranges == TsDistributor( + yearless_date_ranges=[( + YearlessDatetime(12, 24), + YearlessDatetime(12, 26))]).yearless_date_ranges + + +def test_load_nonexistent_custom_temporal_sub_windows(): + with pytest.raises(FileNotFoundError): + TemporalSubWindowsCreator(temporal_sub_window_type='whatever', + overlap=0, + custom_file='i_dont_exist.json') + + +def test_overlap_parameter(seasonal_sub_windows_positive_overlap, + seasonal_sub_windows_negative_overlap): + # overlap is added to both ends of the temporal sub-windows + assert seasonal_sub_windows_positive_overlap.overlap == 5 + assert seasonal_sub_windows_positive_overlap.custom_temporal_sub_windows[ + 'S1'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(11, 26), + YearlessDatetime(3, 5))]).yearless_date_ranges + assert seasonal_sub_windows_positive_overlap.custom_temporal_sub_windows[ + 'S2'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(2, 24), + YearlessDatetime(6, 5))]).yearless_date_ranges + assert seasonal_sub_windows_positive_overlap.custom_temporal_sub_windows[ + 'S3'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(5, 27), + YearlessDatetime(9, 5))]).yearless_date_ranges + assert seasonal_sub_windows_positive_overlap.custom_temporal_sub_windows[ + 'S4'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(8, 27), + YearlessDatetime(12, 5))]).yearless_date_ranges + assert seasonal_sub_windows_positive_overlap.metadata == { + 'Temporal sub-window type': + 'seasons', + 'Overlap': + '5 days', + 'Pretty Names [MM-DD]': + 'S1: 11-26 to 03-05, S2: 02-24 to 06-05, S3: 05-27 to 09-05, S4: 08-27 to 12-05' + } + + # overlap is subtracted from both ends of the temporal sub-windows + assert seasonal_sub_windows_negative_overlap.overlap == -5 + assert seasonal_sub_windows_negative_overlap.custom_temporal_sub_windows[ + 'S1'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(12, 6), + YearlessDatetime(2, 23))]).yearless_date_ranges + assert seasonal_sub_windows_negative_overlap.custom_temporal_sub_windows[ + 'S2'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(3, 6), + YearlessDatetime(5, 26))]).yearless_date_ranges + assert seasonal_sub_windows_negative_overlap.custom_temporal_sub_windows[ + 'S3'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(6, 6), + YearlessDatetime(8, 26))]).yearless_date_ranges + assert seasonal_sub_windows_negative_overlap.custom_temporal_sub_windows[ + 'S4'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(9, 6), + YearlessDatetime(11, 25))]).yearless_date_ranges + assert seasonal_sub_windows_negative_overlap.metadata == { + 'Temporal sub-window type': + 'seasons', + 'Overlap': + '-5 days', + 'Pretty Names [MM-DD]': + 'S1: 12-06 to 02-23, S2: 03-06 to 05-26, S3: 06-06 to 08-26, S4: 09-06 to 11-25' + } + + # overlap is rounded to the nearest integer + float_overlap = TemporalSubWindowsCreator( + temporal_sub_window_type="seasons", overlap=5.2, custom_file=None) + + assert float_overlap.overlap == 5 + + assert [ + x.yearless_date_ranges + for x in float_overlap.custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges for x in seasonal_sub_windows_positive_overlap. + custom_temporal_sub_windows.values() + ] + + assert float_overlap.metadata == seasonal_sub_windows_positive_overlap.metadata + + # make sure cyclic boundaries are handled correctly + # no overlap and +/-365 days overlap should be the same + aa = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=0, + custom_file=None) + aa_plus = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=365, + custom_file=None) + aa_minus = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=-365, + custom_file=None) + # +376 days overlap should be the same as +11 days overlap and -354 days overlap + bb = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=376, + custom_file=None) + bb_plus = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=11, + custom_file=None) + bb_minus = TemporalSubWindowsCreator(temporal_sub_window_type="seasons", + overlap=-354, + custom_file=None) + + assert [ + x.yearless_date_ranges + for x in aa.custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges + for x in aa_plus.custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges + for x in aa_minus.custom_temporal_sub_windows.values() + ] + assert [ + x.yearless_date_ranges + for x in bb.custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges + for x in bb_plus.custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges + for x in bb_minus.custom_temporal_sub_windows.values() + ] + + +def test_add_temporal_sub_window(seasonal_sub_windows_positive_overlap, + additional_temp_sub_window): + seasonal_sub_windows_positive_overlap.add_temp_sub_wndw( + additional_temp_sub_window) + + assert seasonal_sub_windows_positive_overlap.names == [ + 'S1', 'S2', 'S3', 'S4', 'Feynman' + ] + + assert seasonal_sub_windows_positive_overlap.custom_temporal_sub_windows[ + 'Feynman'].date_ranges == [(datetime(1918, 5, + 11), datetime(1988, 2, 15))] + + # if a new window is to be added, it should not have a name that already exists. In this case, this new window should not be added + name_exists = NewSubWindow(name="S1", + begin_date=YearlessDatetime(5, 11), + end_date=YearlessDatetime(2, 15)) + + seasonal_sub_windows_positive_overlap_copy = deepcopy( + seasonal_sub_windows_positive_overlap) + + seasonal_sub_windows_positive_overlap_copy.add_temp_sub_wndw(name_exists) + + assert seasonal_sub_windows_positive_overlap_copy.names == seasonal_sub_windows_positive_overlap.names + + assert [ + x.yearless_date_ranges + for x in seasonal_sub_windows_positive_overlap_copy. + custom_temporal_sub_windows.values() + ] == [ + x.yearless_date_ranges for x in seasonal_sub_windows_positive_overlap. + custom_temporal_sub_windows.values() + ] + + # if a new window is added and specified to become the first window, it should be added at the beginning + seasonal_sub_windows_positive_overlap_copy.add_temp_sub_wndw( + NewSubWindow('I am first', YearlessDatetime(1, 1), + YearlessDatetime(2, 2)), + insert_as_first_wndw=True) + assert seasonal_sub_windows_positive_overlap_copy.names[0] == 'I am first' + + # if an existing window is to be overwritten, it should exist. + seasonal_sub_windows_positive_overlap_copy = deepcopy( + seasonal_sub_windows_positive_overlap) + seasonal_sub_windows_positive_overlap_copy.overwrite_temp_sub_wndw( + name_exists) + + assert seasonal_sub_windows_positive_overlap_copy.names == seasonal_sub_windows_positive_overlap.names + + assert seasonal_sub_windows_positive_overlap_copy.custom_temporal_sub_windows[ + 'S1'].yearless_date_ranges == TsDistributor(yearless_date_ranges=[( + YearlessDatetime(5, 11), + YearlessDatetime(2, 15))]).yearless_date_ranges + + # when overwriting an existing window, it should be possible to use a new datatype for the dates (but always either datetime or YearlessDatetime) + seasonal_sub_windows_positive_overlap_copy.overwrite_temp_sub_wndw( + NewSubWindow('S1', datetime(2023, 1, 1), datetime(2023, 12, 31))) + + assert seasonal_sub_windows_positive_overlap_copy.custom_temporal_sub_windows[ + 'S1'].yearless_date_ranges == None + assert seasonal_sub_windows_positive_overlap_copy.custom_temporal_sub_windows[ + 'S1'].date_ranges == TsDistributor( + date_ranges=[(datetime(2023, 1, 1), + datetime(2023, 12, 31))]).date_ranges diff --git a/tests/test_netcdf_transcription.py b/tests/test_netcdf_transcription.py new file mode 100644 index 0000000..e33eda8 --- /dev/null +++ b/tests/test_netcdf_transcription.py @@ -0,0 +1,857 @@ +import pytest +from copy import deepcopy +from datetime import datetime +import xarray as xr +import shutil +from pathlib import Path +from typing import Union, Optional, Tuple, List +import logging +import numpy as np +import tempfile + +from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber, TemporalSubWindowMismatchError +from qa4sm_reader.intra_annual_temp_windows import TemporalSubWindowsCreator, NewSubWindow, InvalidTemporalSubWindowError +import qa4sm_reader.globals as globals +from qa4sm_reader.utils import log_function_call +import qa4sm_reader.plot_all as pa + +log_file_path = Path( + __file__).parent.parent / '.logs' / "test_netcdf_transcription.log" +if not log_file_path.parent.exists(): + log_file_path.parent.mkdir(parents=True, exist_ok=True) + +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s', + filename=str(log_file_path)) + + +#-------------------------------------------------------Fixtures-------------------------------------------------------- +@pytest.fixture(scope="module") +def tmp_paths(): + '''Fixture to keep track of temporary directories created during a test run and clean them up after the test run''' + paths = [] + yield paths + + for path in paths: + shutil.rmtree(path, ignore_errors=True) + + +@pytest.fixture +def monthly_tsws() -> TemporalSubWindowsCreator: + return TemporalSubWindowsCreator(temporal_sub_window_type='months', + overlap=0, + custom_file=None) + + +@pytest.fixture +def monthly_tsws_incl_bulk(monthly_tsws) -> TemporalSubWindowsCreator: + bulk_wndw = NewSubWindow('bulk', datetime(1950, 1, 1), + datetime(2020, 1, 1)) + return monthly_tsws.add_temp_sub_wndw(bulk_wndw, insert_as_first_wndw=True) + + +@pytest.fixture +def seasonal_tsws() -> TemporalSubWindowsCreator: + return TemporalSubWindowsCreator(temporal_sub_window_type='seasons', + overlap=0, + custom_file=None) + + +@pytest.fixture +def seasonal_tsws_incl_bulk() -> TemporalSubWindowsCreator: + seasonal_tsws = TemporalSubWindowsCreator( + temporal_sub_window_type='seasons', overlap=0, custom_file=None) + bulk_wndw = NewSubWindow('bulk', datetime(1950, 1, 1), + datetime(2020, 1, 1)) + seasonal_tsws.add_temp_sub_wndw(bulk_wndw, insert_as_first_wndw=True) + return seasonal_tsws + + +@pytest.fixture +def seasonal_pytesmo_file(TEST_DATA_DIR) -> Path: + return Path( + TEST_DATA_DIR / 'intra_annual' / 'seasonal' / + '0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_pytesmo.nc' + ) + + +@pytest.fixture +def seasonal_qa4sm_file(TEST_DATA_DIR) -> Path: + return Path( + TEST_DATA_DIR / 'intra_annual' / 'seasonal' / + '0-ERA5.swvl1_with_1-ESA_CCI_SM_combined.sm_with_2-ESA_CCI_SM_combined.sm_with_3-ESA_CCI_SM_combined.sm_with_4-ESA_CCI_SM_combined.sm.CI_tsw_seasons_qa4sm.nc' + ) + + +@pytest.fixture +def monthly_pytesmo_file(TEST_DATA_DIR) -> Path: + return Path(TEST_DATA_DIR / 'intra_annual' / 'monthly' / + '0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_pytesmo.nc') + + +@pytest.fixture +def monthly_qa4sm_file(TEST_DATA_DIR) -> Path: + return Path(TEST_DATA_DIR / 'intra_annual' / 'monthly' / + '0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_qa4sm.nc') + + +#------------------Helper functions------------------------ + + +@log_function_call +def get_tmp_whole_test_data_dir( + TEST_DATA_DIR: Path, tmp_paths: List[Path]) -> Tuple[Path, List[Path]]: + '''Copy the whole test data directory to a temporary directory and return the path to the temporary directory + + Parameters + ---------- + + TEST_DATA_DIR: Path + The path to the test data directory + tmp_paths: List[Path] + **Don't modify this list directly. Keeps track of created tmp dirs during a test run** + + Returns + ------- + + Tuple[Path, List[Path]] + A tuple containing the path to the temporary directory and the list of temporary directories that have been created during the test + ''' + if isinstance(TEST_DATA_DIR, str): + TEST_DATA_DIR = Path(TEST_DATA_DIR) + temp_dir = Path(tempfile.mkdtemp()) + shutil.copytree(TEST_DATA_DIR, temp_dir / TEST_DATA_DIR.name) + + return temp_dir / TEST_DATA_DIR.name, tmp_paths + + +@log_function_call +def get_tmp_single_test_file(test_file: Path, + tmp_paths: List[Path]) -> Tuple[Path, List[Path]]: + '''Copy a single test file to a temporary directory and return the path to the temporary file + + Parameters + ---------- + + TEST_DATA_DIR: Path + The path to the test data directory + tmp_paths: List[Path] + **Don't modify this list directly. Keeps track of created tmp files during a test run** + + Returns + ------- + + Tuple[Path, List[Path]] + A tuple containing the path to the temporary file and the list of temporary files that have been created during the test + ''' + if isinstance(test_file, str): + test_file = Path(test_file) + temp_dir = Path(tempfile.mkdtemp()) + temp_file_path = temp_dir / test_file.name + shutil.copy(test_file, temp_file_path) + return temp_file_path, tmp_paths + + +@log_function_call +def run_test_transcriber( + ncfile: Path, + intra_annual_slices: Union[None, TemporalSubWindowsCreator], + keep_pytesmo_ncfile: bool, + write_outfile: Optional[bool] = True +) -> Tuple[Pytesmo2Qa4smResultsTranscriber, xr.Dataset]: + '''Run a test on the transcriber with the given parameters + + Parameters + ---------- + + ncfile: Path + The path to the netcdf file to be transcribed + intra_annual_slices: Union[None, TemporalSubWindowsCreator] + The temporal sub-windows to be used for the transcription + keep_pytesmo_ncfile: bool + Whether to keep the original pytesmo nc file + write_outfile: Optional[bool] + Whether to write the transcribed dataset to a new netcdf file. Default is True + + Returns + ------- + Tuple[Pytesmo2Qa4smResultsTranscriber, xr.Dataset] + A tuple containing the transcriber instance and the transcribed dataset''' + + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=ncfile, + intra_annual_slices=intra_annual_slices, + keep_pytesmo_ncfile=keep_pytesmo_ncfile) + + logging.info(f"{transcriber=}") + + assert transcriber.exists + + transcriber.output_file_name = ncfile + transcribed_ds = transcriber.get_transcribed_dataset() + + assert isinstance(transcribed_ds, xr.Dataset) + + if write_outfile: + transcriber.write_to_netcdf(transcriber.output_file_name) + assert Path(transcriber.output_file_name).exists() + + if keep_pytesmo_ncfile: + assert Path(ncfile.parent, + ncfile.name + globals.OLD_NCFILE_SUFFIX).exists() + else: + assert not Path(ncfile.parent, + ncfile.name + globals.OLD_NCFILE_SUFFIX).exists() + + return transcriber, transcribed_ds + + +#------------------Check that all required consts are defined------------------ +@log_function_call +def test_qr_globals_attributes(): + attributes = [ + 'METRICS', 'TC_METRICS', 'NON_METRICS', 'METADATA_TEMPLATE', + 'IMPLEMENTED_COMPRESSIONS', 'ALLOWED_COMPRESSION_LEVELS', + 'INTRA_ANNUAL_METRIC_TEMPLATE', 'INTRA_ANNUAL_TCOL_METRIC_TEMPLATE', + 'TEMPORAL_SUB_WINDOW_SEPARATOR', 'DEFAULT_TSW', + 'TEMPORAL_SUB_WINDOW_NC_COORD_NAME', 'MAX_NUM_DS_PER_VAL_RUN', + 'DATASETS', 'TEMPORAL_SUB_WINDOWS' + ] + + assert any(attr in dir(globals) for attr in attributes) + + assert 'zlib' in globals.IMPLEMENTED_COMPRESSIONS + + assert globals.ALLOWED_COMPRESSION_LEVELS == [None, *list(range(10))] + + assert globals.INTRA_ANNUAL_METRIC_TEMPLATE == [ + "{tsw}", globals.TEMPORAL_SUB_WINDOW_SEPARATOR, "{metric}" + ] + + assert globals.INTRA_ANNUAL_TCOL_METRIC_TEMPLATE == globals.INTRA_ANNUAL_TCOL_METRIC_TEMPLATE == [ + "{tsw}", globals.TEMPORAL_SUB_WINDOW_SEPARATOR, "{metric}", "_", + "{number}-{dataset}", "_between_" + ] + + assert len(globals.TEMPORAL_SUB_WINDOW_SEPARATOR) == 1 + + assert globals.TEMPORAL_SUB_WINDOWS == { + "seasons": { + "S1": [[12, 1], [2, 28]], + "S2": [[3, 1], [5, 31]], + "S3": [[6, 1], [8, 31]], + "S4": [[9, 1], [11, 30]], + }, + "months": { + "Jan": [[1, 1], [1, 31]], + "Feb": [[2, 1], [2, 28]], + "Mar": [[3, 1], [3, 31]], + "Apr": [[4, 1], [4, 30]], + 'May': [[5, 1], [5, 31]], + "Jun": [[6, 1], [6, 30]], + "Jul": [[7, 1], [7, 31]], + "Aug": [[8, 1], [8, 31]], + "Sep": [[9, 1], [9, 30]], + "Oct": [[10, 1], [10, 31]], + "Nov": [[11, 1], [11, 30]], + "Dec": [[12, 1], [12, 31]], + } + } + + +# ------------------Test Pytesmo2Qa4smResultsTranscriber------------------------------------------------------- +#-------------------------------------------------------------------------------------------------------------- + +#------------Test instantiation of Pytesmo2Qa4smResultsTranscriber, attrs and basic functionalities------------ + + +@log_function_call +def test_on_non_existing_file(): + with pytest.raises(FileNotFoundError): + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results='non_existing.nc', + intra_annual_slices=None, + keep_pytesmo_ncfile=False) + + +@log_function_call +def test_invalid_temp_subwins(seasonal_tsws_incl_bulk, + tmp_paths, + TEST_DATA_DIR, + test_file: Optional[Path] = None): + logging.info( + f'test_invalid_temp_subwins: {seasonal_tsws_incl_bulk=}, {tmp_paths=}, {TEST_DATA_DIR=}, {test_file=}' + ) + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + + # Test that the transcriber raises an InvalidTemporalSubWindowError when the intra_annual_slices parameter is neither None nor a TemporalSubWindowsCreator instance + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + with pytest.raises(InvalidTemporalSubWindowError): + _, ds = run_test_transcriber(tmp_test_file, + intra_annual_slices='faulty', + keep_pytesmo_ncfile=False) + ds.close() + + +@log_function_call +def test_invalid_temporalsubwindowscreator(seasonal_tsws_incl_bulk, + tmp_paths, + TEST_DATA_DIR, + test_file: Optional[Path] = None): + logging.info( + f'test_invalid_temporalsubwindowscreator: {seasonal_tsws_incl_bulk=}, {tmp_paths=}, {TEST_DATA_DIR=}, {test_file=}' + ) + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + + # Test that the transcriber raises an InvalidTemporalSubWindowError when the intra_annual_slices parameter is a faulty TemporalSubWindowsCreator instance + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + + with pytest.raises(InvalidTemporalSubWindowError): + _, ds = run_test_transcriber( + tmp_test_file, + intra_annual_slices=TemporalSubWindowsCreator('gibberish'), + keep_pytesmo_ncfile=False) + ds.close() + + +@log_function_call +def test_temp_subwin_mismatch(seasonal_tsws_incl_bulk, + tmp_paths, + TEST_DATA_DIR, + test_file: Optional[Path] = None): + logging.info( + f'test_temp_subwin_mismatch: {seasonal_tsws_incl_bulk=}, {tmp_paths=}, {TEST_DATA_DIR=}, {test_file=}' + ) + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + + # Test that the transcriber raises a TemporalSubWindowMismatchError when the intra_annual_slices parameter is a TemporalSubWindowsCreator instance that does not match the temporal sub-windows in the pytesmo_results file + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + with pytest.raises(TemporalSubWindowMismatchError): + _, ds = run_test_transcriber( + tmp_test_file, + intra_annual_slices=seasonal_tsws_incl_bulk, + keep_pytesmo_ncfile=False) + ds.close() + + +@log_function_call +def test_keep_pytesmo_ncfile(TEST_DATA_DIR, test_file: Optional[Path] = None): + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + transcriber, ds = run_test_transcriber(tmp_test_file, + intra_annual_slices=None, + keep_pytesmo_ncfile=True) + transcriber.pytesmo_results.close() + ds.close() + + +@log_function_call +def test_dont_keep_pytesmo_ncfile(TEST_DATA_DIR, + test_file: Optional[Path] = None): + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + _, ds = run_test_transcriber(tmp_test_file, + intra_annual_slices=None, + keep_pytesmo_ncfile=False) + ds.close() + + +@log_function_call +def test_ncfile_compression(TEST_DATA_DIR, test_file: Optional[Path] = None): + if test_file is None: + test_file = Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc') + tmp_test_file = get_tmp_single_test_file(test_file, tmp_paths)[0] + transcriber, ds = run_test_transcriber(tmp_test_file, + intra_annual_slices=None, + keep_pytesmo_ncfile=False, + write_outfile=True) + + # only zlib compression is implemented so far, with compression levels 0-9 + with pytest.raises(NotImplementedError): + transcriber.compress(transcriber.output_file_name, 'not_implemented', + 0) + transcriber.compress(transcriber.output_file_name, 'zlib', -1) + transcriber.compress(transcriber.output_file_name, 'not_implemented', + -1) + + # test the case of a non-existing file + assert not transcriber.compress('non_existing_file.nc', 'zlib', 0) + + # test successful compression with zlib and compression level 9 + assert transcriber.compress(transcriber.output_file_name, 'zlib', 9) + + # test successful compression with defaults + assert transcriber.compress(transcriber.output_file_name) + + ds.close() + + +#-------------------Test default case (= no temporal sub-windows)-------------------------------------------- + + +@log_function_call +def test_bulk_case_transcription(TEST_DATA_DIR, tmp_paths): + # Test transcription of all original test data nc files (== bulk case files) + tmp_test_data_dir, _ = get_tmp_whole_test_data_dir(TEST_DATA_DIR, + tmp_paths) + nc_files = [ + path for path in Path(tmp_test_data_dir).rglob('*.nc') + if 'intra_annual' not in str(path) + ] + logging.info(f"Found {len(nc_files)} .nc files for transcription.") + + for ncf in nc_files: + _, ds = run_test_transcriber(ncf, + intra_annual_slices=None, + keep_pytesmo_ncfile=False, + write_outfile=True) + assert ds.sel( + {globals.TEMPORAL_SUB_WINDOW_NC_COORD_NAME: + globals.DEFAULT_TSW}) == globals.DEFAULT_TSW + logging.info(f"Successfully transcribed file: {ncf}") + ds.close() + + if tmp_test_data_dir.exists(): + shutil.rmtree(tmp_test_data_dir, ignore_errors=True) + + +#-------------------------------------------Test with intra-annual metrics--------------------------------------------- + + +@log_function_call +def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, + monthly_pytesmo_file, monthly_qa4sm_file): + ''' + Test the transcription of the test files with the correct temporal sub-windows and the correct output nc files''' + + # test that the test files exist + assert seasonal_pytesmo_file.exists + assert seasonal_qa4sm_file.exists + assert monthly_pytesmo_file.exists + assert monthly_qa4sm_file.exists + + # instantiate proper TemporalSubWindowsCreator instances for the corresponding test files + bulk_tsw = NewSubWindow( + 'bulk', datetime(1900, 1, 1), datetime(2000, 1, 1) + ) # if ever the default changes away from 'bulk, this will need to be taken into account + + seasons_tsws = TemporalSubWindowsCreator('seasons') + seasons_tsws.add_temp_sub_wndw(bulk_tsw, insert_as_first_wndw=True) + + monthly_tsws = TemporalSubWindowsCreator('months') + monthly_tsws.add_temp_sub_wndw(bulk_tsw, insert_as_first_wndw=True) + + # make sure the above defined temporal sub-windows are indeed the ones on the expected output nc files + assert seasons_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + seasonal_qa4sm_file) + assert monthly_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + monthly_qa4sm_file) + + # instantiate transcribers for the test files + seasonal_transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=seasonal_pytesmo_file, + intra_annual_slices=seasons_tsws, + keep_pytesmo_ncfile=False + ) # deletion or keeping of the original pytesmo nc file only triggers when the transcriber is written to a new file, which is not the case here + + monthly_transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=monthly_pytesmo_file, + intra_annual_slices=monthly_tsws, + keep_pytesmo_ncfile=False) + + assert seasonal_transcriber.exists + assert monthly_transcriber.exists + + # get the transcribed datasets + seasonal_transcribed_ds = seasonal_transcriber.get_transcribed_dataset() + monthly_transcribed_ds = monthly_transcriber.get_transcribed_dataset() + + # check that the transcribed datasets are indeed xarray.Dataset instances + assert isinstance(seasonal_transcribed_ds, xr.Dataset) + assert isinstance(monthly_transcribed_ds, xr.Dataset) + + # check that the transcribed datasets are equal to the expected output files + # xr.testing.assert_equal(ds1, ds2) runs a more detailed comparison of the two datasets as compared to ds1.equals(ds2) + with xr.open_dataset(seasonal_qa4sm_file) as f: + expected_seasonal_ds = f + with xr.open_dataset(monthly_qa4sm_file) as f: + expected_monthly_ds = f + + #!NOTE: pytesmo/QA4SM offer the possibility to calculate Kendall's tau, but currently this metric is deactivated. + #! Therefore, in a real validation run no tau related metrics will be transcribed to the QA4SM file, even though they might be present in the pytesmo file. + + # drop the tau related metrics from the expected datasets + for var in expected_seasonal_ds.data_vars: + if 'tau' in var: + logging.info( + f"Dropping variable {var} from expected seasonal dataset") + expected_seasonal_ds = expected_seasonal_ds.drop_vars(var) + + for var in expected_monthly_ds.data_vars: + if 'tau' in var: + logging.info( + f"Dropping variable {var} from expected monthly dataset") + expected_monthly_ds = expected_monthly_ds.drop_vars(var) + + assert None == xr.testing.assert_equal( + monthly_transcribed_ds, + expected_monthly_ds) # returns None if the datasets are equal + assert None == xr.testing.assert_equal( + seasonal_transcribed_ds, + expected_seasonal_ds) # returns None if the datasets are equal + + # the method above does not check attrs of the datasets, so we do it here + # Creation date and qa4sm_reader might differ, so we exclude them from the comparison + datasets = [ + monthly_transcribed_ds, expected_monthly_ds, seasonal_transcribed_ds, + expected_seasonal_ds + ] + attrs_to_be_excluded = ['date_created', 'qa4sm_version'] + for ds in datasets: + for attr in attrs_to_be_excluded: + if attr in ds.attrs: + del ds.attrs[attr] + + assert seasonal_transcribed_ds.attrs == expected_seasonal_ds.attrs + assert monthly_transcribed_ds.attrs == expected_monthly_ds.attrs + + # Compare the coordinate attributes + for coord in seasonal_transcribed_ds.coords: + for attr in seasonal_transcribed_ds[coord].attrs: + if isinstance(seasonal_transcribed_ds[coord].attrs[attr], + (list, np.ndarray)): + assert np.array_equal( + seasonal_transcribed_ds[coord].attrs[attr], + expected_seasonal_ds[coord].attrs[attr] + ), f"Attributes for coordinate {coord} do not match in seasonal dataset" + else: + assert seasonal_transcribed_ds[coord].attrs[ + attr] == expected_seasonal_ds[coord].attrs[ + attr], f"Attributes for coordinate {coord} do not match in seasonal dataset: '{seasonal_transcribed_ds[coord].attrs[attr]}' =! '{expected_seasonal_ds[coord].attrs[attr]}'" + + for coord in monthly_transcribed_ds.coords: + for attr in monthly_transcribed_ds[coord].attrs: + if isinstance(monthly_transcribed_ds[coord].attrs[attr], + (list, np.ndarray)): + assert np.array_equal( + monthly_transcribed_ds[coord].attrs[attr], + expected_monthly_ds[coord].attrs[attr] + ), f"Attributes for coordinate {coord} do not match in monthly dataset" + else: + assert monthly_transcribed_ds[coord].attrs[ + attr] == expected_monthly_ds[coord].attrs[ + attr], f"Attributes for coordinate {coord} do not match in monthly dataset: '{monthly_transcribed_ds[coord].attrs[attr]}' =! '{expected_monthly_ds[coord].attrs[attr]}'" + + seasonal_transcribed_ds.close() + monthly_transcribed_ds.close() + + +#TODO: refactoring +@log_function_call +def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, tmp_paths): + ''' + Test the plotting of the test files with temporal sub-windows beyond the bulk case (this scenario covered in other tests) + ''' + + tmp_seasonal_file, _ = get_tmp_single_test_file(seasonal_qa4sm_file, + tmp_paths) + tmp_seasonal_dir = tmp_seasonal_file.parent + + tmp_monthly_file, _ = get_tmp_single_test_file(monthly_qa4sm_file, + tmp_paths) + tmp_monthly_dir = tmp_monthly_file.parent + + # check the output directories + + pa.plot_all( + filepath=tmp_seasonal_file, + temporal_sub_windows=Pytesmo2Qa4smResultsTranscriber. + get_tsws_from_ncfile(tmp_seasonal_file), + out_dir=tmp_seasonal_dir, + save_all=True, + out_type=['png', 'svg'], + ) + + metrics_not_plotted = [ + *globals.metric_groups[0], *globals.metric_groups[3], + *globals._metadata_exclude + ] + + tsw_dirs_expected = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + tmp_seasonal_file) + if globals.DEFAULT_TSW in tsw_dirs_expected: + tsw_dirs_expected.remove( + globals.DEFAULT_TSW) # we're not checking the default case here + + for tsw in tsw_dirs_expected: + assert Path( + tmp_seasonal_dir / + tsw).is_dir(), f"{tmp_seasonal_dir / tsw} is not a directory" + + # only metrics and tcol metrics get their dedicated plots for each temporal sub-window + for metric in [ + *list(globals.METRICS.keys()), *list(globals.TC_METRICS.keys()) + ]: + if metric in metrics_not_plotted: + continue + assert Path( + tmp_seasonal_dir / tsw / f"{tsw}_boxplot_{metric}.png" + ).exists( + ), f"{tmp_seasonal_dir / tsw / f'{tsw}_boxplot_{metric}.png'} does not exist" + assert Path( + tmp_seasonal_dir / tsw / f"{tsw}_boxplot_{metric}.svg" + ).exists( + ), f"{tmp_seasonal_dir / tsw / f'{tsw}_boxplot_{metric}.svg'} does not exist" + + assert Path( + tmp_seasonal_dir / tsw / f'{tsw}_statistics_table.csv' + ).is_file( + ), f"{tmp_seasonal_dir / tsw / f'{tsw}_statistics_table.csv'} does not exist" + + # check intra-annual-metric-exclusive comparison boxplots + assert Path(tmp_seasonal_dir / 'comparison_boxplots').is_dir() + for metric in globals.METRICS: + if metric in metrics_not_plotted: + continue + assert Path( + tmp_seasonal_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='png') + ).exists( + ), f"{tmp_seasonal_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='png')} does not exist" + assert Path( + tmp_seasonal_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='svg') + ).exists( + ), f"{tmp_seasonal_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='svg')} does not exist" + + # now check the file with monthly temporal sub-windows and without tcol metrics + + pa.plot_all( + filepath=tmp_monthly_file, + temporal_sub_windows=Pytesmo2Qa4smResultsTranscriber. + get_tsws_from_ncfile(tmp_monthly_file), + out_dir=tmp_monthly_dir, + save_all=True, + save_metadata=True, + out_type=['png', 'svg'], + ) + + tsw_dirs_expected = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + tmp_monthly_file) + if globals.DEFAULT_TSW in tsw_dirs_expected: + tsw_dirs_expected.remove(globals.DEFAULT_TSW) + + for t, tsw in enumerate(tsw_dirs_expected): + assert Path( + tmp_monthly_dir / + tsw).is_dir(), f"{tmp_monthly_dir / tsw} is not a directory" + + # no tcol metrics present here + for metric in [*list(globals.METRICS.keys())]: + if metric in metrics_not_plotted: + continue + # tsw specific plots + assert Path( + tmp_monthly_dir / tsw / f"{tsw}_boxplot_{metric}.png" + ).exists( + ), f"{tmp_monthly_dir / tsw / f'{tsw}_boxplot_{metric}.png'} does not exist" + assert Path( + tmp_monthly_dir / tsw / f"{tsw}_boxplot_{metric}.svg" + ).exists( + ), f"{tmp_monthly_dir / tsw / f'{tsw}_boxplot_{metric}.svg'} does not exist" + + if t == 0: + #comparison boxplots + assert Path(tmp_seasonal_dir / 'comparison_boxplots').is_dir() + assert Path( + tmp_seasonal_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='png') + ).exists( + ), f"{tmp_seasonal_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='png')} does not exist" + assert Path( + tmp_seasonal_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='svg') + ).exists( + ), f"{tmp_seasonal_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='svg')} does not exist" + assert Path( + tmp_monthly_dir / tsw / f'{tsw}_statistics_table.csv' + ).is_file( + ), f"{tmp_monthly_dir / tsw / f'{tsw}_statistics_table.csv'} does not exist" + + +@log_function_call +def test_write_to_netcdf_default(TEST_DATA_DIR, tmp_paths): + temp_netcdf_file: Path = get_tmp_single_test_file( + Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc'), tmp_paths)[0] + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=temp_netcdf_file) + + transcribed_ds = transcriber.get_transcribed_dataset() + # Write to NetCDF + transcriber.write_to_netcdf(temp_netcdf_file) + + # Check if the file is created + assert temp_netcdf_file.exists() + + # Close the datasets + transcriber.pytesmo_results.close() + transcriber.transcribed_dataset.close() + transcribed_ds.close() + + +@log_function_call +def test_write_to_netcdf_custom_encoding(TEST_DATA_DIR, tmp_paths): + temp_netcdf_file: Path = get_tmp_single_test_file( + Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc'), tmp_paths)[0] + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=temp_netcdf_file) + + transcribed_ds = transcriber.get_transcribed_dataset() + + custom_encoding = { + str(var): { + 'zlib': True, + 'complevel': 1 + } + for var in transcribed_ds.variables + if not np.issubdtype(transcribed_ds[var].dtype, np.object_) + } + + # Write to NetCDF with custom encoding + transcriber.write_to_netcdf(temp_netcdf_file, encoding=custom_encoding) + + # Check if the file is created + assert temp_netcdf_file.exists() + + # Close the datasets + transcriber.pytesmo_results.close() + transcriber.transcribed_dataset.close() + transcribed_ds.close() + + +def test_get_transcribed_dataset(TEST_DATA_DIR, tmp_paths): + temp_netcdf_file = get_tmp_single_test_file( + Path(TEST_DATA_DIR / 'basic' / + '0-ISMN.soil moisture_with_1-C3S.sm.nc'), tmp_paths)[0] + transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=temp_netcdf_file) + + # Get the transcribed dataset + transcribed_dataset = transcriber.get_transcribed_dataset() + + # Check if the transcribed dataset is an xarray Dataset + assert isinstance(transcribed_dataset, xr.Dataset) + + # Close the datasets + transcriber.pytesmo_results.close() + transcriber.transcribed_dataset.close() + transcribed_dataset.close() + + +@log_function_call +def test_is_valid_metric_name(seasonal_pytesmo_file, seasonal_tsws_incl_bulk): + # Create a mock cases + mock_tsws = seasonal_tsws_incl_bulk + mock_transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=seasonal_pytesmo_file, + intra_annual_slices=mock_tsws, + keep_pytesmo_ncfile=False) + + # Test valid metric names + tsws = mock_tsws.names + sep = '|' + dataset_combi = '_between_0-ERA5_and_3-ESA_CCI_SM_combined' + valid_metrics = globals.METRICS.keys() + + valid_metric_names = [ + f'{tsw}{sep}{metric}{dataset_combi}' for tsw in tsws + for metric in valid_metrics + ] + for metric_name in valid_metric_names: + assert mock_transcriber.is_valid_metric_name(metric_name) == True + + # Test invalid metric names with metrics that dont even exist + nonsense_metrics = ['nonsense_metric_1', 'nonsense_metric_2'] + nonsense_metric_names = [ + f'{tsw}{sep}{metric}{dataset_combi}' for tsw in tsws + for metric in nonsense_metrics + ] + for metric_name in nonsense_metric_names: + assert mock_transcriber.is_valid_metric_name(metric_name) == False + + # Test tcol metric names + tcol_metrics = globals.TC_METRICS.keys() + tcol_metric_names = [ + f'{tsw}{sep}{metric}{dataset_combi}' for tsw in tsws + for metric in tcol_metrics + ] + for metric_name in tcol_metric_names: + assert mock_transcriber.is_valid_metric_name(metric_name) == False + + +@log_function_call +def test_is_valid_tcol_metric_name(seasonal_pytesmo_file, + seasonal_tsws_incl_bulk): + # Create a mock cases + mock_tsws = seasonal_tsws_incl_bulk + mock_transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=seasonal_pytesmo_file, + intra_annual_slices=mock_tsws, + keep_pytesmo_ncfile=False) + + tcol_metric_names = [ + 'S1|snr_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|snr_2-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|snr_3-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|snr_4-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_4-ESA_CCI_SM_combined', + ] #amongst others + + for metric_name in tcol_metric_names: + assert mock_transcriber.is_valid_tcol_metric_name(metric_name) == True + + tcol_metrics_not_transcribed = [ + 'S1|snr_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|snr_ci_lower_2-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|err_std_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|err_std_ci_lower_2-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|beta_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|beta_ci_lower_2-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_2-ESA_CCI_SM_combined', + 'S1|snr_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|snr_ci_lower_3-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|err_std_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|err_std_ci_lower_3-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|beta_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|beta_ci_lower_3-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_3-ESA_CCI_SM_combined', + 'S1|snr_ci_lower_1-ESA_CCI_SM_combined_between_0-ERA5_and_1-ESA_CCI_SM_combined_and_4-ESA_CCI_SM_combined', + ] + + for metric_name in tcol_metrics_not_transcribed: + assert mock_transcriber.is_valid_tcol_metric_name(metric_name) == False + + +if __name__ == '__main__': + test_file = Path('/tmp/test_dir/0-ISMN.soil_moisture_with_1-C3S.sm.nc') + # transcriber, ds = run_test_transcriber(test_file, + # intra_annual_slices=None, + # keep_pytesmo_ncfile=True) + # transcriber.pytesmo_results.close() + # ds.close() + + test_bulk_case_transcription() diff --git a/tests/test_plot_all.py b/tests/test_plot_all.py index efabe2b..1301a4d 100644 --- a/tests/test_plot_all.py +++ b/tests/test_plot_all.py @@ -5,8 +5,11 @@ import pytest import tempfile import shutil +from pathlib import Path +from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber import qa4sm_reader.plot_all as pa +from qa4sm_reader.utils import transcribe # if sys.platform.startswith("win"): # pytestmark = pytest.mark.skip( @@ -27,15 +30,28 @@ def test_plot_all(plotdir): testfile_path = os.path.join(os.path.dirname(__file__), '..', 'tests', 'test_data', 'metadata', testfile) + temporal_sub_windows_present = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(testfile_path) + if not temporal_sub_windows_present: + dataset = transcribe(testfile_path) + + tmp_testfile_path = Path(plotdir + '/tmp_testfile.nc') + encoding={var: {'zlib': False} for var in dataset.variables} + dataset.to_netcdf(tmp_testfile_path, encoding=encoding) + testfile_path = tmp_testfile_path + temporal_sub_windows_present = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(testfile_path) + + pa.plot_all( filepath=testfile_path, + temporal_sub_windows=temporal_sub_windows_present, out_dir=plotdir, save_all=True, save_metadata=True, ) - assert len(os.listdir(plotdir)) == 60 - assert all(os.path.splitext(file)[1] in [".png", ".csv"] for file in os.listdir(plotdir)), \ - "Not all files have been saved as .png or .csv" + for tswp in temporal_sub_windows_present: + assert len(os.listdir(os.path.join(plotdir, tswp))) == 60 + assert all(os.path.splitext(file)[1] in [".png", ".csv"] for file in os.listdir(os.path.join(plotdir, tswp))), \ + "Not all files have been saved as .png or .csv" shutil.rmtree(plotdir) diff --git a/tests/test_plotter.py b/tests/test_plotter.py index 1dc9d6a..0c2b4a6 100644 --- a/tests/test_plotter.py +++ b/tests/test_plotter.py @@ -601,3 +601,7 @@ def test_average_non_additive(): # Included in the standard interval assert 0.5 < avg < 0.7 assert avg != np.mean(values) + + +def test_logo_exists(): + assert os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'src', 'qa4sm_reader', 'static', 'images', 'logo', 'QA4SM_logo_long.png')) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e67fe2..78c67c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,25 @@ +import os +from pathlib import Path +from glob import glob +import xarray as xr + import qa4sm_reader +from qa4sm_reader.utils import transcribe +from qa4sm_reader.globals import TEMPORAL_SUB_WINDOW_NC_COORD_NAME def test_get_version(): assert qa4sm_reader.__version__ != 'unknown' + +def test_transcribe_all_testfiles(): + # check if all test files can be transcribed for subsequent tests. proper testing of the transcription is done in test_netcdf_transcription.py + TEST_FILE_ROOT = Path(Path(os.path.dirname(os.path.abspath(__file__))).parent, 'tests', 'test_data') + test_files = [ + x for x in glob(str(TEST_FILE_ROOT / '**/*.nc'), recursive=True) + if 'intra_annual' not in Path(x).parts + ] # ignore the dedicated intra-annual test files for now, as they will be tested separately in depth + + assert len(test_files) == 13 + + assert any([isinstance(transcribe(f), xr.Dataset) for f in test_files]) + + assert any([TEMPORAL_SUB_WINDOW_NC_COORD_NAME in transcribe(f).dims for f in test_files])