Skip to content

Commit

Permalink
fix the cell centers computation (#61)
Browse files Browse the repository at this point in the history
* expose the grid info object

* use the grid info methods to implement the reconstruction of cell ids

* add a method to retrieve the cell centers using the accessor

* forward to the grid object

* remove the left-over conversion methods on the index base class

* add a test to make sure the accessor returns the right cell centers

* inverse the order of coordinates for H3

* fix the old `assign_latlon_coords` method

* change the expectations in the h3 tests

now we always expect longitude, then latitude as a tuple
  • Loading branch information
keewis committed Jul 3, 2024
1 parent 2e003f5 commit 4971103
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 20 deletions.
22 changes: 20 additions & 2 deletions xdggs/accessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy.typing as npt
import xarray as xr

from xdggs.grid import DGGSInfo
from xdggs.index import DGGSIndex


Expand Down Expand Up @@ -55,6 +56,10 @@ def params(self) -> dict:
"""The grid parameters after normalization."""
return self.index.grid.to_dict()

@property
def grid_info(self) -> DGGSInfo:
return self.index.grid_info

def sel_latlon(
self, latitude: npt.ArrayLike, longitude: npt.ArrayLike
) -> xr.Dataset | xr.DataArray:
Expand All @@ -74,15 +79,28 @@ def sel_latlon(
with all cells that contain the input latitude/longitude data points.
"""
cell_indexers = {self._name: self.index._latlon2cellid(latitude, longitude)}
cell_indexers = {
self._name: self.grid_info.geographic2cell_ids(latitude, longitude)
}
return self._obj.sel(cell_indexers)

def assign_latlon_coords(self) -> xr.Dataset | xr.DataArray:
"""Return a new Dataset or DataArray with new "latitude" and "longitude"
coordinates representing the grid cell centers."""

lon_data, lat_data = self.index.cell_centers
lon_data, lat_data = self.index.cell_centers()

return self._obj.assign_coords(
latitude=(self.index._dim, lat_data),
longitude=(self.index._dim, lon_data),
)

def cell_centers(self):
lon_data, lat_data = self.index.cell_centers()

return xr.Dataset(
coords={
"latitude": (self.index._dim, lat_data),
"longitude": (self.index._dim, lon_data),
}
)
4 changes: 3 additions & 1 deletion xdggs/h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def to_dict(self: Self) -> dict[str, Any]:
def cell_ids2geographic(
self, cell_ids: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
return cells_to_coordinates(cell_ids, radians=False)
lat, lon = cells_to_coordinates(cell_ids, radians=False)

return lon, lat

def geographic2cell_ids(self, lon, lat):
return coordinates_to_cells(lat, lon, self.resolution, radians=False)
Expand Down
11 changes: 1 addition & 10 deletions xdggs/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,8 @@ def sel(self, labels, method=None, tolerance=None):
def _replace(self, new_pd_index: PandasIndex):
raise NotImplementedError()

def _latlon2cellid(self, lat: Any, lon: Any) -> np.ndarray:
"""convert latitude / longitude points to cell ids."""
raise NotImplementedError()

def _cellid2latlon(self, cell_ids: Any) -> tuple[np.ndarray, np.ndarray]:
"""convert cell ids to latitude / longitude (cell centers)."""
raise NotImplementedError()

@property
def cell_centers(self) -> tuple[np.ndarray, np.ndarray]:
return self._cellid2latlon(self._pd_index.index.values)
return self._grid.cell_ids2geographic(self._pd_index.index.values)

@property
def grid_info(self) -> DGGSInfo:
Expand Down
126 changes: 126 additions & 0 deletions xdggs/tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
import xarray as xr

import xdggs


@pytest.mark.parametrize(
["obj", "expected"],
(
(
xr.DataArray(
[0],
coords={
"cell_ids": (
"cells",
[3],
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "ring",
},
)
},
dims="cells",
),
xr.Dataset(
coords={
"latitude": ("cells", [66.44353569089877]),
"longitude": ("cells", [315.0]),
}
),
),
(
xr.Dataset(
coords={
"cell_ids": (
"cells",
[0x832830FFFFFFFFF],
{"grid_name": "h3", "resolution": 3},
)
}
),
xr.Dataset(
coords={
"latitude": ("cells", [38.19320895]),
"longitude": ("cells", [-122.19619676]),
}
),
),
),
)
def test_cell_centers(obj, expected):
obj_ = obj.pipe(xdggs.decode)

actual = obj_.dggs.cell_centers()

xr.testing.assert_allclose(actual, expected)


@pytest.mark.parametrize(
["obj", "expected"],
(
(
xr.DataArray(
[0],
coords={
"cell_ids": (
"cells",
[3],
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "ring",
},
)
},
dims="cells",
),
xr.DataArray(
[0],
coords={
"latitude": ("cells", [66.44353569089877]),
"longitude": ("cells", [315.0]),
"cell_ids": (
"cells",
[3],
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "ring",
},
),
},
dims="cells",
),
),
(
xr.Dataset(
coords={
"cell_ids": (
"cells",
[0x832830FFFFFFFFF],
{"grid_name": "h3", "resolution": 3},
)
}
),
xr.Dataset(
coords={
"latitude": ("cells", [38.19320895]),
"longitude": ("cells", [-122.19619676]),
"cell_ids": (
"cells",
[0x832830FFFFFFFFF],
{"grid_name": "h3", "resolution": 3},
),
}
),
),
),
)
def test_assign_latlon_coords(obj, expected):
obj_ = obj.pipe(xdggs.decode)

actual = obj_.dggs.assign_latlon_coords()

xr.testing.assert_allclose(actual, expected)
17 changes: 10 additions & 7 deletions xdggs/tests/test_h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
np.array([0x832833FFFFFFFFF, 0x832834FFFFFFFFF, 0x832835FFFFFFFFF]),
]
cell_centers = [
np.array([[38.19320895, -122.19619676]]),
np.array([[38.63853196, -123.43390346], [38.82387033, -121.00991811]]),
np.array([[-122.19619676, 38.19320895]]),
np.array([[-123.43390346, 38.63853196], [-121.00991811, 38.82387033]]),
np.array(
[
[39.27846774, -122.2594399],
[37.09786649, -122.13425086],
[37.55231005, -123.35925909],
[-122.2594399, 39.27846774],
[-122.13425086, 37.09786649],
[-123.35925909, 37.55231005],
]
),
]
Expand Down Expand Up @@ -99,8 +99,9 @@ def test_cell_ids2geographic(self, cell_ids, cell_centers):
grid = h3.H3Info(resolution=3)

actual = grid.cell_ids2geographic(cell_ids)
expected = cell_centers
expected = cell_centers.T

assert isinstance(actual, tuple) and len(actual) == 2
np.testing.assert_allclose(actual, expected)

@pytest.mark.parametrize(
Expand All @@ -109,7 +110,9 @@ def test_cell_ids2geographic(self, cell_ids, cell_centers):
def test_geographic2cell_ids(self, cell_centers, cell_ids):
grid = h3.H3Info(resolution=3)

actual = grid.geographic2cell_ids(cell_centers[:, 1], cell_centers[:, 0])
actual = grid.geographic2cell_ids(
lon=cell_centers[:, 0], lat=cell_centers[:, 1]
)
expected = cell_ids

np.testing.assert_equal(actual, expected)
Expand Down

0 comments on commit 4971103

Please sign in to comment.