Skip to content

Commit

Permalink
Fixed a failing unit test when exporting file based templates to a da…
Browse files Browse the repository at this point in the history
…tabase
  • Loading branch information
christophertubbs committed Nov 2, 2023
1 parent 4a46030 commit 133396a
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 43 deletions.
2 changes: 1 addition & 1 deletion python/lib/core/dmod/core/common/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,5 +492,5 @@ def _inner_sequence(self) -> typing.MutableSequence[_T]:
def _get_handlers(self) -> typing.Dict[CollectionEvent, typing.MutableSequence[typing.Callable]]:
return self._handlers

__root__: list[_T] = pydantic.Field(default_factory=list)
__root__: typing.List[_T] = pydantic.Field(default_factory=list)
_handlers: typing.Dict[CollectionEvent, typing.List[typing.Callable]] = PrivateAttr(default_factory=dict)
20 changes: 18 additions & 2 deletions python/lib/core/dmod/core/common/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,15 @@ def flat(collection: typing.Iterable[typing.Iterable[_CLASS_TYPE]]) -> typing.Se
>>> example_values = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> flat(example_values)
[1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> second_example_values = {
... "one": [1, 2, 3],
... "two": [4, 5, 6],
... "three": 7,
... "four": 8,
... "five": 9
... }
>>> flat(second_example_values)
[1, 2, 3, 4, 5, 6, 7, 8, 9]
Args:
collection: The collection of collections to flatten
Expand All @@ -671,8 +680,15 @@ def flat(collection: typing.Iterable[typing.Iterable[_CLASS_TYPE]]) -> typing.Se
"""
flattened_list: typing.MutableSequence[_CLASS_TYPE] = list()

for inner_collection in collection:
flattened_list.extend(inner_collection)
if isinstance(collection, typing.Mapping):
for mapped_value in collection.values():
if is_iterable_type(mapped_value):
flattened_list.extend(mapped_value)
else:
flattened_list.append(mapped_value)
else:
for inner_collection in collection:
flattened_list.extend(inner_collection)

return flattened_list

Expand Down
83 changes: 44 additions & 39 deletions python/lib/evaluations/dmod/evaluations/specification/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def initialize(self, connection: DBAPIConnection):
if cursor:
cursor.close()

def create_keyed_insert_query(self, columns: typing.Sequence[str]) -> QueryDetails:
def create_keyed_insert_query(self) -> QueryDetails:
"""
Generate a query that will insert unique values into the table
Expand All @@ -153,28 +153,20 @@ def create_keyed_insert_query(self, columns: typing.Sequence[str]) -> QueryDetai
Returns:
A sql query that will insert data into a database table
"""
missing_columns: typing.Set[str] = self.required_columns.union(self.keys).difference(columns)

if missing_columns:
raise ValueError(
f"Cannot create a script to insert values into '{self.name}' - "
f"missing the following required columns: {', '.join(columns)}"
)

# Insert values into each column that don't already exist based on the given values
query = f"""
INSERT INTO {self.name}
({', '.join(columns)})
SELECT {', '.join(['?' for _ in columns])}
({', '.join(self.column_names)})
SELECT {', '.join(['?' for _ in self.column_names])}
WHERE NOT EXISTS (
SELECT 1
FROM {self.name}
WHERE {' AND '.join([f'{key} = ?' for key in self.keys])}
)"""
labels = [name for name in columns] + [key for key in self.keys]
labels = [name for name in self.column_names] + [key for key in self.keys]
return QueryDetails(query=query, value_labels=labels)

def create_unkeyed_insert_query(self, columns: typing.Sequence[str]) -> QueryDetails:
def create_unkeyed_insert_query(self) -> QueryDetails:
"""
Create an insert query for this table that does not identify values based on keys
Expand All @@ -184,21 +176,13 @@ def create_unkeyed_insert_query(self, columns: typing.Sequence[str]) -> QueryDet
Returns:
An insert script for this table that will insert values into the given columns
"""
missing_columns: typing.Set[str] = self.required_columns.difference(columns)

if missing_columns:
raise ValueError(
f"Cannot create a script to insert values into '{self.name}' - "
f"missing the following required columns: {', '.join(columns)}"
)

# Insert values into the given database
query = f"""
INSERT INTO {self.name}
({', '.join(columns)}
VALUES ({', '.join(['?' for _ in columns])});
({', '.join(self.column_names)})
VALUES ({', '.join(['?' for _ in self.column_names])});
"""
labels = [name for name in columns]
labels = [name for name in self.column_names]
return QueryDetails(query=query, value_labels=labels)

def override_table(self, connection: DBAPIConnection):
Expand All @@ -219,7 +203,7 @@ def override_table(self, connection: DBAPIConnection):
try:
cursor = connection.cursor()

cursor.execute(f"TRUNCATE TABLE {self.name}")
cursor.execute(f"DROP TABLE {self.name}")
cursor.fetchall()

try:
Expand Down Expand Up @@ -258,24 +242,41 @@ def insert_templates(
query_details: QueryDetails

if self.keys and not override:
query_details = self.create_keyed_insert_query(self.column_names)
query_details = self.create_keyed_insert_query()
else:
query_details = self.create_unkeyed_insert_query(self.column_names)
query_details = self.create_unkeyed_insert_query()

query: str = query_details.query

parameters: typing.Sequence[typing.Sequence[typing.Any]] = [
[
getattr(row, label)
for label in query_details.value_labels
]
for row in rows
]
parameters: typing.List[typing.Sequence[typing.Any]] = list()

for template in rows:
template_parameters: typing.List[str, int, float, bool] = list()

for label in query_details.value_labels:
# The configuration may not always be on the template object and will need to be read, so call
# 'get_configuration' instead of just grabbing 'configuration'
if label == 'configuration':
template_parameters.append(
json.dumps(
template.get_configuration()
)
)
else:
template_parameters.append(getattr(template, label))

parameters.append(template_parameters)

cursor: typing.Optional[DBAPICursor] = None

try:
cursor = connection.cursor()
cursor.executemany(query, parameters)

try:
cursor.executemany(query, parameters)
except:
print(query, file=sys.stderr)
raise

try:
connection.commit()
Expand Down Expand Up @@ -308,14 +309,12 @@ def get_template_table(name: str) -> Table:
Column(name="name", datatype="VARCHAR(255)"),
Column(name="specification_type", datatype="VARCHAR(255)"),
Column(name="description", datatype="VARCHAR(500)", optional=True),
Column(name="author", datatype="VARCHAR(500)", optional=True),
Column(name="author_name", datatype="VARCHAR(500)", optional=True),
Column(name="configuration", datatype="TEXT"),
Column(name="last_modified", datatype="VARCHAR(50)")
],
keys=[
'name',
'specification_type',
'author'
'name'
]
)

Expand Down Expand Up @@ -511,6 +510,12 @@ def __eq__(self, other):
and self.specification_type == other.specification_type \
and self.get_configuration() == other.get_configuration()

def __str__(self):
return f"[{self.specification_type}] {self.name}{f' : {self.description}' if self.description else ''}"

def __repr__(self):
return self.__str__()


class TemplateManager(abc.ABC, TemplateManagerProtocol):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def dict(self) -> typing.Dict[str, typing.Union[str, pathlib.Path]]:

return details

def __str__(self):
return f"[{self.specification_type}] {self.name}{f' : {self.description}' if self.description else ''}"

def __repr__(self):
return self.__str__()


def serialize_path(path: pathlib.Path, *args, **kwargs) -> str:
return str(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ def field_choice(self) -> typing.Tuple[str, str]:
return self.name, self.name

def to_details(self) -> TemplateDetails:
return BasicTemplateDetails.from_details(self)
return BasicTemplateDetails.copy(self)

def __str__(self):
return f"[{self.specification_type}] {self.name}{':' + self.description if self.description else ''}"

def __repr__(self):
return self.__str__()


class EvaluationDefinition(models.Model):
"""
Expand Down

0 comments on commit 133396a

Please sign in to comment.