Skip to content

Commit

Permalink
Sqlite works and all tests green
Browse files Browse the repository at this point in the history
  • Loading branch information
thorbjoernl committed Aug 30, 2024
1 parent 5df6a31 commit 7532b94
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 17 deletions.
112 changes: 95 additions & 17 deletions src/aerovaldb/sqlitedb/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class AerovalSqliteDB(AerovalDB):
"forecast": extract_substitutions(ROUTE_FORECAST),
"gridded_map": extract_substitutions(ROUTE_GRIDDED_MAP),
"report": extract_substitutions(ROUTE_REPORT),
"reportimages": extract_substitutions(ROUTE_REPORT_IMAGE),
}

TABLE_NAME_TO_ROUTE = {
Expand All @@ -128,6 +129,7 @@ class AerovalSqliteDB(AerovalDB):
"forecast": ROUTE_FORECAST,
"gridded_map": ROUTE_GRIDDED_MAP,
"report": ROUTE_REPORT,
"reportimages": ROUTE_REPORT_IMAGE,
}

def __init__(self, database: str, /, **kwargs):
Expand Down Expand Up @@ -208,6 +210,7 @@ def __init__(self, database: str, /, **kwargs):
ROUTE_FORECAST: "forecast",
ROUTE_GRIDDED_MAP: "gridded_map",
ROUTE_REPORT: "report",
ROUTE_REPORT_IMAGE: "reportimages",
},
version_provider=self._get_version,
)
Expand Down Expand Up @@ -304,18 +307,31 @@ def _initialize_db(self):
args = AerovalSqliteDB.TABLE_COLUMN_NAMES[table_name]

column_names = ",".join(args)

cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name}(
{column_names},
ctime TEXT,
mtime TEXT,
json TEXT,
UNIQUE({column_names}))
"""
)
if table_name == "reportimages":
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name}(
{column_names},
ctime TEXT,
mtime TEXT,
blob BLOB,
UNIQUE({column_names})
)
"""
)
else:
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name}(
{column_names},
ctime TEXT,
mtime TEXT,
json TEXT,
UNIQUE({column_names}))
"""
)

cur.execute(
f"""
Expand Down Expand Up @@ -391,7 +407,7 @@ async def _get(self, route, route_args, **kwargs):
raise FileNotFoundError("Object not found")
for r in fetched:
for k in r.keys():
if k in ("json", "ctime", "mtime"):
if k in ("json", "blob", "ctime", "mtime"):
continue
if not (k in route_args | kwargs) and r[k] is not None:
break
Expand Down Expand Up @@ -476,6 +492,11 @@ async def get_by_uri(

route, route_args, kwargs = parse_uri(uri)

if route == ROUTE_REPORT_IMAGE:
return await self.get_report_image(
route_args["project"], route_args["experiment"], route_args["path"]
)

return await self._get(
route,
route_args,
Expand All @@ -488,6 +509,11 @@ async def get_by_uri(
@async_and_sync
async def put_by_uri(self, obj, uri: str):
route, route_args, kwargs = parse_uri(uri)
if route == ROUTE_REPORT_IMAGE:
await self.put_report_image(
obj, route_args["project"], route_args["experiment"], route_args["path"]
)
return

await self._put(obj, route, route_args, **kwargs)

Expand All @@ -510,14 +536,20 @@ async def list_all(self):
route_args = {}
kwargs = {}
for k in r.keys():
if k in ["json", "ctime", "mtime"]:
if k in ["json", "blob", "ctime", "mtime"]:
continue
if k in arg_names:
route_args[k] = r[k]
else:
kwargs[k] = r[k]

uri = build_uri(route, route_args, kwargs)
if route == ROUTE_REPORT_IMAGE:
for k, v in route_args.items():
route_args[k] = v.replace("/", ":")

uri = build_uri(route, route_args, kwargs)
else:
uri = build_uri(route, route_args, kwargs)
result.append(uri)
return result

Expand Down Expand Up @@ -564,7 +596,7 @@ def list_glob_stats(
route_args = {}
kwargs = {}
for k in r.keys():
if k in ["json", "ctime", "mtime"]:
if k in ["json", "blob", "ctime", "mtime"]:
continue

if k in arg_names:
Expand Down Expand Up @@ -607,7 +639,7 @@ async def list_timeseries(
route_args = {}
kwargs = {}
for k in r.keys():
if k in ["json", "ctime", "mtime"]:
if k in ["json", "blob", "ctime", "mtime"]:
continue

if k in arg_names:
Expand Down Expand Up @@ -649,3 +681,49 @@ def rm_experiment_data(self, project: str, experiment: str) -> None:
""",
(project, experiment),
)

@async_and_sync
async def get_report_image(
self,
project: str,
experiment: str,
path: str,
access_type: str | AccessType = AccessType.BLOB,
):
access_type = self._normalize_access_type(access_type)

if access_type != AccessType.BLOB:
raise UnsupportedOperation(
f"Sqlitedb does not support accesstype {access_type}."
)

cur = self._con.cursor()
cur.execute(
"""
SELECT * FROM reportimages
WHERE
(project, experiment, path) = (?, ?, ?)
""",
(project, experiment, path),
)
fetched = cur.fetchone()

if fetched is None:
raise FileNotFoundError(f"Object not found. {project, experiment, path}")

return fetched["blob"]

@async_and_sync
async def put_report_image(self, obj, project: str, experiment: str, path: str):
cur = self._con.cursor()

if not isinstance(obj, bytes):
raise TypeError(f"Expected bytes. Got {type(obj)}")

cur.execute(
"""
INSERT OR REPLACE INTO reportimages(project, experiment, path, blob)
VALUES(?, ?, ?, ?)
""",
(project, experiment, path, obj),
)
Binary file modified tests/test-db/sqlite/test.sqlite
Binary file not shown.

0 comments on commit 7532b94

Please sign in to comment.