Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/cache queries #21

Merged
merged 17 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/plexosdb/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ CREATE TABLE `t_category`

CREATE TABLE `t_object`
(
`object_id` INTEGER,
`object_id` INTEGER UNIQUE,
`class_id` INT NULL,
`name` VARCHAR(512) NULL COLLATE NOCASE,
`category_id` INT NULL,
Expand All @@ -241,6 +241,7 @@ CREATE TABLE `t_object`
`X` INT NULL,
`Y` INT NULL,
`Z` INT NULL,
UNIQUE (`class_id`, `name`)
CONSTRAINT PK_t_object
PRIMARY KEY (`object_id`)
);
Expand Down
69 changes: 57 additions & 12 deletions src/plexosdb/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
super().__init__()
self._conn = sqlite3.connect(":memory:")
self._sqlite_config()
self._QUERY_CACHE: dict[int, int] = {}

if create_collations:
self._create_collations()
self._create_table_schema()
Expand Down Expand Up @@ -459,14 +461,17 @@ def add_object(

params = (object_name, class_id, category_id, str(uuid.uuid4()), description)
placeholders = ", ".join("?" * len(params))
with self._conn as conn:
cursor = conn.cursor()
cursor.execute(
f"INSERT INTO {Schema.Objects.name}(name, class_id, category_id, GUID, description) "
f"VALUES({placeholders})",
params,
)
object_id = cursor.lastrowid
try:
with self._conn as conn:
cursor = conn.cursor()
cursor.execute(
f"INSERT INTO {Schema.Objects.name}(name, class_id, category_id, GUID, description) "
f"VALUES({placeholders})",
params,
)
object_id = cursor.lastrowid
except sqlite3.IntegrityError as err:
pesap marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(err)

if object_id is None:
raise TypeError("Could not fetch the last row of the insert. Check query format.")
Expand Down Expand Up @@ -837,6 +842,9 @@ def _get_id(
if conditions
else f" WHERE {table_name}.name = :object_name"
)
query_key = self._query_hash(query, params)
if query_key in self._QUERY_CACHE:
return self._QUERY_CACHE[query_key]

result = self.query(query, params)

Expand All @@ -847,7 +855,12 @@ def _get_id(
if len(result) > 1:
msg = f"Multiple ids returned for {object_name} and {class_name}. Try passing addtional filters"
raise ValueError(msg)
return result[0][0] # Get first element of tuple

ret: int = result[0][0] # Get first element of tuple

self._QUERY_CACHE[query_key] = ret

return ret

def get_membership_id(
self,
Expand Down Expand Up @@ -1054,8 +1067,6 @@ def query(self, query_string: str, params=None) -> list[tuple]:
String to get passed to the database connector.
params
Tuple or dict for passing
fetchone
Return firstrow

Note
----
Expand All @@ -1065,7 +1076,9 @@ def query(self, query_string: str, params=None) -> list[tuple]:
"""
with self._conn as conn:
res = conn.execute(query_string, params) if params else conn.execute(query_string)
return res.fetchall()
ret = res.fetchall()

return ret

def ingest_from_records(self, tag: str, record_data: Sequence):
"""Insert elements from xml to database."""
Expand All @@ -1086,8 +1099,11 @@ def ingest_from_records(self, tag: str, record_data: Sequence):
try:
with self._conn as conn:
conn.execute(ingestion_sql, record)
except sqlite3.IntegrityError as err:
raise ValueError(err)
except sqlite3.Error as err:
raise err

logger.trace("Finished ingesting {}", tag)
return

Expand Down Expand Up @@ -1235,3 +1251,32 @@ def _create_collations(self) -> None:
"""Add collate function for helping search enums."""
self._conn.create_collation("NOSPACE", no_space)
return

@staticmethod
def _query_hash(query_string: str, params: tuple | dict | None = None) -> int:
"""
Create a hash int for a query string and params dictionary.

Parameters
----------
query_str
String to get passed to the database connector.
params
Tuple or dict for passing

Returns
-------
Int
likely unique integer for given query_string and params object
"""
if params is None:
return hash(query_string)
if isinstance(params, dict):
return hash((query_string, str(params)))
if isinstance(params, list):
return hash((query_string, *params))
return hash((query_string, params))

def clear_id_cache(self):
"""Clear the cache for the _get_id method."""
self._QUERY_CACHE.clear()
2 changes: 1 addition & 1 deletion tests/data/plexosdb.xml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
<t_object>
<object_id>3</object_id>
<class_id>2</class_id>
<name>SolarPV01</name>
<name>SolarPV02</name>
pesap marked this conversation as resolved.
Show resolved Hide resolved
<category_id>97</category_id>
<GUID>40d15a07-8ccc-460e-919a-ec8e211899a8</GUID>
</t_object>
Expand Down
42 changes: 25 additions & 17 deletions tests/test_plexosdb_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import shutil
import xml.etree.ElementTree as ET # noqa: N817
from plexosdb.enums import ClassEnum, CollectionEnum, Schema
from plexosdb.sqlite import PlexosSQLite
from collections.abc import Generator

DB_FILENAME = "plexosdb.xml"

Expand All @@ -12,8 +14,13 @@ def db_empty() -> "PlexosSQLite":


@pytest.fixture
def db(data_folder) -> PlexosSQLite:
return PlexosSQLite(xml_fname=data_folder.joinpath(DB_FILENAME))
def db(data_folder) -> Generator[PlexosSQLite, None, None]:
xml_fname = data_folder.joinpath(DB_FILENAME)
xml_copy = data_folder.joinpath(f"copy_{DB_FILENAME}")
shutil.copy(xml_fname, xml_copy)
db = PlexosSQLite(xml_fname=xml_copy)
yield db
xml_copy.unlink()
pesap marked this conversation as resolved.
Show resolved Hide resolved


def test_database_initialization(db):
Expand Down Expand Up @@ -80,10 +87,6 @@ def test_check_id_exists(db):
assert isinstance(system_check, bool)
assert not system_check

# Check that returns ValueError if multiple object founds
with pytest.raises(ValueError):
_ = db.check_id_exists(Schema.Objects, "SolarPV01", class_name=ClassEnum.Generator)


@pytest.mark.get_functions
def test_get_id(db):
Expand Down Expand Up @@ -156,27 +159,32 @@ def test_get_collection_id(db):

@pytest.mark.get_functions
def test_get_object_id(db):
gen_01_name = "gen1"
gen_id = db.add_object(
gen_01_name, ClassEnum.Generator, CollectionEnum.Generators, description="Test Gen"
)
assert gen_id

gen_id_get = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator)
assert gen_id == gen_id_get

# Add generator with same name different category
gen_01_name = "gen1"
category_name = "PV Gens"

gen_id = db.add_object(
gen_01_name,
ClassEnum.Generator,
CollectionEnum.Generators,
description="Test Gen",
category_name=category_name,
)
assert gen_id

gen_id_get = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator)
assert gen_id == gen_id_get

# Add generator with same name different category
with pytest.raises(ValueError):
_ = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator)
gen_01_name = "gen1"
gen_id = db.add_object(
gen_01_name,
ClassEnum.Generator,
CollectionEnum.Generators,
description="Test Gen",
)
pesap marked this conversation as resolved.
Show resolved Hide resolved
# for a given class, all names should be unique
_ = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator)

max_rank = db.get_category_max_id(ClassEnum.Generator)
assert max_rank == 2 # Data has ranks 0, 1. 2 is with the new category
Expand Down