diff --git a/setup.cfg b/setup.cfg index 7e93755..d7a3dd8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ packages = aerovaldb tests_require = tox:tox pytest + [options.entry_points] aerovaldb = @@ -50,6 +51,7 @@ commands_pre = python --version deps = pytest + pytest_asyncio commands = python -m pytest . diff --git a/src/aerovaldb/aerovaldb.py b/src/aerovaldb/aerovaldb.py index c4f1c14..212c06c 100644 --- a/src/aerovaldb/aerovaldb.py +++ b/src/aerovaldb/aerovaldb.py @@ -1,6 +1,7 @@ import abc import functools import inspect +import aiofile def get_method(route): @@ -13,7 +14,7 @@ def get_method(route): def wrap(wrapped): @functools.wraps(wrapped) - def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): sig = inspect.signature(wrapped) route_args = {} for pos, par in enumerate(sig.parameters): @@ -29,7 +30,7 @@ def wrapper(self, *args, **kwargs): f"{wrapped.__name__} got less parameters as expected (>= {len(route_args)+2}): {iex}" ) - return self._get(route, route_args, *args, **kwargs) + return await self._get(route, route_args, *args, **kwargs) return wrapper @@ -84,7 +85,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def _get(self, route: str, route_args: dict[str, str], *args, **kwargs): + async def _get(self, route: str, route_args: dict[str, str], *args, **kwargs): """Abstract implementation of the main getter functions. All get and put functions map to this function, with a corresponding route as key to enable key/value pair put and get functionality. @@ -106,7 +107,7 @@ def _put(self, obj, route: str, route_args: dict[str, str], *args, **kwargs): raise NotImplementedError @get_method("/v0/glob_stats/{project}/{experiment}/{frequency}") - def get_glob_stats( + async def get_glob_stats( self, project: str, experiment: str, frequency: str, /, *args, **kwargs ): """Fetches a glob_stats object from the database. @@ -135,7 +136,7 @@ def put_glob_stats( raise NotImplementedError @get_method("/v0/contour/{project}/{experiment}/{obsvar}/{model}") - def get_contour( + async def get_contour( self, project: str, experiment: str, obsvar: str, model: str, /, *args, **kwargs ): """Fetch a contour object from the db. @@ -170,7 +171,7 @@ def put_contour( raise NotImplementedError @get_method("/v0/ts/{project}/{experiment}/{region}/{network}/{obsvar}/{layer}") - def get_ts( + async def get_ts( self, project: str, experiment: str, @@ -230,7 +231,7 @@ def put_ts( @get_method( "/v0/ts_weekly/{project}/{experiment}/{station}_{network}-{obsvar}_{layer}" ) - def get_ts_weekly( + async def get_ts_weekly( self, project: str, experiment: str, @@ -282,7 +283,7 @@ def put_ts_weekly( raise NotImplementedError @get_method("/v0/experiments/{project}") - def get_experiments(self, project: str, /, *args, **kwargs): + async def get_experiments(self, project: str, /, *args, **kwargs): """Fetches a list of experiments for a project from the db. :param project: Project ID. @@ -299,7 +300,7 @@ def put_experiments(self, obj, project: str, /, *args, **kwargs): raise NotImplementedError @get_method("/v0/config/{project}/{experiment}") - def get_config(self, project: str, experiment: str, /, *args, **kwargs): + async def get_config(self, project: str, experiment: str, /, *args, **kwargs): """Fetches a configuration from the db. :param project: Project ID. @@ -319,7 +320,7 @@ def put_config(self, obj, project: str, experiment: str, /, *args, **kwargs): raise NotImplementedError @get_method("/v0/menu/{project}/{experiment}") - def get_menu(self, project: str, experiment: str, /, *args, **kwargs): + async def get_menu(self, project: str, experiment: str, /, *args, **kwargs): """Fetches a menu configuartion from the db. :param project: Project ID. @@ -338,7 +339,7 @@ def put_menu(self, obj, project: str, experiment: str, /, *args, **kwargs): raise NotImplementedError @get_method("/v0/statistics/{project}/{experiment}") - def get_statistics(self, project: str, experiment: str, /, *args, **kwargs): + async def get_statistics(self, project: str, experiment: str, /, *args, **kwargs): """Fetches statistics for an experiment. :param project: Project ID. @@ -357,7 +358,7 @@ def put_statistics(self, obj, project: str, experiment: str, /, *args, **kwargs) raise NotImplementedError @get_method("/v0/ranges/{project}/{experiment}") - def get_ranges(self, project: str, experiment: str, /, *args, **kwargs): + async def get_ranges(self, project: str, experiment: str, /, *args, **kwargs): """Fetches ranges from the db. :param project: Project ID. @@ -376,7 +377,7 @@ def put_ranges(self, obj, project: str, experiment: str, /, *args, **kwargs): raise NotImplementedError @get_method("/v0/regions/{project}/{experiment}") - def get_regions(self, project: str, experiment: str, /, *args, **kwargs): + async def get_regions(self, project: str, experiment: str, /, *args, **kwargs): """Fetches regions from db. :param project: Project ID. @@ -395,7 +396,7 @@ def put_regions(self, obj, project: str, experiment: str, /, *args, **kwargs): raise NotImplementedError @get_method("/v0/model_style/{project}") - def get_models_style(self, project: str, /, *args, **kwargs): + async def get_models_style(self, project: str, /, *args, **kwargs): """Fetches model styles from db. :param project: Project ID. @@ -416,7 +417,7 @@ def put_models_style(self, obj, project: str, /, *args, **kwargs): @get_method( "/v0/map/{project}/{experiment}/{network}/{obsvar}/{layer}/{model}/{modvar}" ) - def get_map( + async def get_map( self, project: str, experiment: str, @@ -476,7 +477,7 @@ def put_map( @get_method( "/v0/scat/{project}/{experiment}/{network}-{obsvar}_{layer}_{model}-{modvar}" ) - def get_scat( + async def get_scat( self, project: str, experiment: str, @@ -534,7 +535,7 @@ def put_scat( raise NotImplementedError @get_method("/v0/profiles/{project}/{experiment}/{station}/{network}/{obsvar}") - def get_profiles( + async def get_profiles( self, project: str, experiment: str, @@ -578,7 +579,7 @@ def put_profiles( raise NotImplementedError @get_method("/v0/hm_ts/{project}/{experiment}") - def get_hm_ts( + async def get_hm_ts( self, project: str, experiment: str, @@ -622,7 +623,7 @@ def put_hm_ts( @get_method( "/v0/forecast/{project}/{experiment}/{station}/{network}/{obsvar}/{layer}" ) - def get_forecast( + async def get_forecast( self, project: str, experiment: str, @@ -674,7 +675,7 @@ def put_forecast( raise NotImplementedError @get_method("/v0/gridded_map/{project}/{experiment}/{obsvar}/{model}") - def get_gridded_map( + async def get_gridded_map( self, project: str, experiment: str, obsvar: str, model: str, /, *args, **kwargs ): """Fetches gridded map. @@ -709,7 +710,9 @@ def put_gridded_map( raise NotImplementedError @get_method("/v0/report/{project}/{experiment}/{title}") - def get_report(self, project: str, experiment: str, title: str, /, *args, **kwargs): + async def get_report( + self, project: str, experiment: str, title: str, /, *args, **kwargs + ): """Fetch report. :param project: Project ID. diff --git a/src/aerovaldb/jsonfiledb.py b/src/aerovaldb/jsonfiledb.py index 0f03d9b..418e40c 100644 --- a/src/aerovaldb/jsonfiledb.py +++ b/src/aerovaldb/jsonfiledb.py @@ -5,6 +5,7 @@ import os import json import orjson +import aiofile from enum import Enum AccessType = Enum("AccessType", ["JSON_STR", "FILE_PATH", "OBJ"]) @@ -102,7 +103,7 @@ def _get_file_path_from_route(self, route, route_args, /, *args, **kwargs): raise ValueError("Error in relative path resolution.") return Path(os.path.join(self._basedir, relative_path)).resolve() - def _get(self, route, route_args, *args, **kwargs): + async def _get(self, route, route_args, *args, **kwargs): access_type = self._normalize_access_type(kwargs.get("access_type", None)) file_path = self._get_file_path_from_route(route, route_args, **kwargs) @@ -114,13 +115,15 @@ def _get(self, route, route_args, *args, **kwargs): return str(file_path) if access_type == AccessType.JSON_STR: - with open(file_path, "rb") as f: - json = str(f.read()) + async with aiofile.async_open(file_path, "rb") as f: + raw = await f.read() + + json = str(raw) return json - with open(file_path, "rb") as f: - raw = f.read() + async with aiofile.async_open(file_path, "rb") as f: + raw = await f.read() return orjson.loads(raw) diff --git a/tests/test_jsonfiledb.py b/tests/test_jsonfiledb.py index d9583bf..068f4ae 100644 --- a/tests/test_jsonfiledb.py +++ b/tests/test_jsonfiledb.py @@ -1,7 +1,11 @@ import pytest import aerovaldb +import asyncio +pytest_plugins = ("pytest_asyncio",) + +@pytest.mark.asyncio @pytest.mark.parametrize("resource", (("json_files:./tests/test-db/json",))) @pytest.mark.parametrize( "fun,args,kwargs,expected", @@ -93,7 +97,7 @@ "obsvar": "obsvar", "layer": "layer", }, - "./project/experiment/hm/ts/network-obsvar-layer" + "./project/experiment/hm/ts/network-obsvar-layer", ), ( "get_hm_ts", @@ -104,7 +108,7 @@ "layer": "layer", "station": "region", }, - "./project/experiment/hm/ts/" + "./project/experiment/hm/ts/", ), ( "get_forecast", @@ -126,43 +130,48 @@ ), ), ) -def test_getter(resource: str, fun: str, args: list, kwargs: dict, expected): +async def test_getter(resource: str, fun: str, args: list, kwargs: dict, expected): with aerovaldb.open(resource) as db: f = getattr(db, fun) if kwargs is not None: - data = f(*args, **kwargs) + data = await f(*args, **kwargs) else: - data = f(*args) + data = await f(*args) assert data["path"] == expected -def test_put_glob_stats(): +@pytest.mark.asyncio +async def test_put_glob_stats(): # TODO: These tests should ideally cleanup after themselves. For now # it is best to delete ./tests/test-db/tmp before running to verify # that they run as intended. with aerovaldb.open("json_files:./tests/test-db/tmp") as db: obj = {"data": "gibberish"} db.put_glob_stats(obj, "test1", "test2", "test3") - read_data = db.get_glob_stats("test1", "test2", "test3") + read_data = await db.get_glob_stats("test1", "test2", "test3") assert obj["data"] == read_data["data"] -def test_put_contour(): +@pytest.mark.asyncio +async def test_put_contour(): with aerovaldb.open("json_files:./tests/test-db/tmp") as db: obj = {"data": "gibberish"} db.put_contour(obj, "test1", "test2", "test3", "test4") - read_data = db.get_contour("test1", "test2", "test3", "test4") + read_data = await db.get_contour("test1", "test2", "test3", "test4") assert obj["data"] == read_data["data"] -def test_put_ts(): +@pytest.mark.asyncio +async def test_put_ts(): with aerovaldb.open("json_files:./tests/test-db/tmp") as db: obj = {"data": "gibberish"} db.put_ts(obj, "test1", "test2", "test3", "test4", "test5", "test6") - read_data = db.get_ts("test1", "test2", "test3", "test4", "test5", "test6") + read_data = await db.get_ts( + "test1", "test2", "test3", "test4", "test5", "test6" + ) assert obj["data"] == read_data["data"]