From 49711037911128d8826aca60972197fd155afdf7 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 3 Jul 2024 13:23:55 +0200 Subject: [PATCH] fix the cell centers computation (#61) * expose the grid info object * use the grid info methods to implement the reconstruction of cell ids * add a method to retrieve the cell centers using the accessor * forward to the grid object * remove the left-over conversion methods on the index base class * add a test to make sure the accessor returns the right cell centers * inverse the order of coordinates for H3 * fix the old `assign_latlon_coords` method * change the expectations in the h3 tests now we always expect longitude, then latitude as a tuple --- xdggs/accessor.py | 22 +++++- xdggs/h3.py | 4 +- xdggs/index.py | 11 +-- xdggs/tests/test_accessor.py | 126 +++++++++++++++++++++++++++++++++++ xdggs/tests/test_h3.py | 17 +++-- 5 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 xdggs/tests/test_accessor.py diff --git a/xdggs/accessor.py b/xdggs/accessor.py index 2013e7e..f8b67cc 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -1,6 +1,7 @@ import numpy.typing as npt import xarray as xr +from xdggs.grid import DGGSInfo from xdggs.index import DGGSIndex @@ -55,6 +56,10 @@ def params(self) -> dict: """The grid parameters after normalization.""" return self.index.grid.to_dict() + @property + def grid_info(self) -> DGGSInfo: + return self.index.grid_info + def sel_latlon( self, latitude: npt.ArrayLike, longitude: npt.ArrayLike ) -> xr.Dataset | xr.DataArray: @@ -74,15 +79,28 @@ def sel_latlon( with all cells that contain the input latitude/longitude data points. """ - cell_indexers = {self._name: self.index._latlon2cellid(latitude, longitude)} + cell_indexers = { + self._name: self.grid_info.geographic2cell_ids(latitude, longitude) + } return self._obj.sel(cell_indexers) def assign_latlon_coords(self) -> xr.Dataset | xr.DataArray: """Return a new Dataset or DataArray with new "latitude" and "longitude" coordinates representing the grid cell centers.""" - lon_data, lat_data = self.index.cell_centers + lon_data, lat_data = self.index.cell_centers() + return self._obj.assign_coords( latitude=(self.index._dim, lat_data), longitude=(self.index._dim, lon_data), ) + + def cell_centers(self): + lon_data, lat_data = self.index.cell_centers() + + return xr.Dataset( + coords={ + "latitude": (self.index._dim, lat_data), + "longitude": (self.index._dim, lon_data), + } + ) diff --git a/xdggs/h3.py b/xdggs/h3.py index dabc6fc..17db4ed 100644 --- a/xdggs/h3.py +++ b/xdggs/h3.py @@ -38,7 +38,9 @@ def to_dict(self: Self) -> dict[str, Any]: def cell_ids2geographic( self, cell_ids: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: - return cells_to_coordinates(cell_ids, radians=False) + lat, lon = cells_to_coordinates(cell_ids, radians=False) + + return lon, lat def geographic2cell_ids(self, lon, lat): return coordinates_to_cells(lat, lon, self.resolution, radians=False) diff --git a/xdggs/index.py b/xdggs/index.py index 4728b82..2c3a7c2 100644 --- a/xdggs/index.py +++ b/xdggs/index.py @@ -69,17 +69,8 @@ def sel(self, labels, method=None, tolerance=None): def _replace(self, new_pd_index: PandasIndex): raise NotImplementedError() - def _latlon2cellid(self, lat: Any, lon: Any) -> np.ndarray: - """convert latitude / longitude points to cell ids.""" - raise NotImplementedError() - - def _cellid2latlon(self, cell_ids: Any) -> tuple[np.ndarray, np.ndarray]: - """convert cell ids to latitude / longitude (cell centers).""" - raise NotImplementedError() - - @property def cell_centers(self) -> tuple[np.ndarray, np.ndarray]: - return self._cellid2latlon(self._pd_index.index.values) + return self._grid.cell_ids2geographic(self._pd_index.index.values) @property def grid_info(self) -> DGGSInfo: diff --git a/xdggs/tests/test_accessor.py b/xdggs/tests/test_accessor.py new file mode 100644 index 0000000..cd6755a --- /dev/null +++ b/xdggs/tests/test_accessor.py @@ -0,0 +1,126 @@ +import pytest +import xarray as xr + +import xdggs + + +@pytest.mark.parametrize( + ["obj", "expected"], + ( + ( + xr.DataArray( + [0], + coords={ + "cell_ids": ( + "cells", + [3], + { + "grid_name": "healpix", + "resolution": 1, + "indexing_scheme": "ring", + }, + ) + }, + dims="cells", + ), + xr.Dataset( + coords={ + "latitude": ("cells", [66.44353569089877]), + "longitude": ("cells", [315.0]), + } + ), + ), + ( + xr.Dataset( + coords={ + "cell_ids": ( + "cells", + [0x832830FFFFFFFFF], + {"grid_name": "h3", "resolution": 3}, + ) + } + ), + xr.Dataset( + coords={ + "latitude": ("cells", [38.19320895]), + "longitude": ("cells", [-122.19619676]), + } + ), + ), + ), +) +def test_cell_centers(obj, expected): + obj_ = obj.pipe(xdggs.decode) + + actual = obj_.dggs.cell_centers() + + xr.testing.assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + ["obj", "expected"], + ( + ( + xr.DataArray( + [0], + coords={ + "cell_ids": ( + "cells", + [3], + { + "grid_name": "healpix", + "resolution": 1, + "indexing_scheme": "ring", + }, + ) + }, + dims="cells", + ), + xr.DataArray( + [0], + coords={ + "latitude": ("cells", [66.44353569089877]), + "longitude": ("cells", [315.0]), + "cell_ids": ( + "cells", + [3], + { + "grid_name": "healpix", + "resolution": 1, + "indexing_scheme": "ring", + }, + ), + }, + dims="cells", + ), + ), + ( + xr.Dataset( + coords={ + "cell_ids": ( + "cells", + [0x832830FFFFFFFFF], + {"grid_name": "h3", "resolution": 3}, + ) + } + ), + xr.Dataset( + coords={ + "latitude": ("cells", [38.19320895]), + "longitude": ("cells", [-122.19619676]), + "cell_ids": ( + "cells", + [0x832830FFFFFFFFF], + {"grid_name": "h3", "resolution": 3}, + ), + } + ), + ), + ), +) +def test_assign_latlon_coords(obj, expected): + obj_ = obj.pipe(xdggs.decode) + + actual = obj_.dggs.assign_latlon_coords() + + xr.testing.assert_allclose(actual, expected) diff --git a/xdggs/tests/test_h3.py b/xdggs/tests/test_h3.py index 85a1b0e..d82cab4 100644 --- a/xdggs/tests/test_h3.py +++ b/xdggs/tests/test_h3.py @@ -14,13 +14,13 @@ np.array([0x832833FFFFFFFFF, 0x832834FFFFFFFFF, 0x832835FFFFFFFFF]), ] cell_centers = [ - np.array([[38.19320895, -122.19619676]]), - np.array([[38.63853196, -123.43390346], [38.82387033, -121.00991811]]), + np.array([[-122.19619676, 38.19320895]]), + np.array([[-123.43390346, 38.63853196], [-121.00991811, 38.82387033]]), np.array( [ - [39.27846774, -122.2594399], - [37.09786649, -122.13425086], - [37.55231005, -123.35925909], + [-122.2594399, 39.27846774], + [-122.13425086, 37.09786649], + [-123.35925909, 37.55231005], ] ), ] @@ -99,8 +99,9 @@ def test_cell_ids2geographic(self, cell_ids, cell_centers): grid = h3.H3Info(resolution=3) actual = grid.cell_ids2geographic(cell_ids) - expected = cell_centers + expected = cell_centers.T + assert isinstance(actual, tuple) and len(actual) == 2 np.testing.assert_allclose(actual, expected) @pytest.mark.parametrize( @@ -109,7 +110,9 @@ def test_cell_ids2geographic(self, cell_ids, cell_centers): def test_geographic2cell_ids(self, cell_centers, cell_ids): grid = h3.H3Info(resolution=3) - actual = grid.geographic2cell_ids(cell_centers[:, 1], cell_centers[:, 0]) + actual = grid.geographic2cell_ids( + lon=cell_centers[:, 0], lat=cell_centers[:, 1] + ) expected = cell_ids np.testing.assert_equal(actual, expected)