Skip to content

Commit

Permalink
Merge pull request #7 from VectorInstitute/add_support_for_db_file
Browse files Browse the repository at this point in the history
Add support to query DB files using sqlite
  • Loading branch information
amrit110 authored Dec 13, 2023
2 parents a227009 + c8f5c66 commit 1f03ebb
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.1.0'
rev: 'v0.1.7'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
43 changes: 22 additions & 21 deletions cycquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ class DatasetQuerier:
Parameters
----------
database
Name of database.
user
Username for database.
password
Password for database.
dbms
Database management system.
host
Hostname of database.
port
Port of database.
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
Notes
-----
Expand All @@ -102,12 +103,12 @@ class DatasetQuerier:

def __init__(
self,
database: str,
user: str,
password: str,
dbms: str = "postgresql",
host: str = "localhost",
port: int = 5432,
dbms: str,
user: str = "",
password: str = "",
host: str = "",
port: Optional[int] = None,
database: str = "",
) -> None:
config = DatasetQuerierConfig(
database=database,
Expand Down Expand Up @@ -176,7 +177,7 @@ def list_columns(self, schema_name: str, table_name: str) -> List[str]:
"""
return list(
getattr(getattr(self.db, schema_name), table_name).data.columns.keys(),
getattr(getattr(self.db, schema_name), table_name).data_.columns.keys(),
)

def list_custom_tables(self) -> List[str]:
Expand Down Expand Up @@ -231,7 +232,7 @@ def get_table(
Table with mapped columns.
"""
table = _create_get_table_lambdafn(schema_name, table_name)(self.db).data
table = _create_get_table_lambdafn(schema_name, table_name)(self.db).data_

if cast_timestamp_cols:
table = _cast_timestamp_cols(table)
Expand All @@ -258,7 +259,7 @@ def _template_table_method(
A query interface object.
"""
table = getattr(getattr(self.db, schema_name), table_name).data
table = getattr(getattr(self.db, schema_name), table_name).data_
table = _to_subquery(table)

return QueryInterface(self.db, table)
Expand Down
142 changes: 102 additions & 40 deletions cycquery/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,60 @@

def _get_db_url(
dbms: str,
user: str,
pwd: str,
host: str,
port: int,
database: str,
user: str = "",
pwd: str = "",
host: str = "",
port: Optional[int] = None,
database: str = "",
) -> str:
"""Combine to make Database URL string."""
"""
Generate a database connection URL.
This function constructs a URL for database connection, which is compatible
with various database management systems (DBMS), including support for SQLite
database files.
Parameters
----------
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
Returns
-------
str
A string representing the database connection URL. For SQLite,
it returns a URL in the format 'sqlite:///path_to_database.db'.
For other DBMS types, it returns a URL in the
format 'dbms://user:password@host:port/database'.
Examples
--------
>>> _get_db_url('postgresql', 'user', 'pass', 'localhost', 5432, 'mydatabase')
'postgresql://user:pass@localhost:5432/mydatabase'
>>> _get_db_url('sqlite', database='path_to_database.db')
'sqlite:///path_to_database.db'
"""
if dbms.lower() not in ["postgresql", "mysql", "sqlite"]:
raise ValueError(
f"Database management system '{dbms}' is not supported, "
f"please use one of 'postgresql', 'mysql', or 'sqlite'.",
)
if dbms.lower() == "sqlite":
return f"sqlite:///{database}" # SQLite expects a file path as the database parameter

return f"{dbms}://{user}:{quote_plus(pwd)}@{host}:{str(port)}/{database}"


Expand All @@ -57,27 +104,28 @@ class DatasetQuerierConfig:
Attributes
----------
dbms
Database management system.
host
Hostname of database.
port
Port of database.
database
Name of database.
user
Username for database.
password
Password for database.
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
"""

database: str
user: str
password: str
dbms: str = "postgresql"
host: str = "localhost"
port: int = 5432
dbms: str
user: str = ""
password: str = ""
host: str = ""
port: Optional[int] = None
database: str = ""


class Database:
Expand Down Expand Up @@ -110,18 +158,32 @@ def __init__(self, config: DatasetQuerierConfig) -> None:
self.config = config
self.is_connected = False

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(SOCKET_CONNECTION_TIMEOUT)
try:
is_port_open = sock.connect_ex((self.config.host, self.config.port))
except socket.gaierror:
LOGGER.error("""Server name not known, cannot establish connection!""")
return
if is_port_open:
LOGGER.error(
"""Valid server host but port seems open, check if server is up!""",
)
return
# Check if server is up or database file exists.
if self.config.dbms.lower() == "sqlite":
if not os.path.exists(self.config.database):
LOGGER.error(
f"""Database file '{self.config.database}' does not exist!""",
)
return
else:
if not self.config.host:
LOGGER.error("""No server host provided!""")
return
if not self.config.port:
LOGGER.error("""No server port provided!""")
return
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(SOCKET_CONNECTION_TIMEOUT)
try:
is_port_open = sock.connect_ex((self.config.host, self.config.port))
except socket.gaierror:
LOGGER.error("""Server name not known, cannot establish connection!""")
return
if is_port_open:
LOGGER.error(
"""Valid server host but port seems open, check if server is up!""",
)
return

self.engine = self._create_engine()
self.session = self._create_session()
Expand Down Expand Up @@ -185,10 +247,10 @@ def _setup(self) -> None:
table = DBTable(table_name, meta[schema_name].tables[table_name])
for column in meta[schema_name].tables[table_name].columns:
setattr(table, column.name, column)
if not isinstance(table.name, str):
table.name = str(table.name)
self._tables.append(table.name)
setattr(schema, get_attr_name(table.name), table)
if not isinstance(table.name_, str):
table.name_ = str(table.name_)
self._tables.append(table.name_)
setattr(schema, get_attr_name(table.name_), table)
setattr(self, schema_name, schema)

@time_function
Expand Down
3 changes: 2 additions & 1 deletion cycquery/post_process/gemini.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Post-processing functions applied to queried GEMINI data (Pandas DataFrames)."""

import pandas as pd
from cyclops.query.post_process.util import process_care_unit_changepoints

from cycquery.post_process.util import process_care_unit_changepoints


CARE_UNIT_HIERARCHY = [
Expand Down
20 changes: 10 additions & 10 deletions cycquery/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class DBSchema:
Parameters
----------
name: str
name_: str
Name of schema.
data: sqlalchemy.sql.schema.MetaData
data_: sqlalchemy.sql.schema.MetaData
Metadata for schema.
"""

name: str
data: sqlalchemy.sql.schema.MetaData
name_: str
data_: sqlalchemy.sql.schema.MetaData


@dataclass
Expand All @@ -52,15 +52,15 @@ class DBTable:
Parameters
----------
name: str
name_: str
Name of table.
data: sqlalchemy.sql.schema.Table
data_: sqlalchemy.sql.schema.Table
Metadata for schema.
"""

name: str
data: sqlalchemy.sql.schema.MetaData
name_: str
data_: sqlalchemy.sql.schema.MetaData


TABLE_OBJECTS = [Table, Select, Subquery, DBTable]
Expand Down Expand Up @@ -91,7 +91,7 @@ def _to_subquery(table: TableTypes) -> Subquery:
return select(table).subquery()

if isinstance(table, DBTable):
return select(table.data).subquery()
return select(table.data_).subquery()

raise ValueError(
f"""Table has type {type(table)}, but must have one of the
Expand Down Expand Up @@ -123,7 +123,7 @@ def _to_select(table: TableTypes) -> Select:
return select(table)

if isinstance(table, DBTable):
return select(table.data)
return select(table.data_)

raise ValueError(
f"""Table has type {type(table)}, but must have one of the
Expand Down
2 changes: 1 addition & 1 deletion tests/cycquery/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_dataset_querier():
)
assert len(querier.list_tables()) == 69
assert len(querier.list_schemas()) == 4
assert len(querier.list_tables(schema_name="cdm_synthea10")) == 43
assert len(querier.list_tables(schema_name="cdm_synthea10")) == 44
visit_occrrence_columns = querier.list_columns("cdm_synthea10", "visit_occurrence")
assert len(visit_occrrence_columns) == 17
assert "visit_occurrence_id" in visit_occrrence_columns
3 changes: 3 additions & 0 deletions tests/cycquery/test_eicu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
def test_eicu_querier():
"""Test EICUQuerier on eICU-CRD."""
querier = EICUQuerier(
dbms="postgresql",
database="eicu",
user="postgres",
password="pwd",
host="localhost",
port=5432,
)

patients = querier.eicu_crd.patient().run(limit=10)
Expand Down
3 changes: 3 additions & 0 deletions tests/cycquery/test_mimiciv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
def test_mimiciv_querier():
"""Test MIMICQuerier on MIMICIV-2.0."""
querier = MIMICIVQuerier(
dbms="postgresql",
host="localhost",
port=5432,
database="mimiciv-2.0",
user="postgres",
password="pwd",
Expand Down
Loading

0 comments on commit 1f03ebb

Please sign in to comment.