Skip to content

Commit

Permalink
Writing out to SQL database. (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
astronomerritt authored Apr 19, 2024
1 parent 3eb5f30 commit 558686c
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 3 deletions.
147 changes: 146 additions & 1 deletion src/adler/dataclasses/AdlerData.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
import os
import sqlite3
import numpy as np
from dataclasses import dataclass, field
from datetime import datetime, timezone


FILTER_DEPENDENT_KEYS = ["phaseAngle_min", "phaseAngle_range", "nobs", "arc"]
Expand Down Expand Up @@ -209,6 +212,148 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None):

return output_obj

def _get_database_connection(self, filepath):
"""Returns the connection to the output SQL database, creating it if it does not exist.
Parameters
-----------
filepath : path-like object
Filepath with the location of the output SQL database.
Returns
----------
con : sqlite3 Connection object
The connection to the output database.
"""

database_exists = os.path.isfile(
filepath
) # check this FIRST as the next statement creates the db if it doesn't exist
con = sqlite3.connect(filepath)

if not database_exists: # we need to make the table and a couple of starter columns
cur = con.cursor()
cur.execute("CREATE TABLE AdlerData(ssObjectId, timestamp)")

return con

def _get_database_columns(self, con, table_name):
"""Gets a list of the current columns in a given table in a SQL database.
Parameters
-----------
con : sqlite3 Connection object
The connection to the output SQL database.
table_name : str
The name of the relevant table in the database.
Returns
----------
list of str
List of current columns existing in the table.
"""

cur = con.cursor()
cur.execute(f"""SELECT * from {table_name} where 1=0""")
return [d[0] for d in cur.description]

def _get_row_data_and_columns(self):
"""Collects all of the data present in the AdlerData object as a list with a corresponding list of column names,
in preparation for a row to be written to a SQL database table.
Returns
-----------
row_data : list
A list containing all of the relevant data present in the AdlerData object.
required_columns : list of str
A list of the corresponding column names in the same order.
"""
required_columns = ["ssObjectId", "timestamp"]
row_data = [self.ssObjectId, str(datetime.now(timezone.utc))]

for f, filter_name in enumerate(self.filter_list):
columns_by_filter = ["_".join([filter_name, filter_key]) for filter_key in FILTER_DEPENDENT_KEYS]
data_by_filter = [
getattr(self.filter_dependent_values[f], filter_key) for filter_key in FILTER_DEPENDENT_KEYS
]

required_columns.extend(columns_by_filter)
row_data.extend(data_by_filter)

for m, model_name in enumerate(self.filter_dependent_values[f].model_list):
columns_by_model = [
"_".join([filter_name, model_name, model_key]) for model_key in MODEL_DEPENDENT_KEYS
]
data_by_model = [
getattr(self.filter_dependent_values[f].model_dependent_values[m], model_key)
for model_key in MODEL_DEPENDENT_KEYS
]

required_columns.extend(columns_by_model)
row_data.extend(data_by_model)

return row_data, required_columns

def _ensure_columns(self, con, table_name, current_columns, required_columns):
"""Creates new columns in a given table of a SQL database as needed by checking the list of current columns against a list
of required columns.
Parameters
-----------
con : sqlite3 Connection object
The connection to the output SQL database.
table_name : str
The name of the relevant table in the database.
current_columns : list of str
A list of the columns already existing in the database table.
required_columns : list of str
A list of the columns needed in the database table.
"""

cur = con.cursor()
for column_name in required_columns:
if column_name not in current_columns:
cur.execute(f"""ALTER TABLE {table_name} ADD COLUMN {column_name}""")

def write_row_to_database(self, filepath, table_name="AdlerData"):
"""Writes all of the relevant data contained within the AdlerData object to a timestamped row in a SQLite database.
Parameters
-----------
filepath : path-like object
Filepath with the location of the output SQL database.
table_name : str, optiona
String containing the table name to write the data to. Default is "AdlerData".
"""

con = self._get_database_connection(filepath)

row_data, required_columns = self._get_row_data_and_columns()
current_columns = self._get_database_columns(con, table_name)
self._ensure_columns(con, table_name, current_columns, required_columns)

column_names = ",".join(required_columns)
column_spaces = ",".join(["?"] * len(required_columns))
sql_command = "INSERT INTO %s (%s) values(%s)" % (table_name, column_names, column_spaces)

cur = con.cursor()
cur.execute(sql_command, row_data)
con.commit()
con.close()


@dataclass
class FilterDependentAdler:
Expand Down
32 changes: 30 additions & 2 deletions tests/adler/dataclasses/test_AdlerData.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from numpy.testing import assert_array_equal
import os
import pytest
import numpy as np
import pandas as pd
import sqlite3

from adler.dataclasses.AdlerData import AdlerData
from numpy.testing import assert_array_equal

from adler.dataclasses.AdlerData import AdlerData
from adler.utilities.tests_utilities import get_test_data_filepath

# setting up the AdlerData object to be used for testing

Expand Down Expand Up @@ -126,6 +130,9 @@ def test_populate_phase_parameters():
test_object.populate_phase_parameters("u", model_name="model_1", H=15.0)

# check to make sure filter-dependent parameter is correctly updated (then return it to previous)
test_object.populate_phase_parameters("u", nobs=99)
assert test_object.filter_dependent_values[0].nobs == 99
test_object.populate_phase_parameters("u", nobs=13)

# testing to make sure the correct error messages trigger
with pytest.raises(ValueError) as error_info_1:
Expand Down Expand Up @@ -222,3 +229,24 @@ def test_print_data(capsys):
expected = "Filter: u\nPhase angle minimum: 11.0\nPhase angle range: 12.0\nNumber of observations: 13\nArc: 14.0\nModel: model_1.\n\tH: 15.0\n\tH error: 16.0\n\tPhase parameter 1: 17.0\n\tPhase parameter 1 error: 18.0\n\tPhase parameter 2: nan\n\tPhase parameter 2 error: nan\nModel: model_2.\n\tH: 25.0\n\tH error: 26.0\n\tPhase parameter 1: 27.0\n\tPhase parameter 1 error: 28.0\n\tPhase parameter 2: 29.0\n\tPhase parameter 2 error: 30.0\n\n\nFilter: g\nPhase angle minimum: 31.0\nPhase angle range: 32.0\nNumber of observations: 33\nArc: 34.0\nModel: model_1.\n\tH: 35.0\n\tH error: 36.0\n\tPhase parameter 1: 37.0\n\tPhase parameter 1 error: 38.0\n\tPhase parameter 2: nan\n\tPhase parameter 2 error: nan\n\n\nFilter: r\nPhase angle minimum: 41.0\nPhase angle range: 42.0\nNumber of observations: 43\nArc: 44.0\nModel: model_2.\n\tH: 45.0\n\tH error: 46.0\n\tPhase parameter 1: 47.0\n\tPhase parameter 1 error: 48.0\n\tPhase parameter 2: 49.0\n\tPhase parameter 2 error: 50.0\n\n\n"

assert captured.out == expected


def test_write_row_to_database(tmp_path):
db_location = os.path.join(tmp_path, "test_AdlerData_database.db")
test_object.write_row_to_database(db_location)

con = sqlite3.connect(db_location)
written_data = pd.read_sql_query("SELECT * from AdlerData", con)
con.close()

expected_data_filepath = get_test_data_filepath("test_SQL_database_table.csv")
expected_data = pd.read_csv(expected_data_filepath)

# we don't expect the timestamp column to be the same, obviously
expected_data = expected_data.drop(columns="timestamp")
written_data = written_data.drop(columns="timestamp")

# note that because I'm using Pandas there's some small dtype and np.nan/None stuff to clear up
# but this makes for a quick streamlined test anyway
expected_data = expected_data.replace({np.nan: None})
pd.testing.assert_frame_equal(expected_data, written_data, check_dtype=False)
2 changes: 2 additions & 0 deletions tests/data/test_SQL_database_table.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ssObjectId,timestamp,u_phaseAngle_min,u_phaseAngle_range,u_nobs,u_arc,u_model_1_H,u_model_1_H_err,u_model_1_phase_parameter_1,u_model_1_phase_parameter_1_err,u_model_1_phase_parameter_2,u_model_1_phase_parameter_2_err,u_model_2_H,u_model_2_H_err,u_model_2_phase_parameter_1,u_model_2_phase_parameter_1_err,u_model_2_phase_parameter_2,u_model_2_phase_parameter_2_err,g_phaseAngle_min,g_phaseAngle_range,g_nobs,g_arc,g_model_1_H,g_model_1_H_err,g_model_1_phase_parameter_1,g_model_1_phase_parameter_1_err,g_model_1_phase_parameter_2,g_model_1_phase_parameter_2_err,r_phaseAngle_min,r_phaseAngle_range,r_nobs,r_arc,r_model_2_H,r_model_2_H_err,r_model_2_phase_parameter_1,r_model_2_phase_parameter_1_err,r_model_2_phase_parameter_2,r_model_2_phase_parameter_2_err
666,2024-04-18 13:32:07.096776+00:00,11.0,12.0,13,14.0,15.0,16.0,17.0,18.0,,,25.0,26.0,27.0,28.0,29.0,30.0,31.0,32.0,33,34.0,35.0,36.0,37.0,38.0,,,41.0,42.0,43,44.0,45.0,46.0,47.0,48.0,49.0,50.0

0 comments on commit 558686c

Please sign in to comment.