From 1214d7514e868f203935a1cd22be794f8c9b2efc Mon Sep 17 00:00:00 2001 From: Amrit K Date: Wed, 13 Dec 2023 10:39:49 -0500 Subject: [PATCH 1/3] Add support to query DB files using sqlite --- .pre-commit-config.yaml | 2 +- cycquery/base.py | 39 ++++----- cycquery/orm.py | 136 ++++++++++++++++++++++---------- cycquery/post_process/gemini.py | 3 +- cycquery/util.py | 20 ++--- tests/cycquery/test_orm.py | 90 ++++++++++++++++++++- 6 files changed, 218 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6437a6..b0b62a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.0' + rev: 'v0.1.7' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/cycquery/base.py b/cycquery/base.py index dfedee3..cfaae09 100644 --- a/cycquery/base.py +++ b/cycquery/base.py @@ -76,18 +76,19 @@ class DatasetQuerier: Parameters ---------- - database - Name of database. - user - Username for database. - password - Password for database. - dbms - Database management system. - host - Hostname of database. - port - Port of database. + dbms : str + The database management system type (e.g., 'postgresql', 'mysql', 'sqlite'). + user : str, optional + The username for the database, by default empty. Not used for SQLite. + pwd : str, optional + The password for the database, by default empty. Not used for SQLite. + host : str, optional + The host address of the database, by default empty. Not used for SQLite. + port : int, optional + The port number for the database, by default None. Not used for SQLite. + database : str, optional + The name of the database or the path to the database file (for SQLite), + by default empty. Notes ----- @@ -102,12 +103,12 @@ class DatasetQuerier: def __init__( self, - database: str, - user: str, - password: str, - dbms: str = "postgresql", - host: str = "localhost", - port: int = 5432, + dbms: str, + user: str = "", + password: str = "", + host: str = "", + port: Optional[int] = None, + database: str = "", ) -> None: config = DatasetQuerierConfig( database=database, @@ -258,7 +259,7 @@ def _template_table_method( A query interface object. """ - table = getattr(getattr(self.db, schema_name), table_name).data + table = getattr(getattr(self.db, schema_name), table_name).data_ table = _to_subquery(table) return QueryInterface(self.db, table) diff --git a/cycquery/orm.py b/cycquery/orm.py index d4dab6e..75f3091 100644 --- a/cycquery/orm.py +++ b/cycquery/orm.py @@ -41,13 +41,60 @@ def _get_db_url( dbms: str, - user: str, - pwd: str, - host: str, - port: int, - database: str, + user: str = "", + pwd: str = "", + host: str = "", + port: Optional[int] = None, + database: str = "", ) -> str: - """Combine to make Database URL string.""" + """ + Generate a database connection URL. + + This function constructs a URL for database connection, which is compatible + with various database management systems (DBMS), including support for SQLite + database files. + + Parameters + ---------- + dbms : str + The database management system type (e.g., 'postgresql', 'mysql', 'sqlite'). + user : str, optional + The username for the database, by default empty. Not used for SQLite. + pwd : str, optional + The password for the database, by default empty. Not used for SQLite. + host : str, optional + The host address of the database, by default empty. Not used for SQLite. + port : int, optional + The port number for the database, by default None. Not used for SQLite. + database : str, optional + The name of the database or the path to the database file (for SQLite), + by default empty. + + Returns + ------- + str + A string representing the database connection URL. For SQLite, + it returns a URL in the format 'sqlite:///path_to_database.db'. + For other DBMS types, it returns a URL in the + format 'dbms://user:password@host:port/database'. + + Examples + -------- + >>> _get_db_url('postgresql', 'user', 'pass', 'localhost', 5432, 'mydatabase') + 'postgresql://user:pass@localhost:5432/mydatabase' + + >>> _get_db_url('sqlite', database='path_to_database.db') + 'sqlite:///path_to_database.db' + + """ + if dbms.lower() not in ["postgresql", "mysql", "sqlite"]: + raise ValueError( + f"Database management system '{dbms}' is not supported, " + f"please use one of 'postgresql', 'mysql', or 'sqlite'.", + ) + if dbms.lower() == "sqlite": + return f"sqlite:///{database}" # SQLite expects a file path as the database parameter + return f"{dbms}://{user}:{quote_plus(pwd)}@{host}:{str(port)}/{database}" @@ -57,27 +104,28 @@ class DatasetQuerierConfig: Attributes ---------- - dbms - Database management system. - host - Hostname of database. - port - Port of database. - database - Name of database. - user - Username for database. - password - Password for database. + dbms : str + The database management system type (e.g., 'postgresql', 'mysql', 'sqlite'). + user : str, optional + The username for the database, by default empty. Not used for SQLite. + pwd : str, optional + The password for the database, by default empty. Not used for SQLite. + host : str, optional + The host address of the database, by default empty. Not used for SQLite. + port : int, optional + The port number for the database, by default None. Not used for SQLite. + database : str, optional + The name of the database or the path to the database file (for SQLite), + by default empty. """ - database: str - user: str - password: str - dbms: str = "postgresql" - host: str = "localhost" - port: int = 5432 + dbms: str + user: str = "" + password: str = "" + host: str = "" + port: Optional[int] = None + database: str = "" class Database: @@ -110,18 +158,26 @@ def __init__(self, config: DatasetQuerierConfig) -> None: self.config = config self.is_connected = False - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(SOCKET_CONNECTION_TIMEOUT) - try: - is_port_open = sock.connect_ex((self.config.host, self.config.port)) - except socket.gaierror: - LOGGER.error("""Server name not known, cannot establish connection!""") - return - if is_port_open: - LOGGER.error( - """Valid server host but port seems open, check if server is up!""", - ) - return + # Check if server is up or database file exists. + if self.config.dbms.lower() == "sqlite": + if not os.path.exists(self.config.database): + LOGGER.error( + f"""Database file '{self.config.database}' does not exist!""", + ) + return + else: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(SOCKET_CONNECTION_TIMEOUT) + try: + is_port_open = sock.connect_ex((self.config.host, self.config.port)) + except socket.gaierror: + LOGGER.error("""Server name not known, cannot establish connection!""") + return + if is_port_open: + LOGGER.error( + """Valid server host but port seems open, check if server is up!""", + ) + return self.engine = self._create_engine() self.session = self._create_session() @@ -185,10 +241,10 @@ def _setup(self) -> None: table = DBTable(table_name, meta[schema_name].tables[table_name]) for column in meta[schema_name].tables[table_name].columns: setattr(table, column.name, column) - if not isinstance(table.name, str): - table.name = str(table.name) - self._tables.append(table.name) - setattr(schema, get_attr_name(table.name), table) + if not isinstance(table.name_, str): + table.name_ = str(table.name_) + self._tables.append(table.name_) + setattr(schema, get_attr_name(table.name_), table) setattr(self, schema_name, schema) @time_function diff --git a/cycquery/post_process/gemini.py b/cycquery/post_process/gemini.py index 6fa5e16..89cb416 100644 --- a/cycquery/post_process/gemini.py +++ b/cycquery/post_process/gemini.py @@ -1,7 +1,8 @@ """Post-processing functions applied to queried GEMINI data (Pandas DataFrames).""" import pandas as pd -from cyclops.query.post_process.util import process_care_unit_changepoints + +from cycquery.post_process.util import process_care_unit_changepoints CARE_UNIT_HIERARCHY = [ diff --git a/cycquery/util.py b/cycquery/util.py index a246298..32a459c 100644 --- a/cycquery/util.py +++ b/cycquery/util.py @@ -35,15 +35,15 @@ class DBSchema: Parameters ---------- - name: str + name_: str Name of schema. - data: sqlalchemy.sql.schema.MetaData + _data: sqlalchemy.sql.schema.MetaData Metadata for schema. """ - name: str - data: sqlalchemy.sql.schema.MetaData + name_: str + data_: sqlalchemy.sql.schema.MetaData @dataclass @@ -52,15 +52,15 @@ class DBTable: Parameters ---------- - name: str + name_: str Name of table. - data: sqlalchemy.sql.schema.Table + data_: sqlalchemy.sql.schema.Table Metadata for schema. """ - name: str - data: sqlalchemy.sql.schema.MetaData + name_: str + data_: sqlalchemy.sql.schema.MetaData TABLE_OBJECTS = [Table, Select, Subquery, DBTable] @@ -91,7 +91,7 @@ def _to_subquery(table: TableTypes) -> Subquery: return select(table).subquery() if isinstance(table, DBTable): - return select(table.data).subquery() + return select(table.data_).subquery() raise ValueError( f"""Table has type {type(table)}, but must have one of the @@ -123,7 +123,7 @@ def _to_select(table: TableTypes) -> Select: return select(table) if isinstance(table, DBTable): - return select(table.data) + return select(table.data_) raise ValueError( f"""Table has type {type(table)}, but must have one of the diff --git a/tests/cycquery/test_orm.py b/tests/cycquery/test_orm.py index 794081c..401d2c2 100644 --- a/tests/cycquery/test_orm.py +++ b/tests/cycquery/test_orm.py @@ -1,11 +1,99 @@ """Test cyclops.query.orm module.""" import os +import sqlite3 import pandas as pd import pytest -from cycquery import OMOPQuerier +from cycquery import DatasetQuerier, OMOPQuerier +from cycquery.orm import _get_db_url + + +# Function to create and populate the database +def create_dummy_database(db_file): + """Create dummy database file.""" + conn = sqlite3.connect(db_file) + cursor = conn.cursor() + + # Create a table + cursor.execute( + """CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER + )""", + ) + + # Insert dummy data + dummy_data = [(1, "Alice", 30), (2, "Bob", 25), (3, "Charlie", 35)] + cursor.executemany( + "INSERT INTO test_table (id, name, age) VALUES (?, ?, ?)", + dummy_data, + ) + + # Save (commit) the changes and close the connection + conn.commit() + conn.close() + + +def test_dataset_querier(): + """Test DatasetQuerier.""" + db_file = "test_database.db" + + # Ensure database file doesn't exist before test + if os.path.exists(db_file): + os.remove(db_file) + + create_dummy_database(db_file) + + # Test DatasetQuerier + querier = DatasetQuerier( + dbms="sqlite", + database=db_file, + ) + assert querier is not None + test_table = querier.main.test_table().run() + assert len(test_table) == 3 + assert test_table["name"].tolist() == ["Alice", "Bob", "Charlie"] + assert test_table["age"].tolist() == [30, 25, 35] + + # Clean up: remove the database file after testing + os.remove(db_file) + + +def test_get_db_url(): + """Test _get_db_url.""" + # Test for a typical SQL database (e.g., PostgreSQL, MySQL) + assert ( + _get_db_url("postgresql", "user", "pass", "localhost", 5432, "mydatabase") + == "postgresql://user:pass@localhost:5432/mydatabase" + ) + assert ( + _get_db_url("mysql", "root", "rootpass", "dbhost", 3306, "somedb") + == "mysql://root:rootpass@dbhost:3306/somedb" + ) + + # Test for SQLite database file + assert _get_db_url("sqlite", database="mydatabase.db") == "sqlite:///mydatabase.db" + + # Test handling of empty parameters for typical SQL databases + assert _get_db_url("mysql", "", "", "", None, "") == "mysql://:@:None/" + + # Test handling of None for port + assert ( + _get_db_url("postgresql", "user", "pass", "localhost", None, "mydatabase") + == "postgresql://user:pass@localhost:None/mydatabase" + ) + + # Test case insensitivity for DBMS + assert _get_db_url("SQLITE", database="mydatabase.db") == "sqlite:///mydatabase.db" + + # Test for incorrect usage + with pytest.raises( + ValueError, + ): + _get_db_url("unknown_dbms", "user", "pass", "localhost", 1234, "mydatabase") @pytest.mark.integration_test() From 4683af8a4c2874556ae27d80b3d2285612d9dd97 Mon Sep 17 00:00:00 2001 From: Amrit K Date: Wed, 13 Dec 2023 10:50:51 -0500 Subject: [PATCH 2/3] Fix integration tests, and attribute access --- cycquery/base.py | 4 ++-- cycquery/orm.py | 6 ++++++ tests/cycquery/test_base.py | 2 +- tests/cycquery/test_eicu.py | 3 +++ tests/cycquery/test_mimiciv.py | 3 +++ 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/cycquery/base.py b/cycquery/base.py index cfaae09..4006f26 100644 --- a/cycquery/base.py +++ b/cycquery/base.py @@ -177,7 +177,7 @@ def list_columns(self, schema_name: str, table_name: str) -> List[str]: """ return list( - getattr(getattr(self.db, schema_name), table_name).data.columns.keys(), + getattr(getattr(self.db, schema_name), table_name).data_.columns.keys(), ) def list_custom_tables(self) -> List[str]: @@ -232,7 +232,7 @@ def get_table( Table with mapped columns. """ - table = _create_get_table_lambdafn(schema_name, table_name)(self.db).data + table = _create_get_table_lambdafn(schema_name, table_name)(self.db).data_ if cast_timestamp_cols: table = _cast_timestamp_cols(table) diff --git a/cycquery/orm.py b/cycquery/orm.py index 75f3091..068356d 100644 --- a/cycquery/orm.py +++ b/cycquery/orm.py @@ -166,6 +166,12 @@ def __init__(self, config: DatasetQuerierConfig) -> None: ) return else: + if not self.config.host: + LOGGER.error("""No server host provided!""") + return + if not self.config.port: + LOGGER.error("""No server port provided!""") + return sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(SOCKET_CONNECTION_TIMEOUT) try: diff --git a/tests/cycquery/test_base.py b/tests/cycquery/test_base.py index 45d548f..f487372 100644 --- a/tests/cycquery/test_base.py +++ b/tests/cycquery/test_base.py @@ -16,7 +16,7 @@ def test_dataset_querier(): ) assert len(querier.list_tables()) == 69 assert len(querier.list_schemas()) == 4 - assert len(querier.list_tables(schema_name="cdm_synthea10")) == 43 + assert len(querier.list_tables(schema_name="cdm_synthea10")) == 44 visit_occrrence_columns = querier.list_columns("cdm_synthea10", "visit_occurrence") assert len(visit_occrrence_columns) == 17 assert "visit_occurrence_id" in visit_occrrence_columns diff --git a/tests/cycquery/test_eicu.py b/tests/cycquery/test_eicu.py index ab6bc56..cfefeb9 100644 --- a/tests/cycquery/test_eicu.py +++ b/tests/cycquery/test_eicu.py @@ -9,9 +9,12 @@ def test_eicu_querier(): """Test EICUQuerier on eICU-CRD.""" querier = EICUQuerier( + dbms="postgresql", database="eicu", user="postgres", password="pwd", + host="localhost", + port=5432, ) patients = querier.eicu_crd.patient().run(limit=10) diff --git a/tests/cycquery/test_mimiciv.py b/tests/cycquery/test_mimiciv.py index 72cad50..b2479e4 100644 --- a/tests/cycquery/test_mimiciv.py +++ b/tests/cycquery/test_mimiciv.py @@ -9,6 +9,9 @@ def test_mimiciv_querier(): """Test MIMICQuerier on MIMICIV-2.0.""" querier = MIMICIVQuerier( + dbms="postgresql", + host="localhost", + port=5432, database="mimiciv-2.0", user="postgres", password="pwd", From c8f5c66236ca45f5fe12c49639ce13159ddbd3d5 Mon Sep 17 00:00:00 2001 From: Amrit K Date: Wed, 13 Dec 2023 10:55:16 -0500 Subject: [PATCH 3/3] Small fix --- cycquery/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cycquery/util.py b/cycquery/util.py index 32a459c..654b2ea 100644 --- a/cycquery/util.py +++ b/cycquery/util.py @@ -37,7 +37,7 @@ class DBSchema: ---------- name_: str Name of schema. - _data: sqlalchemy.sql.schema.MetaData + data_: sqlalchemy.sql.schema.MetaData Metadata for schema. """