From 3d6b37cc5ec206872857a03d946c543304966c71 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Wed, 22 Nov 2023 14:49:41 +0100 Subject: [PATCH] make All shape work --- polytope/datacube/datacube_axis.py | 5 ++ polytope/shapes.py | 2 +- tests/test_shapes.py | 74 ++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/test_shapes.py diff --git a/polytope/datacube/datacube_axis.py b/polytope/datacube/datacube_axis.py index 77a223710..625051f7c 100644 --- a/polytope/datacube/datacube_axis.py +++ b/polytope/datacube/datacube_axis.py @@ -1,3 +1,4 @@ +import math from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, List @@ -18,6 +19,10 @@ def update_range(): def to_intervals(range): update_range() + if range[0] == -math.inf: + range[0] = cls.range[0] + if range[1] == math.inf: + range[1] = cls.range[1] axis_lower = cls.range[0] axis_upper = cls.range[1] axis_range = axis_upper - axis_lower diff --git a/polytope/shapes.py b/polytope/shapes.py index 306d1a64a..cc77273a7 100644 --- a/polytope/shapes.py +++ b/polytope/shapes.py @@ -89,7 +89,7 @@ def __repr__(self): class Span(Shape): """1-D range along a single axis""" - def __init__(self, axis, lower=None, upper=None): + def __init__(self, axis, lower=-math.inf, upper=math.inf): assert not isinstance(lower, list) assert not isinstance(upper, list) self.axis = axis diff --git a/tests/test_shapes.py b/tests/test_shapes.py new file mode 100644 index 000000000..ba355dd4c --- /dev/null +++ b/tests/test_shapes.py @@ -0,0 +1,74 @@ +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from polytope.datacube.backends.FDB_datacube import FDBDatacube +from polytope.datacube.backends.xarray import XArrayDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import All, Select, Span + + +class TestSlicing3DXarrayDatacube: + def setup_method(self, method): + # Create a dataarray with 3 labelled axes using different index types + array = xr.DataArray( + np.random.randn(3, 6, 129, 360), + dims=("date", "step", "level", "longitude"), + coords={ + "date": pd.date_range("2000-01-01", "2000-01-03", 3), + "step": [0, 3, 6, 9, 12, 15], + "level": range(1, 130), + "longitude": range(0, 360), + }, + ) + self.xarraydatacube = XArrayDatacube(array) + self.options = {"longitude": {"transformation": {"cyclic": [0, 360]}}} + self.slicer = HullSlicer() + self.API = Polytope(datacube=array, engine=self.slicer, axis_options=self.options) + + def test_all(self): + request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), All("level"), Select("longitude", [1])) + result = self.API.retrieve(request) + assert len(result.leaves) == 129 + + def test_all_cyclic(self): + request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), Select("level", [1]), All("longitude")) + result = self.API.retrieve(request) + # result.pprint() + assert len(result.leaves) == 360 + + @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_all_mapper_cyclic(self): + self.options = { + "values": { + "transformation": { + "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} + } + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + "longitude": {"transformation": {"cyclic": [0, 360]}}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 11} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + request = Request( + Select("step", [11]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20230710T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["151130"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Span("latitude", 89.9, 90), + All("longitude"), + ) + result = self.API.retrieve(request) + # result.pprint() + assert len(result.leaves) == 20