Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mathleur committed Nov 22, 2023
1 parent d4697b8 commit 4f72f1f
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 15 deletions.
18 changes: 9 additions & 9 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .datacube import Datacube, IndexTree

import time


class FDBDatacube(Datacube):
def __init__(self, config={}, axis_options={}):
Expand Down Expand Up @@ -44,13 +46,11 @@ def remove_unwanted_axes(self, leaf_path):
return leaf_path

def get(self, requests: IndexTree, leaf_path={}):
time1 = time.time()
# First when request node is root, go to its children
if requests.axis.name == "root":
if len(requests.children) == 0:
pass
else:
for c in requests.children:
self.get(c)
for c in requests.children:
self.get(c)
# If request node has no children, we have a leaf so need to assign fdb values to it
else:
key_value_path = {requests.axis.name: requests.value}
Expand All @@ -67,6 +67,8 @@ def get(self, requests: IndexTree, leaf_path={}):
else:
for c in requests.children:
self.get(c, leaf_path)
print("TOTAL GET TIME")
print(time.time() - time1)

def get_2nd_last_values(self, requests, leaf_path={}):
# In this function, we recursively loop over the last two layers of the tree and store the indices of the
Expand Down Expand Up @@ -154,11 +156,9 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length):
interm_request_ranges.append(current_request_ranges)
request_ranges_with_idx = list(enumerate(interm_request_ranges))
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
sorted_request_ranges = [item[1] for item in sorted_list]
original_indices = [item[0] for item in sorted_list]
original_indices, sorted_request_ranges = zip(*sorted_list)
fdb_requests.append(tuple((path, sorted_request_ranges)))
subxarray = self.fdb.extract(fdb_requests)
output_values = subxarray
output_values = self.fdb.extract(fdb_requests)
return (output_values, original_indices)

def datacube_natural_indexes(self, axis, subarray):
Expand Down
7 changes: 6 additions & 1 deletion polytope/engine/hullslicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex
upper = ax.from_float(upper + tol)
flattened = node.flatten()
method = polytope.method
for value in datacube.get_indices(flattened, ax, lower, upper, method):
values = datacube.get_indices(flattened, ax, lower, upper, method)

if len(values) == 0:
node.remove_branch()

for value in values:
# convert to float for slicing
fvalue = ax.to_float(value)
new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx)
Expand Down
1 change: 0 additions & 1 deletion tests/test_datacube_axes_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_created_axes(self):
assert self.datacube._axes["longitude"].has_mapper
assert isinstance(self.datacube._axes["longitude"], FloatDatacubeAxis)
assert not ("values" in self.datacube._axes.keys())
print(list(self.datacube._axes["latitude"].find_indexes({}, self.datacube)[:5]))
assert list(self.datacube._axes["latitude"].find_indexes({}, self.datacube)[:5]) == [
89.94618771566562,
89.87647835333229,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fdb_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setup_method(self, method):
self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options)

# Testing different shapes
@pytest.mark.skip(reason="can't install fdb branch on CI")
# @pytest.mark.skip(reason="can't install fdb branch on CI")
def test_fdb_datacube(self):
request = Request(
Select("step", [0]),
Expand Down
97 changes: 97 additions & 0 deletions tests/test_incomplete_tree_fdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pandas as pd
import pytest
from eccodes import codes_grib_find_nearest, codes_grib_new_from_file
from helper_functions import download_test_data

from polytope.datacube.backends.fdb import FDBDatacube
from polytope.engine.hullslicer import HullSlicer
from polytope.polytope import Polytope, Request
from polytope.shapes import Select


class TestRegularGrid:
def setup_method(self, method):
nexus_url = "https://get.ecmwf.int/test-data/polytope/test-data/era5-levels-members.grib"
download_test_data(nexus_url, "era5-levels-members.grib")
self.options = {
"values": {
"transformation": {"mapper": {"type": "regular", "resolution": 30, "axes": ["latitude", "longitude"]}}
},
"date": {"transformation": {"merge": {"with": "time", "linkers": [" ", "00"]}}},
"step": {"transformation": {"type_change": "int"}},
"number": {"transformation": {"type_change": "int"}},
"longitude": {"transformation": {"cyclic": [0, 360]}},
}
self.config = {"class": "ea", "expver": "0001", "levtype": "pl", "step": 0}
self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options)
self.slicer = HullSlicer()
self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options)

def find_nearest_latlon(self, grib_file, target_lat, target_lon):
# Open the GRIB file
f = open(grib_file)

# Load the GRIB messages from the file
messages = []
while True:
message = codes_grib_new_from_file(f)
if message is None:
break
messages.append(message)

# Find the nearest grid points
nearest_points = []
for message in messages:
nearest_index = codes_grib_find_nearest(message, target_lat, target_lon)
nearest_points.append(nearest_index)

# Close the GRIB file
f.close()

return nearest_points

@pytest.mark.internet
# @pytest.mark.skip(reason="can't install fdb branch on CI")
def test_incomplete_fdb_branch(self):
request = Request(
Select("step", [0]),
Select("levtype", ["pl"]),
Select("date", [pd.Timestamp("20170102T120000")]),
Select("domain", ["g"]),
Select("expver", ["0001"]),
Select("param", ["129"]),
Select("class", ["ea"]),
Select("stream", ["enda"]),
Select("type", ["an"]),
Select("latitude", [0]),
Select("longitude", [1]),
Select("levelist", ["500"]),
Select("number", ["0"]),
)
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 1
assert result.is_root()

@pytest.mark.internet
# @pytest.mark.skip(reason="can't install fdb branch on CI")
def test_incomplete_fdb_branch_2(self):
request = Request(
Select("step", [0]),
Select("levtype", ["pl"]),
Select("date", [pd.Timestamp("20170102T120000")]),
Select("domain", ["g"]),
Select("expver", ["0001"]),
Select("param", ["129"]),
Select("class", ["ea"]),
Select("stream", ["enda"]),
Select("type", ["an"]),
Select("latitude", [1]),
Select("longitude", [0]),
Select("levelist", ["500"]),
Select("number", ["0"]),
)
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 1
assert result.is_root()
1 change: 0 additions & 1 deletion tests/test_merge_cyclic_octahedral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def setup_method(self, method):
self.slicer = HullSlicer()
self.API = Polytope(datacube=self.array, engine=self.slicer, axis_options=self.options)

# @pytest.mark.skip(reason="Need date time to not be strings")
def test_merge_axis(self):
# NOTE: does not work because the date is a string in the merge option...
date = np.datetime64("2000-01-01T06:00:00")
Expand Down
1 change: 0 additions & 1 deletion tests/test_merge_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ def setup_method(self, method):
def test_merge_axis(self):
request = Request(Select("date", [pd.Timestamp("2000-01-01T06:00:00")]))
result = self.API.retrieve(request)
# assert result.leaves[0].flatten()["date"] == np.datetime64("2000-01-01T06:00:00")
assert result.leaves[0].flatten()["date"] == pd.Timestamp("2000-01-01T06:00:00")
2 changes: 1 addition & 1 deletion tests/test_regular_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def find_nearest_latlon(self, grib_file, target_lat, target_lon):
return nearest_points

@pytest.mark.internet
@pytest.mark.skip(reason="can't install fdb branch on CI")
# @pytest.mark.skip(reason="can't install fdb branch on CI")
def test_regular_grid(self):
request = Request(
Select("step", [0]),
Expand Down

0 comments on commit 4f72f1f

Please sign in to comment.