From 762a92426bf28601b58b4eeb37426506e9564c73 Mon Sep 17 00:00:00 2001 From: Steph Merritt Date: Thu, 18 Apr 2024 15:18:54 +0100 Subject: [PATCH] Writing out to SQL database. --- src/adler/dataclasses/AdlerData.py | 147 +++++++++++++++++++++- tests/adler/dataclasses/test_AdlerData.py | 32 ++++- tests/data/test_SQL_database_table.csv | 2 + 3 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 tests/data/test_SQL_database_table.csv diff --git a/src/adler/dataclasses/AdlerData.py b/src/adler/dataclasses/AdlerData.py index 9d24d43..1453cda 100644 --- a/src/adler/dataclasses/AdlerData.py +++ b/src/adler/dataclasses/AdlerData.py @@ -1,5 +1,8 @@ -from dataclasses import dataclass, field +import os +import sqlite3 import numpy as np +from dataclasses import dataclass, field +from datetime import datetime, timezone FILTER_DEPENDENT_KEYS = ["phaseAngle_min", "phaseAngle_range", "nobs", "arc"] @@ -209,6 +212,148 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None): return output_obj + def _get_database_connection(self, filepath): + """Returns the connection to the output SQL database, creating it if it does not exist. + + Parameters + ----------- + filepath : path-like object + Filepath with the location of the output SQL database. + + Returns + ---------- + con : sqlite3 Connection object + The connection to the output database. + + """ + + database_exists = os.path.isfile( + filepath + ) # check this FIRST as the next statement creates the db if it doesn't exist + con = sqlite3.connect(filepath) + + if not database_exists: # we need to make the table and a couple of starter columns + cur = con.cursor() + cur.execute("CREATE TABLE AdlerData(ssObjectId, timestamp)") + + return con + + def _get_database_columns(self, con, table_name): + """Gets a list of the current columns in a given table in a SQL database. + + Parameters + ----------- + con : sqlite3 Connection object + The connection to the output SQL database. + + table_name : str + The name of the relevant table in the database. + + + Returns + ---------- + list of str + List of current columns existing in the table. + + """ + + cur = con.cursor() + cur.execute(f"""SELECT * from {table_name} where 1=0""") + return [d[0] for d in cur.description] + + def _get_row_data_and_columns(self): + """Collects all of the data present in the AdlerData object as a list with a corresponding list of column names, + in preparation for a row to be written to a SQL database table. + + Returns + ----------- + row_data : list + A list containing all of the relevant data present in the AdlerData object. + + required_columns : list of str + A list of the corresponding column names in the same order. + + """ + required_columns = ["ssObjectId", "timestamp"] + row_data = [self.ssObjectId, str(datetime.now(timezone.utc))] + + for f, filter_name in enumerate(self.filter_list): + columns_by_filter = ["_".join([filter_name, filter_key]) for filter_key in FILTER_DEPENDENT_KEYS] + data_by_filter = [ + getattr(self.filter_dependent_values[f], filter_key) for filter_key in FILTER_DEPENDENT_KEYS + ] + + required_columns.extend(columns_by_filter) + row_data.extend(data_by_filter) + + for m, model_name in enumerate(self.filter_dependent_values[f].model_list): + columns_by_model = [ + "_".join([filter_name, model_name, model_key]) for model_key in MODEL_DEPENDENT_KEYS + ] + data_by_model = [ + getattr(self.filter_dependent_values[f].model_dependent_values[m], model_key) + for model_key in MODEL_DEPENDENT_KEYS + ] + + required_columns.extend(columns_by_model) + row_data.extend(data_by_model) + + return row_data, required_columns + + def _ensure_columns(self, con, table_name, current_columns, required_columns): + """Creates new columns in a given table of a SQL database as needed by checking the list of current columns against a list + of required columns. + + + Parameters + ----------- + con : sqlite3 Connection object + The connection to the output SQL database. + + table_name : str + The name of the relevant table in the database. + + current_columns : list of str + A list of the columns already existing in the database table. + + required_columns : list of str + A list of the columns needed in the database table. + + """ + + cur = con.cursor() + for column_name in required_columns: + if column_name not in current_columns: + cur.execute(f"""ALTER TABLE {table_name} ADD COLUMN {column_name}""") + + def write_row_to_database(self, filepath, table_name="AdlerData"): + """Writes all of the relevant data contained within the AdlerData object to a timestamped row in a SQLite database. + + Parameters + ----------- + filepath : path-like object + Filepath with the location of the output SQL database. + + table_name : str, optiona + String containing the table name to write the data to. Default is "AdlerData". + + """ + + con = self._get_database_connection(filepath) + + row_data, required_columns = self._get_row_data_and_columns() + current_columns = self._get_database_columns(con, table_name) + self._ensure_columns(con, table_name, current_columns, required_columns) + + column_names = ",".join(required_columns) + column_spaces = ",".join(["?"] * len(required_columns)) + sql_command = "INSERT INTO %s (%s) values(%s)" % (table_name, column_names, column_spaces) + + cur = con.cursor() + cur.execute(sql_command, row_data) + con.commit() + con.close() + @dataclass class FilterDependentAdler: diff --git a/tests/adler/dataclasses/test_AdlerData.py b/tests/adler/dataclasses/test_AdlerData.py index 1363746..c00c656 100644 --- a/tests/adler/dataclasses/test_AdlerData.py +++ b/tests/adler/dataclasses/test_AdlerData.py @@ -1,9 +1,13 @@ -from numpy.testing import assert_array_equal +import os import pytest import numpy as np +import pandas as pd +import sqlite3 -from adler.dataclasses.AdlerData import AdlerData +from numpy.testing import assert_array_equal +from adler.dataclasses.AdlerData import AdlerData +from adler.utilities.tests_utilities import get_test_data_filepath # setting up the AdlerData object to be used for testing @@ -126,6 +130,9 @@ def test_populate_phase_parameters(): test_object.populate_phase_parameters("u", model_name="model_1", H=15.0) # check to make sure filter-dependent parameter is correctly updated (then return it to previous) + test_object.populate_phase_parameters("u", nobs=99) + assert test_object.filter_dependent_values[0].nobs == 99 + test_object.populate_phase_parameters("u", nobs=13) # testing to make sure the correct error messages trigger with pytest.raises(ValueError) as error_info_1: @@ -222,3 +229,24 @@ def test_print_data(capsys): expected = "Filter: u\nPhase angle minimum: 11.0\nPhase angle range: 12.0\nNumber of observations: 13\nArc: 14.0\nModel: model_1.\n\tH: 15.0\n\tH error: 16.0\n\tPhase parameter 1: 17.0\n\tPhase parameter 1 error: 18.0\n\tPhase parameter 2: nan\n\tPhase parameter 2 error: nan\nModel: model_2.\n\tH: 25.0\n\tH error: 26.0\n\tPhase parameter 1: 27.0\n\tPhase parameter 1 error: 28.0\n\tPhase parameter 2: 29.0\n\tPhase parameter 2 error: 30.0\n\n\nFilter: g\nPhase angle minimum: 31.0\nPhase angle range: 32.0\nNumber of observations: 33\nArc: 34.0\nModel: model_1.\n\tH: 35.0\n\tH error: 36.0\n\tPhase parameter 1: 37.0\n\tPhase parameter 1 error: 38.0\n\tPhase parameter 2: nan\n\tPhase parameter 2 error: nan\n\n\nFilter: r\nPhase angle minimum: 41.0\nPhase angle range: 42.0\nNumber of observations: 43\nArc: 44.0\nModel: model_2.\n\tH: 45.0\n\tH error: 46.0\n\tPhase parameter 1: 47.0\n\tPhase parameter 1 error: 48.0\n\tPhase parameter 2: 49.0\n\tPhase parameter 2 error: 50.0\n\n\n" assert captured.out == expected + + +def test_write_row_to_database(tmp_path): + db_location = os.path.join(tmp_path, "test_AdlerData_database.db") + test_object.write_row_to_database(db_location) + + con = sqlite3.connect(db_location) + written_data = pd.read_sql_query("SELECT * from AdlerData", con) + con.close() + + expected_data_filepath = get_test_data_filepath("test_SQL_database_table.csv") + expected_data = pd.read_csv(expected_data_filepath) + + # we don't expect the timestamp column to be the same, obviously + expected_data = expected_data.drop(columns="timestamp") + written_data = written_data.drop(columns="timestamp") + + # note that because I'm using Pandas there's some small dtype and np.nan/None stuff to clear up + # but this makes for a quick streamlined test anyway + expected_data = expected_data.replace({np.nan: None}) + pd.testing.assert_frame_equal(expected_data, written_data, check_dtype=False) diff --git a/tests/data/test_SQL_database_table.csv b/tests/data/test_SQL_database_table.csv new file mode 100644 index 0000000..790d23c --- /dev/null +++ b/tests/data/test_SQL_database_table.csv @@ -0,0 +1,2 @@ +ssObjectId,timestamp,u_phaseAngle_min,u_phaseAngle_range,u_nobs,u_arc,u_model_1_H,u_model_1_H_err,u_model_1_phase_parameter_1,u_model_1_phase_parameter_1_err,u_model_1_phase_parameter_2,u_model_1_phase_parameter_2_err,u_model_2_H,u_model_2_H_err,u_model_2_phase_parameter_1,u_model_2_phase_parameter_1_err,u_model_2_phase_parameter_2,u_model_2_phase_parameter_2_err,g_phaseAngle_min,g_phaseAngle_range,g_nobs,g_arc,g_model_1_H,g_model_1_H_err,g_model_1_phase_parameter_1,g_model_1_phase_parameter_1_err,g_model_1_phase_parameter_2,g_model_1_phase_parameter_2_err,r_phaseAngle_min,r_phaseAngle_range,r_nobs,r_arc,r_model_2_H,r_model_2_H_err,r_model_2_phase_parameter_1,r_model_2_phase_parameter_1_err,r_model_2_phase_parameter_2,r_model_2_phase_parameter_2_err +666,2024-04-18 13:32:07.096776+00:00,11.0,12.0,13,14.0,15.0,16.0,17.0,18.0,,,25.0,26.0,27.0,28.0,29.0,30.0,31.0,32.0,33,34.0,35.0,36.0,37.0,38.0,,,41.0,42.0,43,44.0,45.0,46.0,47.0,48.0,49.0,50.0