Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Update Regrid2 missing and fill value behaviors to align with CDAT and add unmapped_to_nan arg for output data #613

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ jobs:
pytest

- name: Upload Coverage Report
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
file: "tests_coverage_reports/coverage.xml"
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}

# `build-result` is a workaround to skipped matrix jobs in `build` not being considered "successful",
# which can block PR merges if matrix jobs are required status checks.
Expand Down
36 changes: 31 additions & 5 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,36 @@ def test_regrid_input_mask(self):

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

# np.nan != np.nan, replace with 1e20
output_data = output_data.fillna(1e20)

expected_output = np.array(
[
[1e20] * 4,
[1.0] * 4,
[1.0] * 4,
[1e20] * 4,
],
dtype=np.float32,
)

assert np.all(output_data.ts.values == expected_output)

@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
def test_regrid_input_mask_unmapped_to_nan(self):
regridder = regrid2.Regrid2Regridder(
self.coarse_2d_ds, self.fine_2d_ds, unmapped_to_nan=False
)

self.coarse_2d_ds["mask"] = (("lat", "lon"), [[0, 0], [1, 1], [0, 0]])

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

expected_output = np.array(
[
[0.0] * 4,
[0.70710677] * 4,
[0.70710677] * 4,
[1.0] * 4,
[1.0] * 4,
[0.0] * 4,
],
dtype=np.float32,
Expand Down Expand Up @@ -690,7 +715,7 @@ def test_regrid(self):
assert "time_bnds" in output

@pytest.mark.parametrize(
"name,value,attr_name",
"name,value,_",
[
("periodic", True, "_periodic"),
("extrap_method", "inverse_dist", "_extrap_method"),
Expand All @@ -700,14 +725,15 @@ def test_regrid(self):
("ignore_degenerate", False, "_ignore_degenerate"),
],
)
def test_flags(self, name, value, attr_name):
def test_flags(self, name, value, _):
ds = self.ds.copy()

options = {name: value}

regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options)

assert getattr(regridder, attr_name) == value
assert name in regridder._extra_options
assert regridder._extra_options[name] == value

def test_no_variable(self):
ds = self.ds.copy()
Expand Down
108 changes: 84 additions & 24 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

import numpy as np
import xarray as xr
Expand All @@ -8,7 +8,13 @@


class Regrid2Regridder(BaseRegridder):
def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any):
def __init__(
self,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
unmapped_to_nan=True,
**options: Any,
):
"""
Pure python implementation of the regrid2 horizontal regridder from
CDMS2's regrid2 module.
Expand Down Expand Up @@ -47,6 +53,8 @@ def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: A
"""
super().__init__(input_grid, output_grid, **options)

self._unmapped_to_nan = unmapped_to_nan

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
"""Placeholder for base class."""
raise NotImplementedError()
Expand All @@ -66,20 +74,31 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y")
dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X")

src_mask = self._input_grid.get("mask", None)
src_mask_da = self._input_grid.get("mask", None)

# DataArray to np.ndarray, handle error when None
try:
src_mask = src_mask_da.values # type: ignore
except AttributeError:
src_mask = None

# apply source mask to input data
if src_mask is not None:
masked_value = self._input_grid.attrs.get("_FillValue", None)
nan_replace = input_data_var.encoding.get("_FillValue", None)

if masked_value is None:
masked_value = self._input_grid.attrs.get("missing_value", 0.0)
if nan_replace is None:
nan_replace = input_data_var.encoding.get("missing_value", 1e20)

# Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20
input_data_var = input_data_var.where(src_mask != 0.0, masked_value)
# exclude alternative of NaN values if there are any
input_data_var = input_data_var.where(input_data_var != nan_replace)

# horizontal regrid
output_data = _regrid(
input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds
input_data_var,
src_lat_bnds,
src_lon_bnds,
dst_lat_bnds,
dst_lon_bnds,
src_mask,
unmapped_to_nan=self._unmapped_to_nan,
)

output_ds = _build_dataset(
Expand All @@ -101,7 +120,13 @@ def _regrid(
src_lon_bnds: np.ndarray,
dst_lat_bnds: np.ndarray,
dst_lon_bnds: np.ndarray,
src_mask: Optional[np.ndarray],
omitted=None,
unmapped_to_nan=True,
) -> np.ndarray:
if omitted is None:
omitted = np.nan

lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds)
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)

Expand All @@ -114,6 +139,11 @@ def _regrid(
y_length = len(lat_mapping)
x_length = len(lon_mapping)

if src_mask is None:
input_data_shape = input_data.shape

src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index]))

other_dims = {
x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name)
}
Expand All @@ -122,45 +152,62 @@ def _regrid(
data_shape = [y_length * x_length] + other_sizes
# output data is always float32 in original code
output_data = np.zeros(data_shape, dtype=np.float32)
output_mask = np.ones(data_shape, dtype=np.float32)

is_2d = input_data_var.ndim <= 2

# TODO: need to optimize further, investigate using ufuncs and dask arrays
# TODO: how common is lon by lat data? may need to reshape
for y in range(y_length):
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)
y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0)

for x in range(x_length):
x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap")
x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap")

cell_weight = np.dot(lat_weights[y], lon_weights[x])
cell_weights = np.multiply(
np.dot(lat_weights[y], lon_weights[x]), x_mask_seg
)

cell_weight = np.sum(cell_weights)

output_seg_index = y * x_length + x

if cell_weight == 0.0:
output_mask[output_seg_index] = 0.0

# using the `out` argument is more performant, places data directly into
# array memory rather than allocating a new variable. wasn't working for
# single element output, needs further investigation as we may not need
# branch
if is_2d:
output_data[output_seg_index] = np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
)
else:
output_seg = output_data[output_seg_index]

np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
out=output_seg,
)

if cell_weight <= 0.0:
output_data[output_seg_index] = omitted

# default for unmapped is nan due to division by zero, use output mask to repalce
if not unmapped_to_nan:
output_data[output_mask == 0.0] = 0.0

output_data_shape = [y_length, x_length] + other_sizes

output_data = output_data.reshape(output_data_shape)
Expand Down Expand Up @@ -208,7 +255,9 @@ def _build_dataset(
return output_ds


def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Map source to destination latitude.

Expand All @@ -230,7 +279,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:

Returns
-------
Tuple[List, List]
Tuple[List[np.ndarray], List[np.ndarray]]
A tuple of cell mappings and cell weights.
"""
src_south, src_north = _extract_bounds(src)
Expand All @@ -255,14 +304,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
]

# convert latitude to cell weight (difference of height above/below equator)
weights = [
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1))
for x, y in bounds
]
weights = _get_latitude_weights(bounds)

return mapping, weights


def _get_latitude_weights(
bounds: List[Tuple[np.ndarray, np.ndarray]]
) -> List[np.ndarray]:
weights = []

for x, y in bounds:
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))
cell_weight = cell_weight.reshape((-1, 1))

weights.append(cell_weight)

return weights


def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
Comment on lines +312 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted this section of code into its own private function and used an explicit for loop to make it more readable compared to list comprehension (IMO).

Copy link
Collaborator

@tomvothecoder tomvothecoder Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't think Dask Arrays support np.deg2rad and/or np.sin with np.nan values, resulting in ValueError: cannot convert float NaN to integer in #615.

"""
Map source to destination longitude.
Expand Down Expand Up @@ -347,12 +407,12 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Parameters
----------
bounds : np.ndarray
Dataset containing axis with bounds.
A numpy array of bounds values.

Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing the lower and upper bounds for the axis.
A tuple containing the lower and upper bounds for the axis.
"""
if bounds[0, 0] < bounds[0, 1]:
lower = bounds[:, 0]
Expand Down
24 changes: 14 additions & 10 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
extrap_dist_exponent: Optional[float] = None,
extrap_num_src_pnts: Optional[int] = None,
ignore_degenerate: bool = True,
unmapped_to_nan: bool = True,
**options: Any,
):
"""Extension of ``xESMF`` regridder.
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(

This only applies to "conservative" and "conservative_normed"
regridding methods.
unmapped_to_nan : bool
Sets values of unmapped points to `np.nan` instead of 0 (ESMF default).
**options : Any
Additional arguments passed to the underlying ``xesmf.XESMFRegridder``
constructor.
Expand Down Expand Up @@ -126,11 +129,17 @@ def __init__(
)

self._method = method
self._periodic = periodic
self._extrap_method = extrap_method
self._extrap_dist_exponent = extrap_dist_exponent
self._extrap_num_src_pnts = extrap_num_src_pnts
self._ignore_degenerate = ignore_degenerate

# Re-pack xesmf arguments, broken out for validation/documentation
options.update(
periodic=periodic,
extrap_method=extrap_method,
extrap_dist_exponent=extrap_dist_exponent,
extrap_num_src_pnts=extrap_num_src_pnts,
ignore_degenerate=ignore_degenerate,
unmapped_to_nan=unmapped_to_nan,
)

self._extra_options = options

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
Expand All @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
self._input_grid,
self._output_grid,
method=self._method,
periodic=self._periodic,
extrap_method=self._extrap_method,
extrap_dist_exponent=self._extrap_dist_exponent,
extrap_num_src_pnts=self._extrap_num_src_pnts,
ignore_degenerate=self._ignore_degenerate,
**self._extra_options,
)

Expand Down
Loading