Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support amda template params #192

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions speasy/core/dataprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from speasy.core.inventory import ProviderInventory
from speasy.core.inventory.indexes import (DatasetIndex, ParameterIndex,
SpeasyIndex, inventory_has_changed)
from speasy.core.proxy import GetInventory, Proxyfiable
from speasy.core.proxy import GetInventory, Proxyfiable, MINIMUM_REQUIRED_PROXY_VERSION
from speasy.inventories import flat_inventories, tree

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -37,11 +37,13 @@ def _get_inventory_args(provider_name, **kwargs):


class DataProvider:
def __init__(self, provider_name: str, provider_alt_names: List or None = None, inventory_disable_proxy=False):
def __init__(self, provider_name: str, provider_alt_names: List or None = None, inventory_disable_proxy=False,
min_proxy_version=MINIMUM_REQUIRED_PROXY_VERSION):
self.provider_name = provider_name
self._inventory_disable_proxy = inventory_disable_proxy
self.provider_alt_names = provider_alt_names or []
self.flat_inventory = ProviderInventory()
self.min_proxy_version = min_proxy_version
flat_inventories.__dict__[provider_name] = self.flat_inventory
for alt_name in self.provider_alt_names:
flat_inventories.__dict__[alt_name] = self.flat_inventory
Expand All @@ -61,7 +63,8 @@ def update_inventory(self):
lock = Lock()
with lock:
new_inventory = self._inventory(provider_name=self.provider_name,
disable_proxy=self._inventory_disable_proxy)
disable_proxy=self._inventory_disable_proxy,
min_proxy_version=self.min_proxy_version)
if inventory_has_changed(tree.__dict__.get(self.provider_name, SpeasyIndex("", "", "")), new_inventory):
if self.provider_name in tree.__dict__:
tree.__dict__[self.provider_name].clear()
Expand Down
71 changes: 49 additions & 22 deletions speasy/core/impex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from ...core import make_utc_datetime

from ...core.inventory.indexes import ComponentIndex, DatasetIndex, ParameterIndex, SpeasyIndex, \
TimetableIndex, CatalogIndex
from ...core.inventory.indexes import (ComponentIndex, DatasetIndex, ParameterIndex, SpeasyIndex,
TimetableIndex, CatalogIndex, DerivedParameterIndex, AnyProductIndex)
from ...core.codecs import get_codec
from ...core.proxy import MINIMUM_REQUIRED_PROXY_VERSION
from ...products.variable import SpeasyVariable, merge, DataContainer
from ...products.catalog import Catalog
from ...products.dataset import Dataset
Expand All @@ -28,7 +29,6 @@
from .utils import load_catalog, load_timetable, is_private, is_public
from .exceptions import MissingCredentials


log = logging.getLogger(__name__)


Expand All @@ -45,7 +45,8 @@ class ImpexProductType(Enum):

class ImpexProvider(DataProvider):
def __init__(self, provider_name: str, server_url: str, max_chunk_size_days: int = 10, capabilities: List = None,
username: str = "", password: str = "", name_mapping: Dict = None, output_format: str = 'CDF'):
username: str = "", password: str = "", name_mapping: Dict = None, output_format: str = 'CDF',
min_proxy_version=MINIMUM_REQUIRED_PROXY_VERSION):
self.provider_name = provider_name
self.server_url = server_url
self.client = ImpexClient(capabilities=capabilities, server_url=server_url,
Expand All @@ -55,7 +56,7 @@ def __init__(self, provider_name: str, server_url: str, max_chunk_size_days: int
self.max_chunk_size_days = max_chunk_size_days
self.name_mapping = name_mapping
self._cdf_codec = get_codec('application/x-cdf')
DataProvider.__init__(self, provider_name=provider_name)
DataProvider.__init__(self, provider_name=provider_name, min_proxy_version=min_proxy_version)

def reset_credentials(self, username: str = "", password: str = ""):
"""Reset user credentials and update the inventory by replacing the information contained in the configuration
Expand Down Expand Up @@ -381,6 +382,7 @@ def get_parameter(self, product, start_time, stop_time,
if hasattr(self, 'has_time_restriction') and self.has_time_restriction(product, start_time, stop_time):
kwargs['disable_proxy'] = True
kwargs['restricted_period'] = True

return self._get_parameter(product, start_time, stop_time, extra_http_headers=extra_http_headers,
output_format=output_format or self.client.output_format, **kwargs)

Expand Down Expand Up @@ -572,16 +574,35 @@ def product_type(self, product_id: str or SpeasyIndex) -> ImpexProductType:

return ImpexProductType.UNKNOWN

def find_parent_dataset(self, product_id: Union[str, DatasetIndex, ParameterIndex,
ComponentIndex]) -> Optional[str]:
def to_index(self, product_id: str or SpeasyIndex) -> AnyProductIndex:
if type(product_id) in (
DatasetIndex, ParameterIndex, DerivedParameterIndex, ComponentIndex, TimetableIndex, CatalogIndex):
return product_id
elif type(product_id) is str:
if p := flat_inventories.__dict__[self.provider_name].datasets.get(product_id):
return p
if p := flat_inventories.__dict__[self.provider_name].parameters.get(product_id):
return p
if p := flat_inventories.__dict__[self.provider_name].components.get(product_id):
return p
if p := flat_inventories.__dict__[self.provider_name].timetables.get(product_id):
return p
if p := flat_inventories.__dict__[self.provider_name].catalogs.get(product_id):
return p
raise ValueError(f"Unknown product: {product_id}")

def find_parent_dataset(
self,
product_id: Union[str, DatasetIndex, ParameterIndex, DerivedParameterIndex, ComponentIndex]
) -> Optional[str]:
Fixed Show fixed Hide fixed

product_id = to_xmlid(product_id)
product_type = self.product_type(product_id)
if product_type is ImpexProductType.DATASET:
product = self.to_index(product_id)
if isinstance(product, DatasetIndex):
return product_id
elif product_type in (ImpexProductType.COMPONENT, ImpexProductType.PARAMETER):
for dataset in flat_inventories.__dict__[self.provider_name].datasets.values():
if product_id in dataset:
return to_xmlid(dataset)
elif type(product) in (ParameterIndex, ComponentIndex, DerivedParameterIndex):
return product.dataset
return None

@staticmethod
def is_user_product(product_id: str or SpeasyIndex, collection: Dict):
Expand Down Expand Up @@ -762,10 +783,14 @@ def _get_parameter(self, product, start_time, stop_time,
Optional[
SpeasyVariable]:
log.debug(f'Get data: product = {product}, data start time = {start_time}, data stop time = {stop_time}')
if hasattr(self, 'get_real_product_id'):
real_product_id = self.get_real_product_id(product, **kwargs)
if real_product_id:
kwargs['real_product_id'] = real_product_id
return self._dl_parameter(start_time=start_time, stop_time=stop_time, parameter_id=product,
extra_http_headers=extra_http_headers,
output_format=output_format,
product_variables=self._get_product_variables(product),
product_variables=self._get_product_variables(product, **kwargs),
restricted_period=restricted_period,
time_format='UNIXTIME', **kwargs)

Expand All @@ -775,13 +800,14 @@ def _dl_parameter_chunk(self, start_time: datetime, stop_time: datetime, paramet
product_variables: List = None, **kwargs) -> Optional[SpeasyVariable]:
url = self.client.get_parameter(start_time=start_time.strftime('%Y-%m-%dT%H:%M:%SZ'),
stop_time=stop_time.strftime('%Y-%m-%dT%H:%M:%SZ'),
parameter_id=parameter_id, extra_http_headers=extra_http_headers,
parameter_id=parameter_id,
extra_http_headers=extra_http_headers,
use_credentials=use_credentials, **kwargs)
# check status until done
if url is not None:
var = None
if not product_variables:
product_variables = [parameter_id]
product_variables = [kwargs.get('real_product_id', parameter_id)]
if kwargs.get('output_format', self.client.output_format) in ["CDF_ISTP", "CDF"]:
var = self._cdf_codec.load_variables(variables=product_variables, file=url)
else:
Expand Down Expand Up @@ -829,14 +855,15 @@ def _dl_parameter(self, start_time: datetime, stop_time: datetime, parameter_id:
curr_t += dt
return var
else:
return self._dl_parameter_chunk(start_time, stop_time, parameter_id, extra_http_headers=extra_http_headers,
return self._dl_parameter_chunk(start_time, stop_time, parameter_id,
extra_http_headers=extra_http_headers,
use_credentials=use_credentials,
product_variables=product_variables, **kwargs)

def _dl_user_parameter(self, start_time: datetime, stop_time: datetime, parameter_id: str,
**kwargs) -> Optional[SpeasyVariable]:
return self._dl_parameter(parameter_id=parameter_id, start_time=start_time, stop_time=stop_time,
product_variables=self._get_product_variables(parameter_id),
product_variables=self._get_product_variables(parameter_id, **kwargs),
use_credentials=True, **kwargs)

def _dl_timetable(self, timetable_id: str, use_credentials=False, **kwargs):
Expand Down Expand Up @@ -872,9 +899,9 @@ def _dl_catalog(self, catalog_id: str, use_credentials=False, **kwargs):
def _dl_user_catalog(self, catalog_id: str, **kwargs):
return self._dl_catalog(catalog_id, use_credentials=True, **kwargs)

def _get_product_variables(self, product_id: str or SpeasyIndex):
def _get_product_variables(self, product_id: str or SpeasyIndex, **kwargs):
product_id = to_xmlid(product_id)
return [product_id]
return [kwargs.get('real_product_id', product_id)]

@staticmethod
def _concatenate_variables(variables: Dict[str, SpeasyVariable], product_id) -> Optional[SpeasyVariable]:
Expand Down Expand Up @@ -906,8 +933,8 @@ def _concatenate_variables(variables: Dict[str, SpeasyVariable], product_id) ->
values=DataContainer(values=values, meta=meta, name=product_id, is_time_dependent=True),
columns=columns)

def _get_obs_data_tree(self) -> str or None:
return self.client.get_obs_data_tree()
def _get_obs_data_tree(self, add_template_info=False) -> str or None:
return self.client.get_obs_data_tree(add_template_info=add_template_info)

def _get_timetables_tree(self) -> str or None:
return self.client.get_time_table_list()
Expand Down
9 changes: 6 additions & 3 deletions speasy/core/impex/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ def auth(self):
def in_progress(result):
return result == "in progress"

def get_obs_data_tree(self, use_credentials=False):
params = {}
def get_obs_data_tree(self, use_credentials=False, **kwargs):
params = {
}
if kwargs.get('add_template_info', False):
params['templateInfo'] = True
if use_credentials:
params['userID'], params['password'] = self.get_credentials()
return self._send_indirect_request(ImpexEndpoint.OBSTREE, params=params)
Expand Down Expand Up @@ -109,7 +112,7 @@ def get_parameter(self, start_time, stop_time, parameter_id, extra_http_headers=
params = {
'startTime': start_time,
'stopTime': stop_time,
'parameterID': parameter_id,
'parameterID': kwargs.get('real_product_id', parameter_id),
'outputFormat': kwargs.get('output_format', self.output_format)
}

Expand Down
4 changes: 4 additions & 0 deletions speasy/core/impex/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class MissingCredentials(Exception):

class UnavailableEndpoint(Exception):
pass


class BadTemplateArgDefinition(Exception):
pass
39 changes: 36 additions & 3 deletions speasy/core/impex/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from ...core import fix_name
from ...core.inventory.indexes import (CatalogIndex, ComponentIndex,
DatasetIndex, ParameterIndex,
SpeasyIndex, TimetableIndex)
ArgumentListIndex, ArgumentIndex,
DerivedParameterIndex, SpeasyIndex,
TimetableIndex)


def to_xmlid(index_or_str) -> str:
Expand Down Expand Up @@ -60,8 +62,13 @@ def make_dataset_node(parent, node, provider_name, name_key, is_public: bool = T

@staticmethod
def make_parameter_node(parent, node, provider_name, name_key, is_public: bool = True):
param = ImpexXMLParser.make_any_node(parent, node, provider_name, ParameterIndex, name_key=name_key,
is_public=is_public)
if arguments:=node.find('.//arguments'):
arguments.set('name', '__spz_arguments__')
param = ImpexXMLParser.make_any_node(parent, node, provider_name, DerivedParameterIndex, name_key=name_key,
is_public=is_public)
else:
param = ImpexXMLParser.make_any_node(parent, node, provider_name, ParameterIndex, name_key=name_key,
is_public=is_public)
if isinstance(parent, DatasetIndex):
param.start_date = parent.start_date
param.stop_date = parent.stop_date
Expand Down Expand Up @@ -103,6 +110,30 @@ def make_path_node(parent, node, provider_name, name_key, is_public: bool = True
return ImpexXMLParser.make_any_node(parent, node, provider_name, SpeasyIndex, name_key=name_key,
is_public=is_public)

@staticmethod
def parse_template_arguments(parent, node, provider_name, name_key, is_public: bool = True):
return ImpexXMLParser.make_any_node(parent, node, provider_name, ArgumentListIndex, name_key=name_key,
is_public=is_public)

@staticmethod
def parse_template_argument(parent, node, provider_name, name_key, is_public: bool = True):
if node.get('type') == 'list':
choices = []
for item in node.findall('.//item'):
choices.append((item.get('name'), item.get('key')))
node.remove(item)
node.set('choices', choices)
elif node.get('type') == 'generated-list':
node.set('type', 'list')
choices = []
for k in range(int(node.get('minkey')), int(node.get('maxkey'))):
choices.append((node.get('nametpl').replace('##key##', str(k), 1), str(k)))
node.set('choices', choices)

return ImpexXMLParser.make_any_node(parent, node, provider_name, ArgumentIndex, name_key=name_key,
is_public=is_public)


@staticmethod
def parse(xml, provider_name, name_mapping=None, is_public: bool = True):
handlers = {
Expand All @@ -117,6 +148,8 @@ def parse(xml, provider_name, name_mapping=None, is_public: bool = True):
'timetab': ImpexXMLParser.make_timetable_node,
'catalog': ImpexXMLParser.make_catalog_node,
'param': ImpexXMLParser.make_user_parameter_node,
'arguments': ImpexXMLParser.parse_template_arguments,
'argument': ImpexXMLParser.parse_template_argument
}

def _recursive_parser(parent, node, is_node_public):
Expand Down
3 changes: 2 additions & 1 deletion speasy/core/inventory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Callable
from .indexes import ParameterIndex, DatasetIndex, TimetableIndex, ComponentIndex, CatalogIndex, SpeasyIndex
from .indexes import ParameterIndex, DatasetIndex, TimetableIndex, ComponentIndex, CatalogIndex, SpeasyIndex,DerivedParameterIndex


class ProviderInventory:
Expand All @@ -23,6 +23,7 @@ def __init__(self):
self.components = {}
self._type_lookup = {
ParameterIndex: lambda node: self.parameters.__setitem__(node.spz_uid(), node),
DerivedParameterIndex: lambda node: self.parameters.__setitem__(node.spz_uid(), node),
DatasetIndex: lambda node: self.datasets.__setitem__(node.spz_uid(), node),
TimetableIndex: lambda node: self.timetables.__setitem__(node.spz_uid(), node),
ComponentIndex: lambda node: self.components.__setitem__(node.spz_uid(), node),
Expand Down
43 changes: 42 additions & 1 deletion speasy/core/inventory/indexes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Optional, Union

__INDEXES_TYPES__ = {}

Expand Down Expand Up @@ -85,6 +85,45 @@ def __contains__(self, item: str or ComponentIndex):
return True
return False

class ArgumentIndex(SpeasyIndex):
def __init__(self, name: str, provider: str, uid: str, meta: Optional[dict] = None):
super().__init__(name, provider, uid, meta)

def __repr__(self):
return f'<ArgumentIndex: {self.spz_name()}>'

class ArgumentListIndex(SpeasyIndex):
def __init__(self, name: str, provider: str, uid: str, meta: Optional[dict] = None):
super().__init__(name, provider, uid, meta)

@property
def _arguments(self):
return [v for v in self.__dict__.values() if isinstance(v, ArgumentIndex)]

def __repr__(self):
return f'<ArgumentListIndex: {self.spz_name()}>'

def __getitem__(self, item)->ArgumentIndex:
return self._arguments[item]

def __len__(self):
return len(self._arguments)

def __iter__(self):
return self._arguments.__iter__()

class DerivedParameterIndex(ParameterIndex):
__spz_arguments__: ArgumentListIndex

def __init__(self, name: str, provider: str, uid: str, meta: Optional[dict] = None):
super().__init__(name, provider, uid, meta)

def spz_arguments(self):
return self.__spz_arguments__

def __repr__(self):
return f'<DerivedParameterIndex: {self.spz_name()}>'


class DatasetIndex(SpeasyIndex):
def __init__(self, name: str, provider: str, uid: str, meta: Optional[dict] = None):
Expand Down Expand Up @@ -151,3 +190,5 @@ def inventory_has_changed(orig, new):
if orig.__dict__[orig_key] != new.__dict__[orig_key]:
return True
return False

AnyProductIndex = Union[ParameterIndex, DerivedParameterIndex, DatasetIndex, TimetableIndex, CatalogIndex, ComponentIndex]
Loading
Loading