Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support amda template params #192

Merged
merged 7 commits into from
Feb 21, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP
brenard31 committed Jan 22, 2025
commit de2fa571f02c17d3116573dcc64daea9f263a4ff
27 changes: 17 additions & 10 deletions speasy/core/impex/__init__.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
from .parser import ImpexXMLParser, to_xmlid
from .client import ImpexClient, ImpexEndpoint
from .utils import load_catalog, load_timetable, is_private, is_public
from .exceptions import MissingCredentials
from .exceptions import MissingCredentials, MissingTemplateArgs


log = logging.getLogger(__name__)
@@ -381,6 +381,7 @@ def get_parameter(self, product, start_time, stop_time,
if hasattr(self, 'has_time_restriction') and self.has_time_restriction(product, start_time, stop_time):
kwargs['disable_proxy'] = True
kwargs['restricted_period'] = True

return self._get_parameter(product, start_time, stop_time, extra_http_headers=extra_http_headers,
output_format=output_format or self.client.output_format, **kwargs)

@@ -762,10 +763,14 @@ def _get_parameter(self, product, start_time, stop_time,
Optional[
SpeasyVariable]:
log.debug(f'Get data: product = {product}, data start time = {start_time}, data stop time = {stop_time}')
if hasattr(self, 'get_real_product_id'):
real_product_id = self.get_real_product_id(product, **kwargs)
if real_product_id:
kwargs['real_product_id'] = real_product_id
return self._dl_parameter(start_time=start_time, stop_time=stop_time, parameter_id=product,
extra_http_headers=extra_http_headers,
output_format=output_format,
product_variables=self._get_product_variables(product),
product_variables=self._get_product_variables(product, **kwargs),
restricted_period=restricted_period,
time_format='UNIXTIME', **kwargs)

@@ -775,13 +780,14 @@ def _dl_parameter_chunk(self, start_time: datetime, stop_time: datetime, paramet
product_variables: List = None, **kwargs) -> Optional[SpeasyVariable]:
url = self.client.get_parameter(start_time=start_time.strftime('%Y-%m-%dT%H:%M:%SZ'),
stop_time=stop_time.strftime('%Y-%m-%dT%H:%M:%SZ'),
parameter_id=parameter_id, extra_http_headers=extra_http_headers,
parameter_id=parameter_id,
extra_http_headers=extra_http_headers,
use_credentials=use_credentials, **kwargs)
# check status until done
if url is not None:
var = None
if not product_variables:
product_variables = [parameter_id]
product_variables = [kwargs.get('real_product_id', parameter_id)]
if kwargs.get('output_format', self.client.output_format) in ["CDF_ISTP", "CDF"]:
var = self._cdf_codec.load_variables(variables=product_variables, file=url)
else:
@@ -829,14 +835,15 @@ def _dl_parameter(self, start_time: datetime, stop_time: datetime, parameter_id:
curr_t += dt
return var
else:
return self._dl_parameter_chunk(start_time, stop_time, parameter_id, extra_http_headers=extra_http_headers,
return self._dl_parameter_chunk(start_time, stop_time, parameter_id,
extra_http_headers=extra_http_headers,
use_credentials=use_credentials,
product_variables=product_variables, **kwargs)

def _dl_user_parameter(self, start_time: datetime, stop_time: datetime, parameter_id: str,
**kwargs) -> Optional[SpeasyVariable]:
return self._dl_parameter(parameter_id=parameter_id, start_time=start_time, stop_time=stop_time,
product_variables=self._get_product_variables(parameter_id),
product_variables=self._get_product_variables(parameter_id, **kwargs),
use_credentials=True, **kwargs)

def _dl_timetable(self, timetable_id: str, use_credentials=False, **kwargs):
@@ -872,9 +879,9 @@ def _dl_catalog(self, catalog_id: str, use_credentials=False, **kwargs):
def _dl_user_catalog(self, catalog_id: str, **kwargs):
return self._dl_catalog(catalog_id, use_credentials=True, **kwargs)

def _get_product_variables(self, product_id: str or SpeasyIndex):
def _get_product_variables(self, product_id: str or SpeasyIndex, **kwargs):
product_id = to_xmlid(product_id)
return [product_id]
return [kwargs.get('real_product_id', product_id)]

@staticmethod
def _concatenate_variables(variables: Dict[str, SpeasyVariable], product_id) -> Optional[SpeasyVariable]:
@@ -906,8 +913,8 @@ def _concatenate_variables(variables: Dict[str, SpeasyVariable], product_id) ->
values=DataContainer(values=values, meta=meta, name=product_id, is_time_dependent=True),
columns=columns)

def _get_obs_data_tree(self) -> str or None:
return self.client.get_obs_data_tree()
def _get_obs_data_tree(self, add_template_info=False) -> str or None:
return self.client.get_obs_data_tree(add_template_info=add_template_info)

def _get_timetables_tree(self) -> str or None:
return self.client.get_time_table_list()
9 changes: 6 additions & 3 deletions speasy/core/impex/client.py
Original file line number Diff line number Diff line change
@@ -75,8 +75,11 @@ def auth(self):
def in_progress(result):
return result == "in progress"

def get_obs_data_tree(self, use_credentials=False):
params = {}
def get_obs_data_tree(self, use_credentials=False, **kwargs):
params = {
}
if kwargs.get('add_template_info', False):
params['templateInfo'] = True
if use_credentials:
params['userID'], params['password'] = self.get_credentials()
return self._send_indirect_request(ImpexEndpoint.OBSTREE, params=params)
@@ -109,7 +112,7 @@ def get_parameter(self, start_time, stop_time, parameter_id, extra_http_headers=
params = {
'startTime': start_time,
'stopTime': stop_time,
'parameterID': parameter_id,
'parameterID': kwargs.get('real_product_id', parameter_id),
'outputFormat': kwargs.get('output_format', self.output_format)
}

8 changes: 8 additions & 0 deletions speasy/core/impex/exceptions.py
Original file line number Diff line number Diff line change
@@ -4,3 +4,11 @@ class MissingCredentials(Exception):

class UnavailableEndpoint(Exception):
pass


class MissingTemplateArgs(Exception):
pass


class BadTemplateArgDefinition(Exception):
pass
28 changes: 28 additions & 0 deletions speasy/core/impex/parser.py
Original file line number Diff line number Diff line change
@@ -103,6 +103,31 @@ def make_path_node(parent, node, provider_name, name_key, is_public: bool = True
return ImpexXMLParser.make_any_node(parent, node, provider_name, SpeasyIndex, name_key=name_key,
is_public=is_public)

@staticmethod
def parse_template_arguments(parent, node, provider_name, name_key, is_public: bool = True):
parent.__dict__['arguments'] = {}
return parent.arguments

@staticmethod
def parse_template_argument(parent, node, provider_name, name_key, is_public: bool = True):
parent[node.get('key')] = {
'name': node.get('name'),
'type': node.get('type'),
'default': node.get('default')
}
if parent[node.get('key')]['type'] == 'generated-list':
parent[node.get('key')]['type'] = 'list'
parent[node.get('key')]['items'] = {k: node.get('nametpl').replace('##key##', str(k))
for k in range(int(node.get('minkey')), int(node.get('maxkey')))}
elif parent[node.get('key')]['type'] == 'list':
parent[node.get('key')]['items'] = {}
return parent[node.get('key')]

@staticmethod
def parse_template_argument_item(parent, node, provider_name, name_key, is_public: bool = True):
parent['items'][node.get('key')] = node.get('name')
return {}

@staticmethod
def parse(xml, provider_name, name_mapping=None, is_public: bool = True):
handlers = {
@@ -117,6 +142,9 @@ def parse(xml, provider_name, name_mapping=None, is_public: bool = True):
'timetab': ImpexXMLParser.make_timetable_node,
'catalog': ImpexXMLParser.make_catalog_node,
'param': ImpexXMLParser.make_user_parameter_node,
'arguments': ImpexXMLParser.parse_template_arguments,
'argument': ImpexXMLParser.parse_template_argument,
'item': ImpexXMLParser.parse_template_argument_item,
}

def _recursive_parser(parent, node, is_node_public):
1 change: 1 addition & 0 deletions speasy/core/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,7 @@ def __init__(self):

@staticmethod
def get(path: str, start_time: str, stop_time: str, **kwargs):
print(kwargs)
url = proxy_cfg.url()
if url.endswith("/"):
url = url[:-1]
46 changes: 41 additions & 5 deletions speasy/webservices/amda/ws.py
Original file line number Diff line number Diff line change
@@ -14,11 +14,13 @@
from ...core.inventory.indexes import (CatalogIndex, ParameterIndex,
SpeasyIndex, TimetableIndex)
from ...core.proxy import PROXY_ALLOWED_KWARGS, GetProduct, Proxyfiable
from ...inventories import flat_inventories
from ...products.catalog import Catalog
from ...products.timetable import TimeTable
from ...products.variable import SpeasyVariable

from ...core.impex import ImpexProvider, ImpexEndpoint
from ...core.impex import ImpexProvider, ImpexEndpoint, to_xmlid
from ...core.impex.exceptions import MissingTemplateArgs, BadTemplateArgDefinition


log = logging.getLogger(__name__)
@@ -35,12 +37,42 @@
}


def _amda_replace_arguments_in_template(product: ParameterIndex, additional_arguments: Dict):
product_id = product.template
for k, v in product.arguments.items():
print(v)
if v['type'] == 'list':
if additional_arguments[k] not in v['items'].keys():
raise BadTemplateArgDefinition()
product_id = product_id.replace(f'##{k}##', str(additional_arguments[k]))

return product_id


def _amda_get_real_product_id(product_id: str or SpeasyIndex, **kwargs):
product_id = to_xmlid(product_id)
product = flat_inventories.__dict__[amda_provider_name].parameters[product_id]
if hasattr(product, 'template'):
additional_arguments = kwargs.get('additional_arguments', {})
if not hasattr(product, 'arguments'):
return product_id
real_product_id = product.template
for k, v in product.arguments.items():
if k not in additional_arguments:
raise MissingTemplateArgs()
real_product_id = _amda_replace_arguments_in_template(product, additional_arguments)
else:
real_product_id = product_id
return real_product_id


def _amda_cache_entry_name(prefix: str, product: str, start_time: str, **kwargs):
output_format: str = kwargs.get('output_format', 'cdf_istp')
real_product_id = _amda_get_real_product_id(product, **kwargs)
if output_format.lower() == 'cdf_istp':
return f"{prefix}/{product}-cdf_istp/{start_time}"
return f"{prefix}/{real_product_id}-cdf_istp/{start_time}"
else:
return f"{prefix}/{product}/{start_time}"
return f"{prefix}/{real_product_id}/{start_time}"


def _amda_get_proxy_parameter_args(start_time: datetime, stop_time: datetime, product: str, **kwargs) -> Dict:
@@ -120,6 +152,9 @@ def product_version(self, parameter_id: str or ParameterIndex):
return self.flat_inventory.datasets[dataset].lastModificationDate
return self.flat_inventory.datasets[dataset].lastUpdate

def get_real_product_id(self, product_id: str or SpeasyIndex, **kwargs):
return _amda_get_real_product_id(product_id, **kwargs)

@CacheCall(cache_retention=amda_cfg.user_cache_retention(), is_pure=True)
def get_timetable(self, timetable_id: str or TimetableIndex, **kwargs) -> Optional[TimeTable]:
"""Get timetable data by ID.
@@ -229,7 +264,8 @@ def get_user_catalog(self, catalog_id: str or CatalogIndex, **kwargs) -> Optiona
return super().get_user_catalog(catalog_id)

@AllowedKwargs(
PROXY_ALLOWED_KWARGS + CACHE_ALLOWED_KWARGS + GET_DATA_ALLOWED_KWARGS + ['output_format', 'restricted_period'])
PROXY_ALLOWED_KWARGS + CACHE_ALLOWED_KWARGS + GET_DATA_ALLOWED_KWARGS +
['output_format', 'restricted_period', 'additional_arguments'])
@EnsureUTCDateTime()
@ParameterRangeCheck()
@Cacheable(prefix=amda_provider_name, version=product_version, fragment_hours=lambda x: 12,
@@ -277,7 +313,7 @@ def _get_parameter(self, product, start_time, stop_time,

@CacheCall(cache_retention=24 * 60 * 60, is_pure=True)
def _get_obs_data_tree(self) -> str or None:
return super()._get_obs_data_tree()
return super()._get_obs_data_tree(add_template_info=True)

@CacheCall(cache_retention=amda_cfg.user_cache_retention(), is_pure=True)
def _get_timetables_tree(self) -> str or None: