From 2e0161688ebe19858db93ad57d8e6f23b7d3d413 Mon Sep 17 00:00:00 2001 From: pesap Date: Wed, 18 Sep 2024 16:52:22 -0600 Subject: [PATCH] feat: Enable capability of find enums with spaces and new methods for easy interaction (#17) --- src/plexosdb/enums.py | 11 ++- src/plexosdb/schema.sql | 4 +- src/plexosdb/sqlite.py | 149 ++++++++++++++++++++++++++-------- src/plexosdb/utils.py | 9 ++ src/plexosdb/xml_handler.py | 95 +--------------------- tests/data/plexosdb.xml | 28 +++++++ tests/test_plexosdb_sqlite.py | 21 ++++- tests/test_plexosdb_xml.py | 38 +-------- 8 files changed, 186 insertions(+), 169 deletions(-) diff --git a/src/plexosdb/enums.py b/src/plexosdb/enums.py index e59c86f..0e80b38 100644 --- a/src/plexosdb/enums.py +++ b/src/plexosdb/enums.py @@ -71,6 +71,7 @@ class ClassEnum(StrEnum): Production = "Production" Performance = "Performance" Variable = "Variable" + Constraint = "Constraint" plexos_class_mapping = {enum_member.name: enum_member.value for enum_member in ClassEnum} @@ -84,8 +85,7 @@ class CollectionEnum(StrEnum): HeadStorage = "HeadStorage" TailStorage = "TailStorage" Nodes = "Nodes" - Battery = "Battery" - Storage = "Storage" + Storages = "Storages" Emissions = "Emissions" Reserves = "Reserves" Batteries = "Batteries" @@ -98,10 +98,12 @@ class CollectionEnum(StrEnum): NodeTo = "NodeTo" Transformers = "Transformers" Interfaces = "Interfaces" - Scenarios = "Scenarios" Models = "Models" Scenario = "Scenario" + Scenarios = "Scenarios" Horizon = "Horizon" + Horizons = "Horizons" + Report = "Report" Reports = "Reports" PASA = "PASA" MTSchedule = "MTSchedule" @@ -109,8 +111,11 @@ class CollectionEnum(StrEnum): Transmission = "Transmission" Production = "Production" Diagnostic = "Diagnostic" + Diagnostics = "Diagnostics" Performance = "Performance" DataFiles = "DataFiles" + Constraint = "Constraint" + Constraints = "Constraints" def str2enum(string, schema_enum=Schema) -> Schema | None: diff --git a/src/plexosdb/schema.sql b/src/plexosdb/schema.sql index fadf287..11ad044 100644 --- a/src/plexosdb/schema.sql +++ b/src/plexosdb/schema.sql @@ -89,7 +89,7 @@ CREATE TABLE `t_custom_rule` CREATE TABLE `t_class` ( `class_id` INT NOT NULL, - `name` VARCHAR(255) NULL, + `name` VARCHAR(255) NULL COLLATE NOSPACE, `class_group_id` INT NULL, `is_enabled` BIT NULL, `lang_id` INT NULL, @@ -105,7 +105,7 @@ CREATE TABLE `t_collection` `collection_id` INT NOT NULL, `parent_class_id` INT NULL, `child_class_id` INT NULL, - `name` VARCHAR(255) NULL, + `name` VARCHAR(255) NULL COLLATE NOSPACE, `min_count` INT NULL, `max_count` INT NULL, `complement_name` VARCHAR(255) NULL, diff --git a/src/plexosdb/sqlite.py b/src/plexosdb/sqlite.py index ec7a2f3..f198264 100644 --- a/src/plexosdb/sqlite.py +++ b/src/plexosdb/sqlite.py @@ -10,12 +10,11 @@ from loguru import logger -from .utils import batched +from .utils import batched, no_space from .enums import ClassEnum, CollectionEnum, Schema, str2enum from .xml_handler import XMLHandler SYSTEM_CLASS_NAME = "System" -MASTER_FILE = files("plexosdb").joinpath("master.xml") class PlexosSQLite: @@ -33,6 +32,7 @@ def __init__(self, xml_fname: str | None = None, xml_handler: XMLHandler | None super().__init__() self._conn = sqlite3.connect(":memory:") self._sqlite_config() + self._create_collations() self._create_table_schema() self._populate_database(xml_fname=xml_fname, xml_handler=xml_handler) @@ -262,11 +262,11 @@ def add_property( KeyError When the property is not a valid string for the collection. """ + parent_class = parent_class or ClassEnum.System object_id = self.get_object_id(object_name, class_name=object_class) - collection_id = self.get_collection_id( + valid_properties = self.get_valid_properties( collection, child_class=object_class, parent_class=parent_class ) - valid_properties = self.get_valid_properties(collection_id) if property_name not in valid_properties: msg = ( f"Property {property_name} does not exist for collection: {collection}. " @@ -280,7 +280,6 @@ def add_property( # Add system membership parent_object_name = parent_object_name or SYSTEM_CLASS_NAME # Default to system class - parent_class = parent_class or ClassEnum.System membership_id = self.get_membership_id( child_name=object_name, @@ -306,10 +305,9 @@ def add_property( # Add scenario tag if passed if scenario: - scenario_id = self.check_id_exists(Schema.Objects, scenario, class_name=ClassEnum.Scenario) + scenario_id = self.get_scenario_id(scenario) if scenario_id is None: scenario_id = self.add_object(scenario, ClassEnum.Scenario, CollectionEnum.Scenario) - self.execute_query("INSERT into t_tag(object_id,data_id) values (?,?)", (scenario_id, data_id)) # Add text if passed @@ -344,15 +342,16 @@ def add_property_from_records( property_ids = {key: value for key, value in collection_properties} component_names = tuple(d["name"] for d in records) component_memberships_query = f""" - SELECT - t_object.name as name, - membership_id - FROM - t_membership - inner join t_object on t_membership.child_object_id = t_object.object_id - WHERE - t_membership.parent_object_id = {parent_object_id} and - t_object.name in ({", ".join(["?" for _ in range(len(component_names))])}) + SELECT + t_object.name as name, + membership_id + FROM + t_membership + INNER JOIN + t_object on t_membership.child_object_id = t_object.object_id + WHERE + t_membership.parent_object_id = {parent_object_id} AND + t_object.name in ({", ".join(["?" for _ in range(len(component_names))])}) """ component_memberships = self.query(component_memberships_query, params=component_names) component_memberships_dict: dict = {key: value for key, value in component_memberships} @@ -430,8 +429,10 @@ def add_object( Name to be added to the object class_id ClassEnum from the object to be added. E.g., for generators class_id=ClassEnum.Generators - collection_id + collection_name Collection for system membership. E.g., for generators class_enum=CollectionEnum.SystemGenerators + parent_class_name + Name of the parent class if different from System. Notes ----- @@ -584,6 +585,43 @@ def get_collection_id( Schema.Collection, collection.name, parent_class_name=parent_class, child_class_name=child_class ) + def get_category_max_id(self, class_enum: ClassEnum) -> int: + """Return the current max rank for a given category.""" + class_id = self._get_id(Schema.Class, class_enum.name) + query = """ + SELECT + max(rank) + FROM + t_category + LEFT JOIN + t_class ON t_class.class_id = t_category.class_id + WHERE + t_class.class_id = :class_id + """ + return self.query(query, params={"class_id": class_id})[0][0] + + def get_class_id(self, class_enum: ClassEnum) -> int: + """Return the ID for a given class. + + Parameters + ---------- + class_name : ClassEnum + The enum collection from which to retrieve the ID. + + Returns + ------- + int + The ID corresponding to the object, or None if not found. + + Raises + ------ + KeyError + If ID does not exists on the database. + ValueError + If multiple IDs are returned for the given class. + """ + return self._get_id(Schema.Class, class_enum.name) + def get_property_id( self, property_name: str, @@ -616,9 +654,9 @@ def get_property_id( ValueError If multiple IDs are returned for the given parent/child class provided. """ - collection_id = self.get_collection_id(collection, parent_class=parent_class, child_class=child_class) - valid_properties = self.get_valid_properties(collection_id) - + valid_properties = self.get_valid_properties( + collection, parent_class=parent_class, child_class=child_class + ) if property_name not in valid_properties: msg = ( f"Property {property_name} does not exist for collection: {collection}. " @@ -626,13 +664,16 @@ def get_property_id( ) raise KeyError(msg) + collection_id = self.get_collection_id(collection, parent_class=parent_class, child_class=child_class) + query_id = """ SELECT property_id FROM `t_property` WHERE name = :property_name - AND collection_id = :collection_id + AND + collection_id = :collection_id """ params = {"property_name": property_name, "collection_id": collection_id} result = self.query(query_id, params) @@ -744,32 +785,54 @@ def _get_id( """ table_name = table.name column_name = table.label - - query_id = f"SELECT {column_name} FROM `{table_name}` WHERE name = :object_name" params = { "object_name": object_name, } - if class_name: + + query = f"SELECT {column_name} FROM `{table_name}`" + conditions = [] + join_clauses = [] + + if class_name is not None: + assert isinstance(class_name, ClassEnum) class_id = self._get_id(Schema.Class, class_name) - query_id += " and class_id = :class_id" + conditions.append("class_id = :class_id") params["class_id"] = class_id - if parent_class_name: + if parent_class_name is not None: + assert isinstance(parent_class_name, ClassEnum) parent_class_id = self._get_id(Schema.Class, parent_class_name.name) - query_id += " and parent_class_id = :parent_class_id" + join_clauses.append( + f" LEFT JOIN t_class as parent_class ON {table_name}.parent_class_id = parent_class.class_id" + ) + conditions.append("parent_class_id = :parent_class_id") params["parent_class_id"] = parent_class_id - if child_class_name: + if child_class_name is not None: + assert isinstance(child_class_name, ClassEnum) child_class_id = self._get_id(Schema.Class, child_class_name.name) - query_id += " and child_class_id = :child_class_id" + join_clauses.append( + f" LEFT JOIN t_class AS child_class ON {table_name}.child_class_id = child_class.class_id" + ) + conditions.append("child_class_id = :child_class_id") params["child_class_id"] = child_class_id if category_name and class_name: category_id = self.get_category_id(category_name, class_name) - query_id += " and category_id = :category_id" + conditions.append("category_id = :category_id") params["category_id"] = category_id - result = self.query(query_id, params) + # Build final query + query += " ".join(join_clauses) + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += ( + f" AND {table_name}.name = :object_name" + if conditions + else f" WHERE {table_name}.name = :object_name" + ) + + result = self.query(query, params) if not result: msg = f"No object found with the requested {params=}" @@ -950,12 +1013,18 @@ def get_memberships( def get_scenario_id(self, scenario_name: str) -> int: """Return scenario id for a given scenario name.""" scenario_id = self.check_id_exists(Schema.Objects, scenario_name, class_name=ClassEnum.Scenario) - if scenario_id is None: - scenario_id = self.add_object(scenario_name, ClassEnum.Scenario, CollectionEnum.Scenario) + if not scenario_id: + scenario_id = self.add_object(scenario_name, ClassEnum.Scenario, CollectionEnum.Scenarios) return scenario_id - def get_valid_properties(self, collection_id: int) -> list[str]: + def get_valid_properties( + self, + collection: CollectionEnum, + parent_class: ClassEnum | None = None, + child_class: ClassEnum | None = None, + ) -> list[str]: """Return list of valid property names per collection.""" + collection_id = self.get_collection_id(collection, parent_class=parent_class, child_class=child_class) query_string = "SELECT name from t_property where collection_id = ?" result = self.query(query_string, (collection_id,)) return [d[0] for d in result] @@ -1139,8 +1208,11 @@ def _create_table_schema(self) -> None: def _populate_database(self, xml_fname: str | None, xml_handler: XMLHandler | None = None): fpath = xml_fname if fpath is None and not xml_handler: - fpath = MASTER_FILE # type: ignore - logger.debug("Using {} as default file", fpath) + msg = ( + "Base XML file was not provided. " + "Make sure that you are passing either `xml_fname` or xml_handler`." + ) + raise FileNotFoundError(msg) if not xml_handler: xml_handler = XMLHandler.parse(fpath=fpath) # type: ignore @@ -1152,3 +1224,8 @@ def _populate_database(self, xml_fname: str | None, xml_handler: XMLHandler | No if schema: record_dict = xml_handler.get_records(schema) self.ingest_from_records(tag, record_dict) + + def _create_collations(self) -> None: + """Add collate function for helping search enums.""" + self._conn.create_collation("NOSPACE", no_space) + return diff --git a/src/plexosdb/utils.py b/src/plexosdb/utils.py index df45fb0..ac9f36a 100644 --- a/src/plexosdb/utils.py +++ b/src/plexosdb/utils.py @@ -51,3 +51,12 @@ def validate_string(value: str) -> Any: logger.trace("Could not parse {}", value) finally: return value + + +def no_space(a: str, b: str) -> int: + """Collate function for catching strings with spaces.""" + if a.replace(" ", "") == b.replace(" ", ""): + return 0 + if a.replace(" ", "") < b.replace(" ", ""): + return -1 + return 1 diff --git a/src/plexosdb/xml_handler.py b/src/plexosdb/xml_handler.py index 418c0ec..7a22152 100644 --- a/src/plexosdb/xml_handler.py +++ b/src/plexosdb/xml_handler.py @@ -1,16 +1,13 @@ """Plexos Input XML API.""" -from collections import defaultdict import xml.etree.ElementTree as ET # noqa: N817 +from collections import defaultdict from collections.abc import Iterable, Iterator -from enum import Enum -from functools import lru_cache from os import PathLike from loguru import logger -from .exceptions import MultlipleElementsError -from .enums import CollectionEnum, Schema +from .enums import Schema from .utils import validate_string @@ -73,58 +70,6 @@ def get_records( ) ) - def get_id(self, element_enum: Schema, *, label: str | None = None, **tag_elements) -> str: - """Return element ID matching name, tags and values. - - This function should return the element_id for a a single element. If - the query returns more than one element, it will raise an error. - - Returns - ------- - str - Element type id - - Raises - ------ - KeyError - If combination of element_name and tags do not exists - MultipleElementsError - If more than one element found - """ - element = list(self.iter(element_enum, **tag_elements)) - - if not element: - msg = f"{element_enum=} with {tag_elements=} not found" - raise KeyError(msg) - - if len(element) > 1: - msg = ( - f"Multiple elements returned for {element_enum=}.{tag_elements}. " - "Use `iter` too see all the returned elements or provide additional filters." - ) - raise MultlipleElementsError(msg) - - if label is None: - return element[0].findtext(element_enum.label) # type: ignore - - return element[0].findtext(label) # type: ignore - - def get_max_id(self, element_type: Schema): - """Return max id for a given child class. - - Paramters - --------- - element_type - XML parent tag to iterate over. - """ - # element = list(self.iter(element_type.name)) - return max(0, self._counts.get(element_type.name, 0)) - - def _get_xml_element(self, element_type: Schema, label: str | None = None, **kwargs) -> ET.Element: - element_id = self.get_id(element_enum=element_type, label=label, **kwargs) - element = list(self.iter(element_type, element_id, label=label))[0] # noqa: RUF015 - return element - def iter( self, element_type: Schema, *elements: Iterable[str | int], label: str | None = None, **tags ) -> Iterable[ET.Element]: @@ -188,8 +133,6 @@ def to_xml(self, fpath: str | PathLike) -> None: def _cache_iter(self, element_type: Schema, **tag_elements) -> Iterator | list: if not tag_elements: return iter(self._cache[element_type.name]) - if element_type.label not in tag_elements: - return filter(construct_condition_lambda(**tag_elements), self._cache[element_type.name]) index = int(tag_elements[element_type.label]) - 1 return iter([self._cache[element_type.name][index]]) @@ -212,24 +155,6 @@ def _iter_elements(self, element_type: str, *elements, **tag_elements) -> Iterat elements = self.root.findall(xpath_query) # type: ignore yield from elements - @lru_cache - def get_valid_properties_list(self, collection_enum: CollectionEnum | None = None): - """Return list of valid properties for the given Collection.""" - return list( - map( - lambda x: x.findtext("name"), - self.iter(Schema.Property, collection_id=collection_enum), - ) - ) - - @lru_cache - def get_valid_properties_dict(self, collection_enum: CollectionEnum | None = None): - """Return list of valid properties for the given Collection.""" - return { - x.findtext("property_id"): x.findtext("name") - for x in self.iter(Schema.Property, collection_id=collection_enum) - } - def _remove_namespace(self, namespace: str) -> None: """Remove namespace in the passed document in place. @@ -245,22 +170,6 @@ def _remove_namespace(self, namespace: str) -> None: elem.tag = elem.tag[nsl:] -def construct_condition_lambda(**kwargs): # noqa: D103 - # Precompute the values of findtext calls - findtext_values = {key: str(value) for key, value in kwargs.items() if value} - - # Construct the lambda function - def condition(x): - for key, value in findtext_values.items(): - if isinstance(value, Enum): - value = value.value - if x.findtext(key) != value: - return False - return True - - return condition - - def xml_query(element_name: str, *tags, **tag_elements) -> str: """Construct XPath query for extracting data from a XML with no namespace. diff --git a/tests/data/plexosdb.xml b/tests/data/plexosdb.xml index 222ac7c..6870ce6 100644 --- a/tests/data/plexosdb.xml +++ b/tests/data/plexosdb.xml @@ -216,6 +216,34 @@ 87 Report objects + + 700 + 1 + 78 + Scenarios + 0 + -1 + true + true + 90 + Scenario objects + + + 774 + 96 + 38 + Gas Nodes + 0 + -1 + Layouts + 0 + -1 + true + true + 115 + set of Gas Node objects in the Layout + set of Layouts containing the Gas Node + 47 1 diff --git a/tests/test_plexosdb_sqlite.py b/tests/test_plexosdb_sqlite.py index 906bea7..25c5291 100644 --- a/tests/test_plexosdb_sqlite.py +++ b/tests/test_plexosdb_sqlite.py @@ -12,7 +12,7 @@ def db_empty() -> "PlexosSQLite": @pytest.fixture -def db(data_folder) -> "PlexosSQLite": +def db(data_folder) -> PlexosSQLite: return PlexosSQLite(xml_fname=data_folder.joinpath(DB_FILENAME)) @@ -178,6 +178,9 @@ def test_get_object_id(db): with pytest.raises(ValueError): _ = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator) + max_rank = db.get_category_max_id(ClassEnum.Generator) + assert max_rank == 2 # Data has ranks 0, 1. 2 is with the new category + # Now actually filter by category object_id = db.get_object_id(gen_01_name, class_name=ClassEnum.Generator, category_name=category_name) assert object_id @@ -604,6 +607,22 @@ def test_create_table_element(db): assert column_element.text == str(column_value) +def test_populate_database(db): + with pytest.raises(FileNotFoundError): + _ = db._populate_database(xml_fname=None, xml_handler=None) + + +def test_coalesce_columns(db): + # Database can find some columns specified with Collate specified + collection_id = db.query("SELECT collection_id from t_collection where name = 'Gas Nodes'")[0][0] + assert collection_id + collection_id_nospace = db.query("SELECT collection_id from t_collection where name = 'GasNodes'")[0][0] + assert collection_id_nospace + + # Both should be equal + assert collection_id_nospace == collection_id + + def test_to_xml(db, tmp_path): fname = "testing" fpath = tmp_path / fname diff --git a/tests/test_plexosdb_xml.py b/tests/test_plexosdb_xml.py index 0be3a57..21e9505 100644 --- a/tests/test_plexosdb_xml.py +++ b/tests/test_plexosdb_xml.py @@ -1,11 +1,10 @@ -import pytest import os -from pathlib import Path import xml.etree.ElementTree as ET # noqa: N817 -from plexosdb.exceptions import MultlipleElementsError +from pathlib import Path + +import pytest from plexosdb.enums import Schema -from plexosdb.xml_handler import XMLHandler -from plexosdb.xml_handler import xml_query +from plexosdb.xml_handler import XMLHandler, xml_query XML_FPATH = Path("tests").joinpath("data/plexosdb.xml") NAMESPACE = "http://tempuri.org/MasterDataSet.xsd" @@ -50,16 +49,6 @@ def test_cache_construction(): assert len(handler._cache) == 9 # Total number of elements parsed -@pytest.mark.parametrize( - "class_name,category,name,expected_id", - [(Schema.Class, None, "System", "1"), (Schema.Objects, 2, "SolarPV01", "2")], -) -def test_xml_get_id(xml_handler, class_name, category, name, expected_id): - element_id = xml_handler.get_id(class_name, name=name, category_id=category) - assert element_id is not None - assert element_id == expected_id - - def test_iter(xml_handler): elements = list(xml_handler.iter(Schema.Objects)) assert len(elements) == 3 @@ -83,25 +72,6 @@ def test_save_xml(tmp_path): assert getattr(handler, "_counts", False) is False -# TODO(pesap): Add test to round-trip serialization of plexos model. -# https://github.nrel.gov/PCM/R2X/issues/361 - - -def test_get_element_id_returns(xml_handler): - # Assert that we raise and error if element combination is not found - with pytest.raises(KeyError): - xml_handler.get_id(Schema.Class, name="test") - - # Assert that raises error if multiple matches found - with pytest.raises(MultlipleElementsError, match="Multiple elements returned"): - xml_handler.get_id(Schema.Class) - - -def test_get_max_id(xml_handler): - max_id = xml_handler.get_max_id(Schema.Objects) - assert max_id == 3 - - @pytest.mark.parametrize( "element_name, tags, tag_elements, expected_query", [