Skip to content

Commit

Permalink
Add support to query DB files using sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Dec 13, 2023
1 parent a227009 commit 1214d75
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.1.0'
rev: 'v0.1.7'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
39 changes: 20 additions & 19 deletions cycquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ class DatasetQuerier:
Parameters
----------
database
Name of database.
user
Username for database.
password
Password for database.
dbms
Database management system.
host
Hostname of database.
port
Port of database.
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
Notes
-----
Expand All @@ -102,12 +103,12 @@ class DatasetQuerier:

def __init__(
self,
database: str,
user: str,
password: str,
dbms: str = "postgresql",
host: str = "localhost",
port: int = 5432,
dbms: str,
user: str = "",
password: str = "",
host: str = "",
port: Optional[int] = None,
database: str = "",
) -> None:
config = DatasetQuerierConfig(
database=database,
Expand Down Expand Up @@ -258,7 +259,7 @@ def _template_table_method(
A query interface object.
"""
table = getattr(getattr(self.db, schema_name), table_name).data
table = getattr(getattr(self.db, schema_name), table_name).data_
table = _to_subquery(table)

return QueryInterface(self.db, table)
Expand Down
136 changes: 96 additions & 40 deletions cycquery/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,60 @@

def _get_db_url(
dbms: str,
user: str,
pwd: str,
host: str,
port: int,
database: str,
user: str = "",
pwd: str = "",
host: str = "",
port: Optional[int] = None,
database: str = "",
) -> str:
"""Combine to make Database URL string."""
"""
Generate a database connection URL.
This function constructs a URL for database connection, which is compatible
with various database management systems (DBMS), including support for SQLite
database files.
Parameters
----------
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
Returns
-------
str
A string representing the database connection URL. For SQLite,
it returns a URL in the format 'sqlite:///path_to_database.db'.
For other DBMS types, it returns a URL in the
format 'dbms://user:password@host:port/database'.
Examples
--------
>>> _get_db_url('postgresql', 'user', 'pass', 'localhost', 5432, 'mydatabase')
'postgresql://user:pass@localhost:5432/mydatabase'
>>> _get_db_url('sqlite', database='path_to_database.db')
'sqlite:///path_to_database.db'
"""
if dbms.lower() not in ["postgresql", "mysql", "sqlite"]:
raise ValueError(
f"Database management system '{dbms}' is not supported, "
f"please use one of 'postgresql', 'mysql', or 'sqlite'.",
)
if dbms.lower() == "sqlite":
return f"sqlite:///{database}" # SQLite expects a file path as the database parameter

return f"{dbms}://{user}:{quote_plus(pwd)}@{host}:{str(port)}/{database}"


Expand All @@ -57,27 +104,28 @@ class DatasetQuerierConfig:
Attributes
----------
dbms
Database management system.
host
Hostname of database.
port
Port of database.
database
Name of database.
user
Username for database.
password
Password for database.
dbms : str
The database management system type (e.g., 'postgresql', 'mysql', 'sqlite').
user : str, optional
The username for the database, by default empty. Not used for SQLite.
pwd : str, optional
The password for the database, by default empty. Not used for SQLite.
host : str, optional
The host address of the database, by default empty. Not used for SQLite.
port : int, optional
The port number for the database, by default None. Not used for SQLite.
database : str, optional
The name of the database or the path to the database file (for SQLite),
by default empty.
"""

database: str
user: str
password: str
dbms: str = "postgresql"
host: str = "localhost"
port: int = 5432
dbms: str
user: str = ""
password: str = ""
host: str = ""
port: Optional[int] = None
database: str = ""


class Database:
Expand Down Expand Up @@ -110,18 +158,26 @@ def __init__(self, config: DatasetQuerierConfig) -> None:
self.config = config
self.is_connected = False

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(SOCKET_CONNECTION_TIMEOUT)
try:
is_port_open = sock.connect_ex((self.config.host, self.config.port))
except socket.gaierror:
LOGGER.error("""Server name not known, cannot establish connection!""")
return
if is_port_open:
LOGGER.error(
"""Valid server host but port seems open, check if server is up!""",
)
return
# Check if server is up or database file exists.
if self.config.dbms.lower() == "sqlite":
if not os.path.exists(self.config.database):
LOGGER.error(
f"""Database file '{self.config.database}' does not exist!""",
)
return
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(SOCKET_CONNECTION_TIMEOUT)
try:
is_port_open = sock.connect_ex((self.config.host, self.config.port))
except socket.gaierror:
LOGGER.error("""Server name not known, cannot establish connection!""")
return
if is_port_open:
LOGGER.error(
"""Valid server host but port seems open, check if server is up!""",
)
return

self.engine = self._create_engine()
self.session = self._create_session()
Expand Down Expand Up @@ -185,10 +241,10 @@ def _setup(self) -> None:
table = DBTable(table_name, meta[schema_name].tables[table_name])
for column in meta[schema_name].tables[table_name].columns:
setattr(table, column.name, column)
if not isinstance(table.name, str):
table.name = str(table.name)
self._tables.append(table.name)
setattr(schema, get_attr_name(table.name), table)
if not isinstance(table.name_, str):
table.name_ = str(table.name_)
self._tables.append(table.name_)
setattr(schema, get_attr_name(table.name_), table)
setattr(self, schema_name, schema)

@time_function
Expand Down
3 changes: 2 additions & 1 deletion cycquery/post_process/gemini.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Post-processing functions applied to queried GEMINI data (Pandas DataFrames)."""

import pandas as pd
from cyclops.query.post_process.util import process_care_unit_changepoints

from cycquery.post_process.util import process_care_unit_changepoints


CARE_UNIT_HIERARCHY = [
Expand Down
20 changes: 10 additions & 10 deletions cycquery/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class DBSchema:
Parameters
----------
name: str
name_: str
Name of schema.
data: sqlalchemy.sql.schema.MetaData
_data: sqlalchemy.sql.schema.MetaData
Metadata for schema.
"""

name: str
data: sqlalchemy.sql.schema.MetaData
name_: str
data_: sqlalchemy.sql.schema.MetaData


@dataclass
Expand All @@ -52,15 +52,15 @@ class DBTable:
Parameters
----------
name: str
name_: str
Name of table.
data: sqlalchemy.sql.schema.Table
data_: sqlalchemy.sql.schema.Table
Metadata for schema.
"""

name: str
data: sqlalchemy.sql.schema.MetaData
name_: str
data_: sqlalchemy.sql.schema.MetaData


TABLE_OBJECTS = [Table, Select, Subquery, DBTable]
Expand Down Expand Up @@ -91,7 +91,7 @@ def _to_subquery(table: TableTypes) -> Subquery:
return select(table).subquery()

if isinstance(table, DBTable):
return select(table.data).subquery()
return select(table.data_).subquery()

raise ValueError(
f"""Table has type {type(table)}, but must have one of the
Expand Down Expand Up @@ -123,7 +123,7 @@ def _to_select(table: TableTypes) -> Select:
return select(table)

if isinstance(table, DBTable):
return select(table.data)
return select(table.data_)

raise ValueError(
f"""Table has type {type(table)}, but must have one of the
Expand Down
Loading

0 comments on commit 1214d75

Please sign in to comment.