Skip to content

Commit

Permalink
Merge pull request #316 from ecmwf/develop
Browse files Browse the repository at this point in the history
v 1.0.26
  • Loading branch information
mathleur authored Feb 3, 2025
2 parents 0ca8808 + 0eb7f59 commit 5b5c882
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 10 deletions.
37 changes: 29 additions & 8 deletions polytope_feature/datacube/datacube_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import xarray as xr

from .transformations.datacube_cyclic.datacube_cyclic import DatacubeAxisCyclic
from .transformations.datacube_mappers.datacube_mappers import DatacubeMapper
Expand Down Expand Up @@ -142,22 +143,39 @@ def find_indices_between(self, indexes_ranges, low, up, datacube, method=None):
)
return indexes_between_ranges

@staticmethod
def values_type(values):
type_ = None
if isinstance(values, xr.core.variable.IndexVariable) or isinstance(values, xr.core.variable.Variable):
# If we have some xarray variable, transform them to actual variable type
values = np.array(values)
type_ = values.dtype.type
else:
if len(values) == 0:
# If we have no values (newly created axis), default to a float
values = np.array(values)
type_ = values.dtype.type
else:
type_ = type(values[0])
return type_

@staticmethod
def create_standard(name, values, datacube):
values = np.array(values)
DatacubeAxis.check_axis_type(name, values)
val_type = DatacubeAxis.values_type(values)

DatacubeAxis.check_axis_type(name, val_type)
if datacube._axes is None:
datacube._axes = {name: deepcopy(_type_to_axis_lookup[values.dtype.type])}
datacube._axes = {name: deepcopy(_type_to_axis_lookup[val_type])}
else:
datacube._axes[name] = deepcopy(_type_to_axis_lookup[values.dtype.type])
datacube._axes[name] = deepcopy(_type_to_axis_lookup[val_type])

datacube._axes[name].name = name
datacube.axis_counter += 1

@staticmethod
def check_axis_type(name, values):
# NOTE: The values here need to be a numpy array which has a dtype attribute
if values.dtype.type not in _type_to_axis_lookup:
raise ValueError(f"Could not create a mapper for index type {values.dtype.type} for axis {name}")
def check_axis_type(name, val_type):
if val_type not in _type_to_axis_lookup:
raise ValueError(f"Could not create a mapper for index type {val_type} for axis {name}")


transformations_order = [
Expand Down Expand Up @@ -302,10 +320,13 @@ def serialize(self, value):
np.int64: IntDatacubeAxis(),
np.datetime64: PandasTimestampDatacubeAxis(),
np.timedelta64: PandasTimedeltaDatacubeAxis(),
pd.Timedelta: PandasTimedeltaDatacubeAxis(),
np.float64: FloatDatacubeAxis(),
np.float32: FloatDatacubeAxis(),
np.int32: IntDatacubeAxis(),
np.str_: UnsliceableDatacubeAxis(),
str: UnsliceableDatacubeAxis(),
np.object_: UnsliceableDatacubeAxis(),
int: IntDatacubeAxis(),
float: FloatDatacubeAxis(),
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import deepcopy
from importlib import import_module

import pandas as pd

from ..datacube_transformations import DatacubeAxisTransformation


Expand Down Expand Up @@ -75,4 +77,48 @@ def make_str(self, value):
return tuple(values)


_type_to_datacube_type_change_lookup = {"int": "TypeChangeStrToInt"}
class TypeChangeStrToTimestamp(DatacubeAxisTypeChange):
def __init__(self, axis_name, new_type):
self.axis_name = axis_name
self._new_type = new_type

def transform_type(self, value):
try:
return pd.Timestamp(value)
except ValueError:
return None

def make_str(self, value):
values = []
for val in value:
values.append(val.strftime("%Y%m%d"))
return tuple(values)


class TypeChangeStrToTimedelta(DatacubeAxisTypeChange):
def __init__(self, axis_name, new_type):
self.axis_name = axis_name
self._new_type = new_type

def transform_type(self, value):
try:
hours = int(value[:2])
mins = int(value[2:])
return pd.Timedelta(hours=hours, minutes=mins)
except ValueError:
return None

def make_str(self, value):
values = []
for val in value:
hours = int(val.total_seconds() // 3600)
mins = int((val.total_seconds() % 3600) // 60)
values.append(f"{hours:02d}{mins:02d}")
return tuple(values)


_type_to_datacube_type_change_lookup = {
"int": "TypeChangeStrToInt",
"date": "TypeChangeStrToTimestamp",
"time": "TypeChangeStrToTimedelta",
}
2 changes: 1 addition & 1 deletion polytope_feature/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.25"
__version__ = "1.0.26"
69 changes: 69 additions & 0 deletions tests/test_date_time_unmerged.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pandas as pd
import pytest

from polytope_feature.polytope import Polytope, Request
from polytope_feature.shapes import Box, Select, Span


class TestSlicingFDBDatacube:
def setup_method(self, method):
# Create a dataarray with 3 labelled axes using different index types
self.options = {
"axis_config": [
{"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]},
{"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]},
{"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]},
{
"axis_name": "values",
"transformations": [
{"name": "mapper", "type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}
],
},
{"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]},
{"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]},
],
"pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper", "type": "fc"},
"compressed_axes_config": [
"longitude",
"latitude",
"levtype",
"step",
"date",
"domain",
"expver",
"param",
"class",
"stream",
"type",
],
}

# Testing different shapes
@pytest.mark.fdb
def test_fdb_datacube(self):
import pygribjump as gj

request = Request(
Select("step", [0]),
Select("levtype", ["sfc"]),
# Select("date", [pd.Timestamp("20240118")]),
Select("time", [pd.Timedelta("00:00:00")]),
# Span("time", [pd.Timedelta("00:00:00")]),
Span("date", pd.Timestamp("20240118"), pd.Timestamp("20240119")),
Select("domain", ["g"]),
Select("expver", ["0001"]),
Select("param", ["167"]),
Select("class", ["od"]),
Select("stream", ["oper"]),
Select("type", ["fc"]),
Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]),
)
self.fdbdatacube = gj.GribJump()
self.API = Polytope(
datacube=self.fdbdatacube,
options=self.options,
)
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 3
assert len(result.leaves[0].result) == 3

0 comments on commit 5b5c882

Please sign in to comment.