Skip to content

Commit

Permalink
better handling of land boundaries
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Feb 16, 2024
1 parent 6515cc5 commit 87868b8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
16 changes: 12 additions & 4 deletions jaxparrow/tools/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import jax.numpy as jnp
from jaxtyping import Array, Float

from .sanitize import handle_land_boundary


def interpolation(
field: Float[Array, "lat lon"],
Expand Down Expand Up @@ -35,13 +37,17 @@ def interpolation(
Interpolated field
"""
if axis == 0:
midpoint_values = 0.5 * (field[:-1, :] + field[1:, :])
field_b, field_f = field[:-1, :], field[1:, :]
field_b, field_f = handle_land_boundary(field_b, field_f)
midpoint_values = 0.5 * (field_b + field_f)
if padding == "left":
field = field.at[1:, :].set(midpoint_values)
else: # padding == "right"
field = field.at[:-1, :].set(midpoint_values)
else: # axis == 1
midpoint_values = 0.5 * (field[:, :-1] + field[:, 1:])
field_b, field_f = field[:, :-1], field[:, 1:]
field_b, field_f = handle_land_boundary(field_b, field_f)
midpoint_values = 0.5 * (field_b + field_f)
if padding == "left":
field = field.at[:, 1:].set(midpoint_values)
else:
Expand Down Expand Up @@ -83,16 +89,18 @@ def derivative(
Interpolated field
"""
if axis == 0:
midpoint_values = field[1:, :] - field[:-1, :]
field_b, field_f = field[:-1, :], field[1:, :]
if padding == "left":
pad_width = ((1, 0), (0, 0))
else: # padding == "right"
pad_width = ((0, 1), (0, 0))
else: # axis == 1
midpoint_values = field[:, 1:] - field[:, :-1]
field_b, field_f = field[:, :-1], field[:, 1:]
if padding == "left":
pad_width = ((0, 0), (1, 0))
else:
pad_width = ((0, 0), (0, 1))
field_b, field_f = handle_land_boundary(field_b, field_f)
midpoint_values = field_f - field_b
field = jnp.pad(midpoint_values, pad_width=pad_width, mode="edge") / dxy
return field
31 changes: 30 additions & 1 deletion jaxparrow/tools/sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,39 @@ def init_mask(
Initialized (if needed) mask
"""
if mask is None:
mask = jnp.isnan(field)
mask = jnp.isfinite(field)
return mask


def handle_land_boundary(
field1: Float[Array, "lat lon"],
field2: Float[Array, "lat lon"]
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]:
"""
Replaces the non-finite values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise.
It allows to introduce less non-finite values when applying grid operators.
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field.
Parameters
----------
field1 : Float[Array, "lat lon"]
A field
field2 : Float[Array, "lat lon"]
Another field
Returns
-------
field1 : Float[Array, "lat lon"]
A field whose non-finite values have been replaced with the ones from ``field2``
field2 : Float[Array, "lat lon"]
A field whose non-finite values have been replaced with the ones from ``field1``
"""
field1 = jnp.where(jnp.isfinite(field1), field1, field2)
field2 = jnp.where(jnp.isfinite(field2), field2, field1)
return field1, field2


def sanitize_grid_np(
lat: Float[Array, "lat lon"],
lon: Float[Array, "lat lon"],
Expand Down

0 comments on commit 87868b8

Please sign in to comment.