Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/healpix mapper #71

Merged
merged 12 commits into from
Sep 8, 2023
13 changes: 11 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,17 @@ jobs:
- name: Download Git LFS file
run: |
git lfs pull
- name: Install eccodes
run: sudo apt-get install -y libeccodes-dev
# - name: Install eccodes
# run: sudo apt-get install -y libeccodes-dev
- name: Install eccodes and Dependencies
id: install-dependencies
uses: ecmwf-actions/build-package@v2
with:
self_build: false
dependencies: |
ecmwf/ecbuild@develop
MathisRosenhauer/libaec@master
ecmwf/eccodes@develop

- name: Setup Python
uses: actions/setup-python@v4
Expand Down
87 changes: 87 additions & 0 deletions examples/healpix_grid_box_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import geopandas as gpd
import matplotlib.pyplot as plt
from earthkit import data
from eccodes import codes_grib_find_nearest, codes_grib_new_from_file

from polytope.datacube.backends.xarray import XArrayDatacube
from polytope.engine.hullslicer import HullSlicer
from polytope.polytope import Polytope, Request
from polytope.shapes import Box, Select


class TestOctahedralGrid:
def setup_method(self, method):
ds = data.from_source("file", "./tests/data/healpix.grib")
self.latlon_array = ds.to_xarray().isel(step=0).isel(time=0).isel(isobaricInhPa=0).z
self.xarraydatacube = XArrayDatacube(self.latlon_array)
self.options = {
"values": {
"transformation": {"mapper": {"type": "healpix", "resolution": 32, "axes": ["latitude", "longitude"]}}
}
}
self.slicer = HullSlicer()
self.API = Polytope(datacube=self.latlon_array, engine=self.slicer, axis_options=self.options)

def find_nearest_latlon(self, grib_file, target_lat, target_lon):
# Open the GRIB file
f = open(grib_file)

# Load the GRIB messages from the file
messages = []
while True:
message = codes_grib_new_from_file(f)
if message is None:
break
messages.append(message)

# Find the nearest grid points
nearest_points = []
for message in messages:
nearest_index = codes_grib_find_nearest(message, target_lat, target_lon)
nearest_points.append(nearest_index)

# Close the GRIB file
f.close()

return nearest_points

def test_octahedral_grid(self):
request = Request(
Box(["latitude", "longitude"], [-2, -2], [10, 10]),
Select("time", ["2022-12-14T12:00:00"]),
Select("step", ["01:00:00"]),
Select("isobaricInhPa", [500]),
Select("valid_time", ["2022-12-14T13:00:00"]),
)
result = self.API.retrieve(request)
assert len(result.leaves) == 35

lats = []
lons = []
eccodes_lats = []
eccodes_lons = []
tol = 1e-8
for i in range(len(result.leaves)):
cubepath = result.leaves[i].flatten()
lat = cubepath["latitude"]
lon = cubepath["longitude"]
lats.append(lat)
lons.append(lon)
nearest_points = self.find_nearest_latlon("./tests/data/healpix.grib", lat, lon)
eccodes_lat = nearest_points[0][0]["lat"]
eccodes_lon = nearest_points[0][0]["lon"]
eccodes_lats.append(eccodes_lat)
eccodes_lons.append(eccodes_lon)
assert eccodes_lat - tol <= lat
assert lat <= eccodes_lat + tol
assert eccodes_lon - tol <= lon
assert lon <= eccodes_lon + tol
assert len(eccodes_lats) == 35
worldmap = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
fig, ax = plt.subplots(figsize=(12, 6))
worldmap.plot(color="darkgrey", ax=ax)

plt.scatter(eccodes_lons, eccodes_lats, c="blue", marker="s", s=20)
plt.scatter(lons, lats, s=16, c="red", cmap="YlOrRd")
plt.colorbar(label="Temperature")
plt.show()
89 changes: 88 additions & 1 deletion polytope/datacube/transformations/datacube_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,93 @@ def unmap(self, first_val, second_val):
return final_transformation.unmap(first_val, second_val)


class HealpixGridMapper(DatacubeMapper):
def __init__(self, base_axis, mapped_axes, resolution):
self._mapped_axes = mapped_axes
self._base_axis = base_axis
self._resolution = resolution

def first_axis_vals(self):
rad2deg = 180 / math.pi
vals = [0] * (4 * self._resolution - 1)

# Polar caps
for i in range(1, self._resolution):
val = 90 - (rad2deg * math.acos(1 - (i * i / (3 * self._resolution * self._resolution))))
vals[i - 1] = val
vals[4 * self._resolution - 1 - i] = -val
# Equatorial belts
for i in range(self._resolution, 2 * self._resolution):
val = 90 - (rad2deg * math.acos((4 * self._resolution - 2 * i) / (3 * self._resolution)))
vals[i - 1] = val
vals[4 * self._resolution - 1 - i] = -val
# Equator
vals[2 * self._resolution - 1] = 0

return vals

def map_first_axis(self, lower, upper):
axis_lines = self.first_axis_vals()
return_vals = [val for val in axis_lines if lower <= val <= upper]
return return_vals

def second_axis_vals(self, first_val):
tol = 1e-8
first_val = [i for i in self.first_axis_vals() if first_val - tol <= i <= first_val + tol][0]
idx = self.first_axis_vals().index(first_val)

# Polar caps
if idx < self._resolution - 1 or 3 * self._resolution - 1 < idx <= 4 * self._resolution - 2:
start = 45 / (idx + 1)
vals = [start + i * (360 / (4 * (idx + 1))) for i in range(4 * (idx + 1))]
return vals
# Equatorial belts
start = 45 / self._resolution
if self._resolution - 1 <= idx < 2 * self._resolution - 1 or 2 * self._resolution <= idx < 3 * self._resolution:
r_start = start * (2 - (((idx + 1) - self._resolution + 1) % 2))
vals = [r_start + i * (360 / (4 * self._resolution)) for i in range(4 * self._resolution)]
return vals
# Equator
temp_val = 1 if self._resolution % 2 else 0
r_start = start * (1 - temp_val)
if idx == 2 * self._resolution - 1:
vals = [r_start + i * (360 / (4 * self._resolution)) for i in range(4 * self._resolution)]
return vals

def map_second_axis(self, first_val, lower, upper):
axis_lines = self.second_axis_vals(first_val)
return_vals = [val for val in axis_lines if lower <= val <= upper]
return return_vals

def axes_idx_to_healpix_idx(self, first_idx, second_idx):
idx = 0
for i in range(self._resolution - 1):
if i != first_idx:
idx += 4 * (i + 1)
else:
idx += second_idx
for i in range(self._resolution - 1, 3 * self._resolution):
if i != first_idx:
idx += 4 * self._resolution
else:
idx += second_idx
for i in range(3 * self._resolution, 4 * self._resolution - 1):
if i != first_idx:
idx += 4 * (4 * self._resolution - 1 - i + 1)
else:
idx += second_idx
return idx

def unmap(self, first_val, second_val):
tol = 1e-8
first_val = [i for i in self.first_axis_vals() if first_val - tol <= i <= first_val + tol][0]
first_idx = self.first_axis_vals().index(first_val)
second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0]
second_idx = self.second_axis_vals(first_val).index(second_val)
healpix_index = self.axes_idx_to_healpix_idx(first_idx, second_idx)
return healpix_index


class OctahedralGridMapper(DatacubeMapper):
def __init__(self, base_axis, mapped_axes, resolution):
self._mapped_axes = mapped_axes
Expand Down Expand Up @@ -2785,4 +2872,4 @@ def unmap(self, first_val, second_val):
return octahedral_index


_type_to_datacube_mapper_lookup = {"octahedral": "OctahedralGridMapper"}
_type_to_datacube_mapper_lookup = {"octahedral": "OctahedralGridMapper", "healpix": "HealpixGridMapper"}
2 changes: 0 additions & 2 deletions polytope/engine/hullslicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex
fvalue = ax.to_float(value)
new_polytope = slice(polytope, ax.name, fvalue)
# store the native type
# remapped_val = (ax.remap([value, value])[0][0] + ax.remap([value, value])[0][1])/2
remapped_val = value
if ax.is_cyclic:
remapped_val = (ax.remap([value, value])[0][0] + ax.remap([value, value])[0][1]) / 2
remapped_val = round(remapped_val, int(-math.log10(ax.tol)))
# child = node.create_child(ax, value)
child = node.create_child(ax, remapped_val)
child["unsliced_polytopes"] = copy(node["unsliced_polytopes"])
child["unsliced_polytopes"].remove(polytope)
Expand Down
3 changes: 3 additions & 0 deletions tests/data/healpix.grib
Git LFS file not shown
75 changes: 75 additions & 0 deletions tests/test_healpix_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from earthkit import data
from eccodes import codes_grib_find_nearest, codes_grib_new_from_file

from polytope.datacube.backends.xarray import XArrayDatacube
from polytope.engine.hullslicer import HullSlicer
from polytope.polytope import Polytope, Request
from polytope.shapes import Box, Select


class TestOctahedralGrid:
def setup_method(self, method):
ds = data.from_source("file", "./tests/data/healpix.grib")
self.latlon_array = ds.to_xarray().isel(step=0).isel(time=0).isel(isobaricInhPa=0).z
self.xarraydatacube = XArrayDatacube(self.latlon_array)
self.options = {
"values": {
"transformation": {"mapper": {"type": "healpix", "resolution": 32, "axes": ["latitude", "longitude"]}}
}
}
self.slicer = HullSlicer()
self.API = Polytope(datacube=self.latlon_array, engine=self.slicer, axis_options=self.options)

def find_nearest_latlon(self, grib_file, target_lat, target_lon):
# Open the GRIB file
f = open(grib_file)

# Load the GRIB messages from the file
messages = []
while True:
message = codes_grib_new_from_file(f)
if message is None:
break
messages.append(message)

# Find the nearest grid points
nearest_points = []
for message in messages:
nearest_index = codes_grib_find_nearest(message, target_lat, target_lon)
nearest_points.append(nearest_index)

# Close the GRIB file
f.close()

return nearest_points

def test_octahedral_grid(self):
request = Request(
Box(["latitude", "longitude"], [-2, -2], [10, 10]),
Select("time", ["2022-12-14T12:00:00"]),
Select("step", ["01:00:00"]),
Select("isobaricInhPa", [500]),
Select("valid_time", ["2022-12-14T13:00:00"]),
)
result = self.API.retrieve(request)
assert len(result.leaves) == 35

lats = []
lons = []
eccodes_lats = []
tol = 1e-8
for i in range(len(result.leaves)):
cubepath = result.leaves[i].flatten()
lat = cubepath["latitude"]
lon = cubepath["longitude"]
lats.append(lat)
lons.append(lon)
nearest_points = self.find_nearest_latlon("./tests/data/healpix.grib", lat, lon)
eccodes_lat = nearest_points[0][0]["lat"]
eccodes_lon = nearest_points[0][0]["lon"]
eccodes_lats.append(eccodes_lat)
assert eccodes_lat - tol <= lat
assert lat <= eccodes_lat + tol
assert eccodes_lon - tol <= lon
assert lon <= eccodes_lon + tol
assert len(eccodes_lats) == 35
Loading