Skip to content

Commit

Permalink
feat: Asynchronous getters
Browse files Browse the repository at this point in the history
  • Loading branch information
thorbjoernl committed May 27, 2024
1 parent 528d39c commit 35b4d74
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 37 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ packages = aerovaldb
tests_require =
tox:tox
pytest


[options.entry_points]
aerovaldb =
Expand All @@ -50,6 +51,7 @@ commands_pre =
python --version
deps =
pytest
pytest_asyncio
commands =
python -m pytest .

45 changes: 24 additions & 21 deletions src/aerovaldb/aerovaldb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import functools
import inspect
import aiofile


def get_method(route):
Expand All @@ -13,7 +14,7 @@ def get_method(route):

def wrap(wrapped):
@functools.wraps(wrapped)
def wrapper(self, *args, **kwargs):
async def wrapper(self, *args, **kwargs):
sig = inspect.signature(wrapped)
route_args = {}
for pos, par in enumerate(sig.parameters):
Expand All @@ -29,7 +30,7 @@ def wrapper(self, *args, **kwargs):
f"{wrapped.__name__} got less parameters as expected (>= {len(route_args)+2}): {iex}"
)

return self._get(route, route_args, *args, **kwargs)
return await self._get(route, route_args, *args, **kwargs)

return wrapper

Expand Down Expand Up @@ -84,7 +85,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
pass

def _get(self, route: str, route_args: dict[str, str], *args, **kwargs):
async def _get(self, route: str, route_args: dict[str, str], *args, **kwargs):
"""Abstract implementation of the main getter functions. All get and put
functions map to this function, with a corresponding route as key
to enable key/value pair put and get functionality.
Expand All @@ -106,7 +107,7 @@ def _put(self, obj, route: str, route_args: dict[str, str], *args, **kwargs):
raise NotImplementedError

@get_method("/v0/glob_stats/{project}/{experiment}/{frequency}")
def get_glob_stats(
async def get_glob_stats(
self, project: str, experiment: str, frequency: str, /, *args, **kwargs
):
"""Fetches a glob_stats object from the database.
Expand Down Expand Up @@ -135,7 +136,7 @@ def put_glob_stats(
raise NotImplementedError

@get_method("/v0/contour/{project}/{experiment}/{obsvar}/{model}")
def get_contour(
async def get_contour(
self, project: str, experiment: str, obsvar: str, model: str, /, *args, **kwargs
):
"""Fetch a contour object from the db.
Expand Down Expand Up @@ -170,7 +171,7 @@ def put_contour(
raise NotImplementedError

@get_method("/v0/ts/{project}/{experiment}/{region}/{network}/{obsvar}/{layer}")
def get_ts(
async def get_ts(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -230,7 +231,7 @@ def put_ts(
@get_method(
"/v0/ts_weekly/{project}/{experiment}/{station}_{network}-{obsvar}_{layer}"
)
def get_ts_weekly(
async def get_ts_weekly(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -282,7 +283,7 @@ def put_ts_weekly(
raise NotImplementedError

@get_method("/v0/experiments/{project}")
def get_experiments(self, project: str, /, *args, **kwargs):
async def get_experiments(self, project: str, /, *args, **kwargs):
"""Fetches a list of experiments for a project from the db.
:param project: Project ID.
Expand All @@ -299,7 +300,7 @@ def put_experiments(self, obj, project: str, /, *args, **kwargs):
raise NotImplementedError

@get_method("/v0/config/{project}/{experiment}")
def get_config(self, project: str, experiment: str, /, *args, **kwargs):
async def get_config(self, project: str, experiment: str, /, *args, **kwargs):
"""Fetches a configuration from the db.
:param project: Project ID.
Expand All @@ -319,7 +320,7 @@ def put_config(self, obj, project: str, experiment: str, /, *args, **kwargs):
raise NotImplementedError

@get_method("/v0/menu/{project}/{experiment}")
def get_menu(self, project: str, experiment: str, /, *args, **kwargs):
async def get_menu(self, project: str, experiment: str, /, *args, **kwargs):
"""Fetches a menu configuartion from the db.
:param project: Project ID.
Expand All @@ -338,7 +339,7 @@ def put_menu(self, obj, project: str, experiment: str, /, *args, **kwargs):
raise NotImplementedError

@get_method("/v0/statistics/{project}/{experiment}")
def get_statistics(self, project: str, experiment: str, /, *args, **kwargs):
async def get_statistics(self, project: str, experiment: str, /, *args, **kwargs):
"""Fetches statistics for an experiment.
:param project: Project ID.
Expand All @@ -357,7 +358,7 @@ def put_statistics(self, obj, project: str, experiment: str, /, *args, **kwargs)
raise NotImplementedError

@get_method("/v0/ranges/{project}/{experiment}")
def get_ranges(self, project: str, experiment: str, /, *args, **kwargs):
async def get_ranges(self, project: str, experiment: str, /, *args, **kwargs):
"""Fetches ranges from the db.
:param project: Project ID.
Expand All @@ -376,7 +377,7 @@ def put_ranges(self, obj, project: str, experiment: str, /, *args, **kwargs):
raise NotImplementedError

@get_method("/v0/regions/{project}/{experiment}")
def get_regions(self, project: str, experiment: str, /, *args, **kwargs):
async def get_regions(self, project: str, experiment: str, /, *args, **kwargs):
"""Fetches regions from db.
:param project: Project ID.
Expand All @@ -395,7 +396,7 @@ def put_regions(self, obj, project: str, experiment: str, /, *args, **kwargs):
raise NotImplementedError

@get_method("/v0/model_style/{project}")
def get_models_style(self, project: str, /, *args, **kwargs):
async def get_models_style(self, project: str, /, *args, **kwargs):
"""Fetches model styles from db.
:param project: Project ID.
Expand All @@ -416,7 +417,7 @@ def put_models_style(self, obj, project: str, /, *args, **kwargs):
@get_method(
"/v0/map/{project}/{experiment}/{network}/{obsvar}/{layer}/{model}/{modvar}"
)
def get_map(
async def get_map(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -476,7 +477,7 @@ def put_map(
@get_method(
"/v0/scat/{project}/{experiment}/{network}-{obsvar}_{layer}_{model}-{modvar}"
)
def get_scat(
async def get_scat(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -534,7 +535,7 @@ def put_scat(
raise NotImplementedError

@get_method("/v0/profiles/{project}/{experiment}/{station}/{network}/{obsvar}")
def get_profiles(
async def get_profiles(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -578,7 +579,7 @@ def put_profiles(
raise NotImplementedError

@get_method("/v0/hm_ts/{project}/{experiment}")
def get_hm_ts(
async def get_hm_ts(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -622,7 +623,7 @@ def put_hm_ts(
@get_method(
"/v0/forecast/{project}/{experiment}/{station}/{network}/{obsvar}/{layer}"
)
def get_forecast(
async def get_forecast(
self,
project: str,
experiment: str,
Expand Down Expand Up @@ -674,7 +675,7 @@ def put_forecast(
raise NotImplementedError

@get_method("/v0/gridded_map/{project}/{experiment}/{obsvar}/{model}")
def get_gridded_map(
async def get_gridded_map(
self, project: str, experiment: str, obsvar: str, model: str, /, *args, **kwargs
):
"""Fetches gridded map.
Expand Down Expand Up @@ -709,7 +710,9 @@ def put_gridded_map(
raise NotImplementedError

@get_method("/v0/report/{project}/{experiment}/{title}")
def get_report(self, project: str, experiment: str, title: str, /, *args, **kwargs):
async def get_report(
self, project: str, experiment: str, title: str, /, *args, **kwargs
):
"""Fetch report.
:param project: Project ID.
Expand Down
13 changes: 8 additions & 5 deletions src/aerovaldb/jsonfiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import json
import orjson
import aiofile
from enum import Enum

AccessType = Enum("AccessType", ["JSON_STR", "FILE_PATH", "OBJ"])
Expand Down Expand Up @@ -102,7 +103,7 @@ def _get_file_path_from_route(self, route, route_args, /, *args, **kwargs):
raise ValueError("Error in relative path resolution.")
return Path(os.path.join(self._basedir, relative_path)).resolve()

def _get(self, route, route_args, *args, **kwargs):
async def _get(self, route, route_args, *args, **kwargs):
access_type = self._normalize_access_type(kwargs.get("access_type", None))

file_path = self._get_file_path_from_route(route, route_args, **kwargs)
Expand All @@ -114,13 +115,15 @@ def _get(self, route, route_args, *args, **kwargs):
return str(file_path)

if access_type == AccessType.JSON_STR:
with open(file_path, "rb") as f:
json = str(f.read())
async with aiofile.async_open(file_path, "rb") as f:
raw = await f.read()

json = str(raw)

return json

with open(file_path, "rb") as f:
raw = f.read()
async with aiofile.async_open(file_path, "rb") as f:
raw = await f.read()

return orjson.loads(raw)

Expand Down
31 changes: 20 additions & 11 deletions tests/test_jsonfiledb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import aerovaldb
import asyncio

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
@pytest.mark.parametrize("resource", (("json_files:./tests/test-db/json",)))
@pytest.mark.parametrize(
"fun,args,kwargs,expected",
Expand Down Expand Up @@ -93,7 +97,7 @@
"obsvar": "obsvar",
"layer": "layer",
},
"./project/experiment/hm/ts/network-obsvar-layer"
"./project/experiment/hm/ts/network-obsvar-layer",
),
(
"get_hm_ts",
Expand All @@ -104,7 +108,7 @@
"layer": "layer",
"station": "region",
},
"./project/experiment/hm/ts/"
"./project/experiment/hm/ts/",
),
(
"get_forecast",
Expand All @@ -126,43 +130,48 @@
),
),
)
def test_getter(resource: str, fun: str, args: list, kwargs: dict, expected):
async def test_getter(resource: str, fun: str, args: list, kwargs: dict, expected):
with aerovaldb.open(resource) as db:
f = getattr(db, fun)

if kwargs is not None:
data = f(*args, **kwargs)
data = await f(*args, **kwargs)
else:
data = f(*args)
data = await f(*args)

assert data["path"] == expected


def test_put_glob_stats():
@pytest.mark.asyncio
async def test_put_glob_stats():
# TODO: These tests should ideally cleanup after themselves. For now
# it is best to delete ./tests/test-db/tmp before running to verify
# that they run as intended.
with aerovaldb.open("json_files:./tests/test-db/tmp") as db:
obj = {"data": "gibberish"}
db.put_glob_stats(obj, "test1", "test2", "test3")
read_data = db.get_glob_stats("test1", "test2", "test3")
read_data = await db.get_glob_stats("test1", "test2", "test3")

assert obj["data"] == read_data["data"]


def test_put_contour():
@pytest.mark.asyncio
async def test_put_contour():
with aerovaldb.open("json_files:./tests/test-db/tmp") as db:
obj = {"data": "gibberish"}
db.put_contour(obj, "test1", "test2", "test3", "test4")
read_data = db.get_contour("test1", "test2", "test3", "test4")
read_data = await db.get_contour("test1", "test2", "test3", "test4")
assert obj["data"] == read_data["data"]


def test_put_ts():
@pytest.mark.asyncio
async def test_put_ts():
with aerovaldb.open("json_files:./tests/test-db/tmp") as db:
obj = {"data": "gibberish"}
db.put_ts(obj, "test1", "test2", "test3", "test4", "test5", "test6")

read_data = db.get_ts("test1", "test2", "test3", "test4", "test5", "test6")
read_data = await db.get_ts(
"test1", "test2", "test3", "test4", "test5", "test6"
)

assert obj["data"] == read_data["data"]

0 comments on commit 35b4d74

Please sign in to comment.