From a9b6c5dbf720a3cd4d4199303ddaf7d67b050b33 Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 2 Apr 2024 09:58:08 -0700 Subject: [PATCH 1/3] [PR]: Regridding nan update (dask part only) (#634) --- xcdat/regridder/regrid2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 602a16d7..3c139c03 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -106,7 +106,7 @@ def _regrid( lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) # convert to pure numpy - input_data = input_data_var.astype(np.float32).data + input_data = input_data_var.astype(np.float32).values y_name, y_index = _get_dimension(input_data_var, "Y") x_name, x_index = _get_dimension(input_data_var, "X") @@ -498,4 +498,4 @@ def _get_bounds_ensure_dtype(ds, axis): if bounds.dtype != np.float32: bounds = bounds.astype(np.float32) - return bounds.data + return bounds.values From 472c04b36aa89eb7f9d6d49cff0b0851dfb8edce Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Tue, 2 Apr 2024 11:32:50 -0700 Subject: [PATCH 2/3] [PR]: Remove deprecated features and APIs for next release (#628) * Remove deprecated features and APIs - Remove `horizontal_xesmf()` and `horizontal_regrid2()` - Remove `**kwargs` from `create_grid()` and `_deprecated_create_grid()` - Remove `add_bounds` accepting boolean arg in `open_dataset()`, `open_mfdataset()` and `_postprocess_dataset()` * Remove tests for deprecated code - Remove `lxml` dependency * Remove `_prepare_coordinate()` * Fix `test_global_mean_grid` and `test_zonal_grid` --- conda-env/ci.yml | 1 - conda-env/dev.yml | 1 - tests/test_dataset.py | 117 ----------------------------------- tests/test_regrid.py | 113 +++++++--------------------------- xcdat/dataset.py | 117 +---------------------------------- xcdat/regridder/accessor.py | 119 ------------------------------------ xcdat/regridder/grid.py | 117 +++++------------------------------ 7 files changed, 39 insertions(+), 546 deletions(-) diff --git a/conda-env/ci.yml b/conda-env/ci.yml index 455062e0..3df06815 100644 --- a/conda-env/ci.yml +++ b/conda-env/ci.yml @@ -10,7 +10,6 @@ dependencies: - cf_xarray >=0.7.3 # Constrained because https://github.com/xarray-contrib/cf-xarray/issues/467 - cftime - dask - - lxml # TODO: Remove this in v0.7.0 once cdml/XML support is dropped - netcdf4 - numpy >=1.23.0 # This version of numpy includes support for Python 3.11. - pandas diff --git a/conda-env/dev.yml b/conda-env/dev.yml index 83088470..bdcd3c48 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -10,7 +10,6 @@ dependencies: - cf_xarray >=0.7.3 # Constrained because https://github.com/xarray-contrib/cf-xarray/issues/467 - cftime - dask - - lxml # TODO: Remove this in v0.7.0 once cdml/XML support is dropped - netcdf4 - numpy >=1.23.0 # This version of numpy includes support for Python 3.11. - pandas diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 74b5edcc..b3841dff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -5,7 +5,6 @@ import numpy as np import pytest import xarray as xr -from lxml import etree from tests.fixtures import generate_dataset from xcdat._logger import _setup_custom_logger @@ -77,9 +76,6 @@ def test_skips_adding_bounds(self): ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=False) ds.to_netcdf(self.file_path) - result = open_dataset(self.file_path, add_bounds=False) - assert result.identical(ds) - result = open_dataset(self.file_path, add_bounds=None) assert result.identical(ds) @@ -324,48 +320,6 @@ def test_keeps_specified_var_and_preserves_bounds(self): assert result.identical(expected) - def test_raises_deprecation_warning_when_passing_add_bounds_true(self): - ds_no_bounds = generate_dataset( - decode_times=True, cf_compliant=True, has_bounds=False - ) - ds_no_bounds.to_netcdf(self.file_path) - - with warnings.catch_warnings(record=True) as w: - result = open_dataset(self.file_path, add_bounds=True) - - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert str(w[0].message) == ( - "`add_bounds=True` will be deprecated after v0.6.0. Please use a list " - "of axis strings instead (e.g., `add_bounds=['X', 'Y']`)." - ) - - expected = generate_dataset( - decode_times=True, cf_compliant=True, has_bounds=True - ) - expected = expected.drop_vars("time_bnds") - del expected["time"].attrs["bounds"] - - assert result.identical(expected) - - def test_raises_deprecation_warning_when_passing_add_bounds_false(self): - ds_no_bounds = generate_dataset( - decode_times=True, cf_compliant=True, has_bounds=False - ) - ds_no_bounds.to_netcdf(self.file_path) - - with warnings.catch_warnings(record=True) as w: - result = open_dataset(self.file_path, add_bounds=False) - - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert str(w[0].message) == ( - "`add_bounds=False` will be deprecated after v0.6.0. Please use " - "`add_bounds=None` instead." - ) - - assert result.identical(ds_no_bounds) - class TestOpenMfDataset: @pytest.fixture(autouse=True) @@ -410,80 +364,9 @@ def test_skips_adding_bounds(self): ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=False) ds.to_netcdf(self.file_path1) - result = open_mfdataset(self.file_path1, add_bounds=False) - assert result.identical(ds) - result = open_mfdataset(self.file_path1, add_bounds=None) assert result.identical(ds) - def test_raises_error_if_xml_does_not_have_root_directory_attr(self): - ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) - ds2.to_netcdf(self.file_path2) - - # Create the XML file - xml_path = f"{self.dir}/datasets.xml" - page = etree.Element("dataset") - doc = etree.ElementTree(page) - doc.write(xml_path, xml_declaration=True, encoding="utf-16") - - with pytest.raises(KeyError): - open_mfdataset(xml_path, decode_times=True) - - def test_opens_datasets_from_xml_using_str_path(self): - ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) - ds2.to_netcdf(self.file_path2) - - # Create the XML file - xml_path = f"{self.dir}/datasets.xml" - page = etree.Element("dataset", directory=str(self.dir)) - doc = etree.ElementTree(page) - doc.write(xml_path, xml_declaration=True, encoding="utf-16") - - result = open_mfdataset(xml_path, decode_times=True) - expected = ds1.merge(ds2) - - result.identical(expected) - - def test_opens_datasets_from_xml_raises_deprecation_warning(self): - ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) - ds2.to_netcdf(self.file_path2) - - # Create the XML file - xml_path = f"{self.dir}/datasets.xml" - page = etree.Element("dataset", directory=str(self.dir)) - doc = etree.ElementTree(page) - doc.write(xml_path, xml_declaration=True, encoding="utf-16") - - with pytest.warns(DeprecationWarning): - open_mfdataset(xml_path, decode_times=True) - - def test_opens_datasets_from_xml_using_pathlib_path(self): - ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) - ds2.to_netcdf(self.file_path2) - - # Create the XML file - xml_path = self.dir / "datasets.xml" - page = etree.Element("dataset", directory=str(self.dir)) - doc = etree.ElementTree(page) - doc.write(xml_path, xml_declaration=True, encoding="utf-16") - - result = open_mfdataset(xml_path, decode_times=True) - expected = ds1.merge(ds2) - - result.identical(expected) - def test_raises_error_if_directory_has_no_netcdf_files(self): with pytest.raises(ValueError): open_mfdataset(str(self.dir), decode_times=True) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 13de350b..193a22a6 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -1,7 +1,6 @@ import datetime import re import sys -import warnings from unittest import mock import numpy as np @@ -38,7 +37,8 @@ class TestXGCMRegridder: def setup(self): self.ds = fixtures.generate_lev_dataset() - self.output_grid = grid.create_grid(lev=np.linspace(10000, 2000, 2)) + z = grid.create_axis("lev", np.linspace(10000, 2000, 2), generate_bounds=False) + self.output_grid = grid.create_grid(z=z) def test_multiple_z_axes(self): self.ds = self.ds.assign_coords({"ilev": self.ds.lev.copy().rename("ilev")}) @@ -891,73 +891,6 @@ def test_create_grid_wrong_axis_value(self): ): grid.create_grid(x=(self.lon, self.lon_bnds, self.lat)) # type: ignore[arg-type] - def test_deprecated_unexpected_coordinate(self): - lev = np.linspace(1000, 1, 2) - - with pytest.raises( - ValueError, - match="Coordinate mass is not valid, reference `xcdat.axis.VAR_NAME_MAP` for valid options.", - ): - grid.create_grid(lev=lev, mass=np.linspace(10, 20, 2)) - - def test_deprecated_create_grid_lev(self): - lev = np.linspace(1000, 1, 2) - lev_bnds = np.array([[1499.5, 500.5], [500.5, -498.5]]) - - with warnings.catch_warnings(record=True) as w: - new_grid = grid.create_grid(lev=(lev, lev_bnds)) - - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert ( - str(w[0].message) - == "**kwargs will be deprecated, see docstring and use 'x', 'y', or 'z' arguments" - ) - - assert np.array_equal(new_grid.lev, lev) - assert np.array_equal(new_grid.lev_bnds, lev_bnds) - - def test_deprecated_create_grid(self): - lat = np.array([-45, 0, 45]) - lon = np.array([30, 60, 90, 120, 150]) - lat_bnds = np.array([[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]]) - lon_bnds = np.array([[15, 45], [45, 75], [75, 105], [105, 135], [135, 165]]) - - new_grid = grid.create_grid(lat=lat, lon=lon) - - assert np.array_equal(new_grid.lat, lat) - assert np.array_equal(new_grid.lat_bnds, lat_bnds) - assert new_grid.lat.units == "degrees_north" - assert np.array_equal(new_grid.lon, lon) - assert np.array_equal(new_grid.lon_bnds, lon_bnds) - assert new_grid.lon.units == "degrees_east" - - da_lat = xr.DataArray( - name="lat", - data=lat, - dims=["lat"], - attrs={"units": "degrees_north", "axis": "Y"}, - ) - da_lon = xr.DataArray( - name="lon", - data=lon, - dims=["lon"], - attrs={"units": "degrees_east", "axis": "X"}, - ) - da_lat_bnds = xr.DataArray(name="lat_bnds", data=lat_bnds, dims=["lat", "bnds"]) - da_lon_bnds = xr.DataArray(name="lon_bnds", data=lon_bnds, dims=["lon", "bnds"]) - - new_grid = grid.create_grid( - lat=(da_lat, da_lat_bnds), lon=(da_lon, da_lon_bnds) - ) - - assert np.array_equal(new_grid.lat, lat) - assert np.array_equal(new_grid.lat_bnds, lat_bnds) - assert new_grid.lat.units == "degrees_north" - assert np.array_equal(new_grid.lon, lon) - assert np.array_equal(new_grid.lon_bnds, lon_bnds) - assert new_grid.lon.units == "degrees_east" - def test_uniform_grid(self): new_grid = grid.create_uniform_grid(-90, 90, 4.0, -180, 180, 5.0) @@ -986,10 +919,14 @@ def test_gaussian_grid(self): assert uneven_grid.lon.shape == (67,) def test_global_mean_grid(self): - source_grid = grid.create_grid( - lat=np.array([-80, -40, 0, 40, 80]), - lon=np.array([0, 45, 90, 180, 270, 360]), + x = grid.create_axis( + "lon", np.array([0, 45, 90, 180, 270, 360]), generate_bounds=True ) + y = grid.create_axis( + "lat", np.array([-80, -40, 0, 40, 80]), generate_bounds=True + ) + + source_grid = grid.create_grid(x=x, y=y) mean_grid = grid.create_global_mean_grid(source_grid) @@ -1068,9 +1005,14 @@ def test_raises_error_for_global_mean_grid_if_an_axis_has_multiple_dimensions(se grid.create_global_mean_grid(source_grid_with_2_lons) def test_zonal_grid(self): - source_grid = grid.create_grid( - lat=np.array([-80, -40, 0, 40, 80]), lon=np.array([-160, -80, 80, 160]) + x = grid.create_axis( + "lon", np.array([-160, -80, 80, 160]), generate_bounds=True ) + y = grid.create_axis( + "lat", np.array([-80, -40, 0, 40, 80]), generate_bounds=True + ) + + source_grid = grid.create_grid(x=x, y=y) zonal_grid = grid.create_zonal_grid(source_grid) @@ -1194,7 +1136,9 @@ def test_horizontal(self): assert output_data.ts.shape == (15, 4, 4) def test_vertical(self): - output_grid = grid.create_grid(lev=np.linspace(10000, 2000, 2)) + z = grid.create_axis("lev", np.linspace(10000, 2000, 2), generate_bounds=False) + + output_grid = grid.create_grid(z=z) output_data = self.vertical_ds.regridder.vertical( "so", output_grid, tool="xgcm", method="linear" @@ -1210,7 +1154,8 @@ def test_vertical(self): assert output_data.so.shape == (15, 4, 4, 4) def test_vertical_multiple_z_axes(self): - output_grid = grid.create_grid(lev=np.linspace(10000, 2000, 2)) + z = grid.create_axis("lev", np.linspace(10000, 2000, 2), generate_bounds=False) + output_grid = grid.create_grid(z=z) self.vertical_ds = self.vertical_ds.assign_coords( {"ilev": self.vertical_ds.lev.copy().rename("ilev")} @@ -1312,22 +1257,6 @@ def test_vertical_tool_check(self, _get_input_grid): ): self.ac.vertical("ts", mock_data, tool="dummy", target_data=None) # type: ignore - @pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning") - def test_convenience_methods(self): - ds = fixtures.generate_dataset( - decode_times=True, cf_compliant=False, has_bounds=True - ) - - out_grid = grid.create_gaussian_grid(32) - - output_xesmf = ds.regridder.horizontal_xesmf("ts", out_grid, method="bilinear") - - assert output_xesmf.ts.shape == (15, 32, 65) - - output_regrid2 = ds.regridder.horizontal_regrid2("ts", out_grid) - - assert output_regrid2.ts.shape == (15, 32, 65) - class TestBase: def test_preserve_bounds(self): diff --git a/xcdat/dataset.py b/xcdat/dataset.py index f3fe3e01..45774b48 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -3,7 +3,6 @@ import os import pathlib -import warnings from datetime import datetime from functools import partial from io import BufferedIOBase @@ -13,7 +12,6 @@ import xarray as xr from dateutil import parser from dateutil import relativedelta as rd -from lxml import etree from xarray.backends.common import AbstractDataStore from xarray.coding.cftime_offsets import get_date_type from xarray.coding.times import convert_times, decode_cf_datetime @@ -46,7 +44,7 @@ def open_dataset( path: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, data_var: Optional[str] = None, - add_bounds: List[CFAxisKey] | None | bool = ["X", "Y"], + add_bounds: List[CFAxisKey] | None = ["X", "Y"], decode_times: bool = True, center_times: bool = False, lon_orient: Optional[Tuple[float, float]] = None, @@ -54,10 +52,6 @@ def open_dataset( ) -> xr.Dataset: """Wraps ``xarray.open_dataset()`` with post-processing options. - .. deprecated:: v0.6.0 - ``add_bounds`` boolean arguments (True/False) are being deprecated. - Please use either a list (e.g., ["X", "Y"]) to specify axes or ``None``. - Parameters ---------- path : str, Path, file-like or DataStore @@ -133,7 +127,7 @@ def open_dataset( def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], data_var: Optional[str] = None, - add_bounds: List[CFAxisKey] | None | Literal[False] = ["X", "Y"], + add_bounds: List[CFAxisKey] | None = ["X", "Y"], decode_times: bool = True, center_times: bool = False, lon_orient: Optional[Tuple[float, float]] = None, @@ -143,10 +137,6 @@ def open_mfdataset( ) -> xr.Dataset: """Wraps ``xarray.open_mfdataset()`` with post-processing options. - .. deprecated:: v0.6.0 - ``add_bounds`` boolean arguments (True/False) are being deprecated. - Please use either a list (e.g., ["X", "Y"]) to specify axes or ``None``. - Parameters ---------- paths : str | NestedSequence[str | os.PathLike] @@ -162,15 +152,6 @@ def open_mfdataset( If concatenation along more than one dimension is desired, then ``paths`` must be a nested list-of-lists (see [2]_ ``xarray.combine_nested`` for details). - * File path to an XML file with a ``directory`` attribute (e.g., - ``"path/to/files"``). If ``directory`` is set to a blank string - (""), then the current directory is substituted ("."). This option - is intended to support the CDAT CDML dialect of XML files, but it - can work with any XML file that has the ``directory`` attribute. - Refer to [4]_ for more information on CDML. NOTE: This feature is - deprecated in v0.6.0 and will be removed in the subsequent release. - CDAT (including cdms2/CDML) is in maintenance only mode and marked - for end-of-life by the end of 2023. add_bounds: List[CFAxisKey] | None | bool List of CF axes to try to add bounds for (if missing), by default ["X", "Y"]. Set to None to not add any missing bounds. Please note that @@ -240,15 +221,6 @@ def open_mfdataset( in-memory copy you are manipulating in xarray is modified: the original file on disk is never touched. - The CDAT "Climate Data Markup Language" (CDML) is a deprecated dialect of - XML with a defined set of attributes. CDML is still used by current and - former users of CDAT. To enable CDML users to adopt xCDAT more easily in - their workflows, xCDAT can parse XML/CDML files for the ``directory`` - to generate a glob or list of file paths. Refer to [4]_ for more information - on CDML. NOTE: This feature is deprecated in v0.6.0 and will be removed in - the subsequent release. CDAT (including cdms2/CDML) is in maintenance only - mode and marked for end-of-life by the end of 2023. - References ---------- .. [2] https://docs.xarray.dev/en/stable/generated/xarray.combine_nested.html @@ -258,13 +230,6 @@ def open_mfdataset( if isinstance(paths, str) or isinstance(paths, pathlib.Path): if os.path.isdir(paths): paths = _parse_dir_for_nc_glob(paths) - elif _is_xml_filepath(paths): - warnings.warn( - "`open_mfdataset()` will no longer support CDML/XML paths after " - "v0.6.0 because CDAT is marked for end-of-life at the end of 2023.", - DeprecationWarning, - ) - paths = _parse_xml_for_nc_glob(paths) preprocess = partial(_preprocess, decode_times=decode_times, callable=preprocess) @@ -422,58 +387,6 @@ def decode_time(dataset: xr.Dataset) -> xr.Dataset: return ds -def _is_xml_filepath(paths: str | pathlib.Path) -> bool: - """Checks if the ``paths`` argument is a path to an XML file. - - Parameters - ---------- - paths : str | pathlib.Path - A string or pathlib.Path represnting a file path. - - Returns - ------- - bool - """ - if isinstance(paths, str): - return paths.split(".")[-1] == "xml" - elif isinstance(paths, pathlib.Path): - return paths.parts[-1].endswith("xml") - - -def _parse_xml_for_nc_glob(xml_path: str | pathlib.Path) -> str | List[str]: - """ - Parses an XML file for the ``directory`` attr to return a string glob or - list of string file paths. - - Parameters - ---------- - xml_path : str | pathlib.Path - The XML file path. - - Returns - ------- - str | List[str] - A string glob of `*.nc` paths. - - """ - # `resolve_entities=False` and `no_network=True` guards against XXE attacks. - # Source: https://rules.sonarsource.com/python/RSPEC-2755 - parser = etree.XMLParser(resolve_entities=False, no_network=True) - tree = etree.parse(xml_path, parser) - root = tree.getroot() - - dir_attr = root.attrib.get("directory") - if dir_attr is None: - raise KeyError( - f"The XML file ({xml_path}) does not have a 'directory' attribute " - "that points to a directory of `.nc` dataset files." - ) - - glob_path = dir_attr + "/*.nc" - - return glob_path - - def _parse_dir_for_nc_glob(dir_path: str | pathlib.Path) -> str: """Parses a directory for a glob of `*.nc` paths. @@ -557,10 +470,6 @@ def _postprocess_dataset( ) -> xr.Dataset: """Post-processes a Dataset object. - .. deprecated:: v0.6.0 - ``add_bounds`` boolean arguments (True/False) are being deprecated. - Please use either a list (e.g., ["X", "Y"]) to specify axes or ``None``. - Parameters ---------- dataset : xr.Dataset @@ -611,28 +520,6 @@ def _postprocess_dataset( if center_times: ds = center_times_func(dataset) - # TODO: Boolean (`True`/`False`) will be deprecated after v0.6.0. - if add_bounds is True: - add_bounds = ["X", "Y"] - warnings.warn( - ( - "`add_bounds=True` will be deprecated after v0.6.0. Please use a list " - "of axis strings instead (e.g., `add_bounds=['X', 'Y']`)." - ), - DeprecationWarning, - stacklevel=2, - ) - elif add_bounds is False: - add_bounds = None - warnings.warn( - ( - "`add_bounds=False` will be deprecated after v0.6.0. Please use " - "`add_bounds=None` instead." - ), - DeprecationWarning, - stacklevel=2, - ) - if add_bounds is not None: ds = ds.bounds.add_missing_bounds(axes=add_bounds) diff --git a/xcdat/regridder/accessor.py b/xcdat/regridder/accessor.py index b8a93c35..b4b34d95 100644 --- a/xcdat/regridder/accessor.py +++ b/xcdat/regridder/accessor.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from typing import Any, List, Literal, Tuple import xarray as xr @@ -117,124 +116,6 @@ def _get_axis_data( return coord_var, bounds_var - # TODO Either provide generic `horizontal` and `vertical` methods or tool specific - def horizontal_xesmf( - self, - data_var: str, - output_grid: xr.Dataset, - **options: Any, - ) -> xr.Dataset: - """ - Deprecated, will be removed with 0.7.0 release. - - Extends the xESMF library for horizontal regridding between structured - rectilinear and curvilinear grids. - - This method extends ``xESMF`` by automatically constructing the - ``xe.XESMFRegridder`` object, preserving source bounds, and generating - missing bounds. It regrids ``data_var`` in the dataset to - ``output_grid``. - - Option documentation :py:func:`xcdat.regridder.xesmf.XESMFRegridder` - - Parameters - ---------- - data_var: str - Name of the variable in the `xr.Dataset` to regrid. - output_grid : xr.Dataset - Dataset containing output grid. - options : Dict[str, Any] - Dictionary with extra parameters for the regridder. - - Returns - ------- - xr.Dataset - With the ``data_var`` variable on the grid defined in ``output_grid``. - - Raises - ------ - ValueError - If tool is not supported. - - Examples - -------- - - Generate output grid: - - >>> output_grid = xcdat.create_gaussian_grid(32) - - Regrid data to output grid using xesmf: - - >>> ds.regridder.horizontal_xesmf("ts", output_grid) - """ - warnings.warn( - "`horizontal_xesmf` will be deprecated in 0.7.x, please migrate to using " - "`horizontal(..., tool='xesmf')` method.", - DeprecationWarning, - stacklevel=2, - ) - - regridder = HORIZONTAL_REGRID_TOOLS["xesmf"](self._ds, output_grid, **options) - - return regridder.horizontal(data_var, self._ds) - - # TODO Either provide generic `horizontal` and `vertical` methods or tool specific - def horizontal_regrid2( - self, - data_var: str, - output_grid: xr.Dataset, - **options: Any, - ) -> xr.Dataset: - """ - Deprecated, will be removed with 0.7.0 release. - - Pure python implementation of CDAT's regrid2 horizontal regridder. - - Regrids ``data_var`` in dataset to ``output_grid`` using regrid2's - algorithm. - - Options documentation :py:func:`xcdat.regridder.regrid2.Regrid2Regridder` - - Parameters - ---------- - data_var: str - Name of the variable in the `xr.Dataset` to regrid. - output_grid : xr.Dataset - Dataset containing output grid. - options : Dict[str, Any] - Dictionary with extra parameters for the regridder. - - Returns - ------- - xr.Dataset - With the ``data_var`` variable on the grid defined in ``output_grid``. - - Raises - ------ - ValueError - If tool is not supported. - - Examples - -------- - Generate output grid: - - >>> output_grid = xcdat.create_gaussian_grid(32) - - Regrid data to output grid using regrid2: - - >>> ds.regridder.horizontal_regrid2("ts", output_grid) - """ - warnings.warn( - "`horizontal_regrid2` will be deprecated in 0.7.x, please migrate to using " - "`horizontal(..., tool='regrid2')` method.", - DeprecationWarning, - stacklevel=2, - ) - - regridder = HORIZONTAL_REGRID_TOOLS["regrid2"](self._ds, output_grid, **options) - - return regridder.horizontal(data_var, self._ds) - def horizontal( self, data_var: str, diff --git a/xcdat/regridder/grid.py b/xcdat/regridder/grid.py index b2b32564..28600a6f 100644 --- a/xcdat/regridder/grid.py +++ b/xcdat/regridder/grid.py @@ -1,12 +1,12 @@ -import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple, Union import numpy as np import xarray as xr from xcdat.axis import COORD_DEFAULT_ATTRS, VAR_NAME_MAP, CFAxisKey, get_dim_coords from xcdat.bounds import create_bounds -from xcdat.regridder.base import CoordOptionalBnds # First 50 zeros for the bessel function # Taken from https://github.com/CDAT/cdms/blob/dd41a8dd3b5bac10a4bfdf6e56f6465e11efc51d/regrid2/Src/_regridmodule.c#L3390-L3402 @@ -435,41 +435,24 @@ def create_zonal_grid(grid: xr.Dataset) -> xr.Dataset: def create_grid( - x: Optional[ - Union[ - xr.DataArray, - Tuple[xr.DataArray, Optional[xr.DataArray]], - ] - ] = None, - y: Optional[ - Union[ - xr.DataArray, - Tuple[xr.DataArray, Optional[xr.DataArray]], - ] - ] = None, - z: Optional[ - Union[ - xr.DataArray, - Tuple[xr.DataArray, Optional[xr.DataArray]], - ] - ] = None, + x: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, + y: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, + z: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, attrs: Optional[Dict[str, str]] = None, - **kwargs: CoordOptionalBnds, ) -> xr.Dataset: """Creates a grid dataset using the specified axes. - .. deprecated:: v0.6.0 - ``**kwargs`` argument is being deprecated, please migrate to - ``x``, ``y``, or ``z`` arguments to create future grids. - Parameters ---------- - x : Optional[Union[xr.DataArray, Tuple[xr.DataArray]]] - Data with optional bounds to use for the "X" axis, by default None. - y : Optional[Union[xr.DataArray, Tuple[xr.DataArray]]] - Data with optional bounds to use for the "Y" axis, by default None. - z : Optional[Union[xr.DataArray, Tuple[xr.DataArray]]] - Data with optional bounds to use for the "Z" axis, by default None. + x : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None + An optional dataarray or tuple of a datarray with optional bounds to use + for the "X" axis, by default None. + y : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, + An optional dataarray or tuple of a datarray with optional bounds to use + for the "Y" axis, by default None. + z : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None + An optional dataarray or tuple of a datarray with optional bounds to use + for the "Z" axis, by default None. attrs : Optional[Dict[str, str]] Custom attributes to be added to the generated `xr.Dataset`. @@ -515,16 +498,8 @@ def create_grid( >>> ) >>> grid = create_grid(z=z) """ - if np.all([item is None for item in (x, y, z)]) and len(kwargs) == 0: + if np.all([item is None for item in (x, y, z)]): raise ValueError("Must pass at least 1 axis to create a grid.") - elif np.all([item is None for item in (x, y, z)]) and len(kwargs) > 0: - warnings.warn( - "**kwargs will be deprecated, see docstring and use 'x', 'y', or 'z' arguments", - DeprecationWarning, - stacklevel=2, - ) - - return _deprecated_create_grid(**kwargs) axes = {"x": x, "y": y, "z": z} ds = xr.Dataset(attrs={} if attrs is None else attrs.copy()) @@ -557,66 +532,6 @@ def create_grid( return ds -def _deprecated_create_grid(**kwargs: CoordOptionalBnds) -> xr.Dataset: - coords = {} - data_vars = {} - - for name, data in kwargs.items(): - if name in VAR_NAME_MAP["X"]: - coord, bnds = _prepare_coordinate(name, data, **COORD_DEFAULT_ATTRS["X"]) - elif name in VAR_NAME_MAP["Y"]: - coord, bnds = _prepare_coordinate(name, data, **COORD_DEFAULT_ATTRS["Y"]) - elif name in VAR_NAME_MAP["Z"]: - coord, bnds = _prepare_coordinate(name, data, **COORD_DEFAULT_ATTRS["Z"]) - else: - raise ValueError( - f"Coordinate {name} is not valid, reference " - "`xcdat.axis.VAR_NAME_MAP` for valid options." - ) - - coords[name] = coord - - if bnds is not None: - bnds = bnds.copy() - - if isinstance(bnds, np.ndarray): - bnds = xr.DataArray( - name=f"{name}_bnds", - data=bnds.copy(), - dims=[name, "bnds"], - ) - - data_vars[bnds.name] = bnds - - coord.attrs["bounds"] = bnds.name - - grid = xr.Dataset(data_vars, coords=coords) - - grid = grid.bounds.add_missing_bounds(axes=["X", "Y"]) - - return grid - - -def _prepare_coordinate(name: str, data: CoordOptionalBnds, **attrs: Any): - if isinstance(data, tuple): - coord, bnds = data[0], data[1] - else: - coord, bnds = data, None - - # ensure we make a copy - coord = coord.copy() - - if isinstance(coord, np.ndarray): - coord = xr.DataArray( - name=name, - data=coord, - dims=[name], - attrs=attrs, - ) - - return coord, bnds - - def create_axis( name: str, data: Union[List[Union[int, float]], np.ndarray], From 1f4d22a0e532425a5992bc387b24bf171089bd40 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 10 Apr 2024 13:01:23 -0700 Subject: [PATCH 3/3] [PR]: Update Regrid2 missing and fill value behaviors to align with CDAT and add `unmapped_to_nan` arg for output data (#613) Co-authored-by: tomvothecoder Co-authored-by: Jiwoo Lee --- .github/workflows/build_workflow.yml | 3 +- tests/test_regrid.py | 36 +++++++-- xcdat/regridder/regrid2.py | 108 +++++++++++++++++++++------ xcdat/regridder/xesmf.py | 24 +++--- 4 files changed, 131 insertions(+), 40 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index 108e40bd..3dcbf886 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -107,10 +107,11 @@ jobs: pytest - name: Upload Coverage Report - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: "tests_coverage_reports/coverage.xml" fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} # `build-result` is a workaround to skipped matrix jobs in `build` not being considered "successful", # which can block PR merges if matrix jobs are required status checks. diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 193a22a6..c7253f4f 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -496,11 +496,36 @@ def test_regrid_input_mask(self): output_data = regridder.horizontal("ts", self.coarse_2d_ds) + # np.nan != np.nan, replace with 1e20 + output_data = output_data.fillna(1e20) + + expected_output = np.array( + [ + [1e20] * 4, + [1.0] * 4, + [1.0] * 4, + [1e20] * 4, + ], + dtype=np.float32, + ) + + assert np.all(output_data.ts.values == expected_output) + + @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") + def test_regrid_input_mask_unmapped_to_nan(self): + regridder = regrid2.Regrid2Regridder( + self.coarse_2d_ds, self.fine_2d_ds, unmapped_to_nan=False + ) + + self.coarse_2d_ds["mask"] = (("lat", "lon"), [[0, 0], [1, 1], [0, 0]]) + + output_data = regridder.horizontal("ts", self.coarse_2d_ds) + expected_output = np.array( [ [0.0] * 4, - [0.70710677] * 4, - [0.70710677] * 4, + [1.0] * 4, + [1.0] * 4, [0.0] * 4, ], dtype=np.float32, @@ -690,7 +715,7 @@ def test_regrid(self): assert "time_bnds" in output @pytest.mark.parametrize( - "name,value,attr_name", + "name,value,_", [ ("periodic", True, "_periodic"), ("extrap_method", "inverse_dist", "_extrap_method"), @@ -700,14 +725,15 @@ def test_regrid(self): ("ignore_degenerate", False, "_ignore_degenerate"), ], ) - def test_flags(self, name, value, attr_name): + def test_flags(self, name, value, _): ds = self.ds.copy() options = {name: value} regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options) - assert getattr(regridder, attr_name) == value + assert name in regridder._extra_options + assert regridder._extra_options[name] == value def test_no_variable(self): ds = self.ds.copy() diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 3c139c03..823215d7 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import numpy as np import xarray as xr @@ -8,7 +8,13 @@ class Regrid2Regridder(BaseRegridder): - def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any): + def __init__( + self, + input_grid: xr.Dataset, + output_grid: xr.Dataset, + unmapped_to_nan=True, + **options: Any, + ): """ Pure python implementation of the regrid2 horizontal regridder from CDMS2's regrid2 module. @@ -47,6 +53,8 @@ def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: A """ super().__init__(input_grid, output_grid, **options) + self._unmapped_to_nan = unmapped_to_nan + def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Placeholder for base class.""" raise NotImplementedError() @@ -66,20 +74,31 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y") dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X") - src_mask = self._input_grid.get("mask", None) + src_mask_da = self._input_grid.get("mask", None) + + # DataArray to np.ndarray, handle error when None + try: + src_mask = src_mask_da.values # type: ignore + except AttributeError: + src_mask = None - # apply source mask to input data - if src_mask is not None: - masked_value = self._input_grid.attrs.get("_FillValue", None) + nan_replace = input_data_var.encoding.get("_FillValue", None) - if masked_value is None: - masked_value = self._input_grid.attrs.get("missing_value", 0.0) + if nan_replace is None: + nan_replace = input_data_var.encoding.get("missing_value", 1e20) - # Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20 - input_data_var = input_data_var.where(src_mask != 0.0, masked_value) + # exclude alternative of NaN values if there are any + input_data_var = input_data_var.where(input_data_var != nan_replace) + # horizontal regrid output_data = _regrid( - input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds + input_data_var, + src_lat_bnds, + src_lon_bnds, + dst_lat_bnds, + dst_lon_bnds, + src_mask, + unmapped_to_nan=self._unmapped_to_nan, ) output_ds = _build_dataset( @@ -101,7 +120,13 @@ def _regrid( src_lon_bnds: np.ndarray, dst_lat_bnds: np.ndarray, dst_lon_bnds: np.ndarray, + src_mask: Optional[np.ndarray], + omitted=None, + unmapped_to_nan=True, ) -> np.ndarray: + if omitted is None: + omitted = np.nan + lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds) lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) @@ -114,6 +139,11 @@ def _regrid( y_length = len(lat_mapping) x_length = len(lon_mapping) + if src_mask is None: + input_data_shape = input_data.shape + + src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index])) + other_dims = { x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name) } @@ -122,6 +152,7 @@ def _regrid( data_shape = [y_length * x_length] + other_sizes # output data is always float32 in original code output_data = np.zeros(data_shape, dtype=np.float32) + output_mask = np.ones(data_shape, dtype=np.float32) is_2d = input_data_var.ndim <= 2 @@ -129,14 +160,23 @@ def _regrid( # TODO: how common is lon by lat data? may need to reshape for y in range(y_length): y_seg = np.take(input_data, lat_mapping[y], axis=y_index) + y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0) for x in range(x_length): x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") + x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap") - cell_weight = np.dot(lat_weights[y], lon_weights[x]) + cell_weights = np.multiply( + np.dot(lat_weights[y], lon_weights[x]), x_mask_seg + ) + + cell_weight = np.sum(cell_weights) output_seg_index = y * x_length + x + if cell_weight == 0.0: + output_mask[output_seg_index] = 0.0 + # using the `out` argument is more performant, places data directly into # array memory rather than allocating a new variable. wasn't working for # single element output, needs further investigation as we may not need @@ -144,23 +184,30 @@ def _regrid( if is_2d: output_data[output_seg_index] = np.divide( np.sum( - np.multiply(x_seg, cell_weight), + np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), - np.sum(cell_weight), + cell_weight, ) else: output_seg = output_data[output_seg_index] np.divide( np.sum( - np.multiply(x_seg, cell_weight), + np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), - np.sum(cell_weight), + cell_weight, out=output_seg, ) + if cell_weight <= 0.0: + output_data[output_seg_index] = omitted + + # default for unmapped is nan due to division by zero, use output mask to repalce + if not unmapped_to_nan: + output_data[output_mask == 0.0] = 0.0 + output_data_shape = [y_length, x_length] + other_sizes output_data = output_data.reshape(output_data_shape) @@ -208,7 +255,9 @@ def _build_dataset( return output_ds -def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: +def _map_latitude( + src: np.ndarray, dst: np.ndarray +) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Map source to destination latitude. @@ -230,7 +279,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: Returns ------- - Tuple[List, List] + Tuple[List[np.ndarray], List[np.ndarray]] A tuple of cell mappings and cell weights. """ src_south, src_north = _extract_bounds(src) @@ -255,14 +304,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: ] # convert latitude to cell weight (difference of height above/below equator) - weights = [ - (np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1)) - for x, y in bounds - ] + weights = _get_latitude_weights(bounds) return mapping, weights +def _get_latitude_weights( + bounds: List[Tuple[np.ndarray, np.ndarray]] +) -> List[np.ndarray]: + weights = [] + + for x, y in bounds: + cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y)) + cell_weight = cell_weight.reshape((-1, 1)) + + weights.append(cell_weight) + + return weights + + def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: """ Map source to destination longitude. @@ -347,12 +407,12 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Parameters ---------- bounds : np.ndarray - Dataset containing axis with bounds. + A numpy array of bounds values. Returns ------- Tuple[np.ndarray, np.ndarray] - A tuple containing the lower and upper bounds for the axis. + A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: lower = bounds[:, 0] diff --git a/xcdat/regridder/xesmf.py b/xcdat/regridder/xesmf.py index 90239469..7ad9dba0 100644 --- a/xcdat/regridder/xesmf.py +++ b/xcdat/regridder/xesmf.py @@ -28,6 +28,7 @@ def __init__( extrap_dist_exponent: Optional[float] = None, extrap_num_src_pnts: Optional[int] = None, ignore_degenerate: bool = True, + unmapped_to_nan: bool = True, **options: Any, ): """Extension of ``xESMF`` regridder. @@ -74,6 +75,8 @@ def __init__( This only applies to "conservative" and "conservative_normed" regridding methods. + unmapped_to_nan : bool + Sets values of unmapped points to `np.nan` instead of 0 (ESMF default). **options : Any Additional arguments passed to the underlying ``xesmf.XESMFRegridder`` constructor. @@ -126,11 +129,17 @@ def __init__( ) self._method = method - self._periodic = periodic - self._extrap_method = extrap_method - self._extrap_dist_exponent = extrap_dist_exponent - self._extrap_num_src_pnts = extrap_num_src_pnts - self._ignore_degenerate = ignore_degenerate + + # Re-pack xesmf arguments, broken out for validation/documentation + options.update( + periodic=periodic, + extrap_method=extrap_method, + extrap_dist_exponent=extrap_dist_exponent, + extrap_num_src_pnts=extrap_num_src_pnts, + ignore_degenerate=ignore_degenerate, + unmapped_to_nan=unmapped_to_nan, + ) + self._extra_options = options def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: self._input_grid, self._output_grid, method=self._method, - periodic=self._periodic, - extrap_method=self._extrap_method, - extrap_dist_exponent=self._extrap_dist_exponent, - extrap_num_src_pnts=self._extrap_num_src_pnts, - ignore_degenerate=self._ignore_degenerate, **self._extra_options, )