Skip to content

Commit

Permalink
switch to SQLAlchemy 2 (#198)
Browse files Browse the repository at this point in the history
Drop dependency on sqlalchemy-aio, replace with local worker thread implementation.
  • Loading branch information
albertodonato authored Nov 2, 2024
1 parent ec9893a commit 6bf0bb4
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 175 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ dependencies = [
"prometheus-client",
"python-dateutil",
"pyyaml",
"sqlalchemy<1.4",
"sqlalchemy-aio>=0.17",
"sqlalchemy>=2",
"toolrack>=4",
]
optional-dependencies.testing = [
Expand Down
236 changes: 187 additions & 49 deletions query_exporter/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
"""Database wrapper."""

import asyncio
from collections.abc import Iterable
from collections.abc import (
Callable,
Iterable,
Sequence,
)
from concurrent import futures
from dataclasses import (
dataclass,
field,
)
from functools import partial
from itertools import chain
import logging
import sys
from threading import (
Thread,
current_thread,
)
from time import (
perf_counter,
time,
Expand All @@ -28,16 +38,17 @@
event,
text,
)
from sqlalchemy.engine import (
Connection,
CursorResult,
Engine,
Row,
)
from sqlalchemy.exc import (
ArgumentError,
NoSuchModuleError,
)
from sqlalchemy_aio import ASYNCIO_STRATEGY
from sqlalchemy_aio.asyncio import AsyncioEngine
from sqlalchemy_aio.base import (
AsyncConnection,
AsyncResultProxy,
)
from sqlalchemy.sql.elements import TextClause

#: Timeout for a query
QueryTimeout = int | float
Expand Down Expand Up @@ -116,7 +127,7 @@ def __init__(self, query_name: str, message: str) -> None:
)


# database errors that mean the query won't ever succeed. Not all possible
# Database errors that mean the query won't ever succeed. Not all possible
# fatal errors are tracked here, because some DBAPI errors can happen in
# circumstances which can be fatal or not. Since there doesn't seem to be a
# reliable way to know, there might be cases when a query will never succeed
Expand All @@ -140,7 +151,7 @@ def __post_init__(self) -> None:
create_db_engine(self.dsn)


def create_db_engine(dsn: str, **kwargs: Any) -> AsyncioEngine:
def create_db_engine(dsn: str, **kwargs: Any) -> Engine:
"""Create the database engine, validating the DSN"""
try:
return create_engine(dsn, **kwargs)
Expand All @@ -161,22 +172,20 @@ class QueryResults(NamedTuple):
"""Results of a database query."""

keys: list[str]
rows: list[tuple[Any]]
rows: Sequence[Row[Any]]
timestamp: float | None = None
latency: float | None = None

@classmethod
async def from_results(cls, results: AsyncResultProxy) -> Self:
def from_result(cls, result: CursorResult[Any]) -> Self:
"""Return a QueryResults from results for a query."""
timestamp = time()
conn_info = results._result_proxy.connection.info
latency = conn_info.get("query_latency", None)
return cls(
await results.keys(),
await results.fetchall(),
timestamp=timestamp,
latency=latency,
)
keys: list[str] = []
rows: Sequence[Row[Any]] = []
if result.returns_rows:
keys, rows = list(result.keys()), result.all()
latency = result.connection.info.get("query_latency", None)
return cls(keys, rows, timestamp=timestamp, latency=latency)


class MetricResult(NamedTuple):
Expand Down Expand Up @@ -276,11 +285,149 @@ def _check_query_parameters(self) -> None:
raise InvalidQueryParameters(self.name)


class WorkerAction:
"""An action to be called in the worker thread."""

def __init__(
self, func: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
self._func = partial(func, *args, **kwargs)
self._loop = asyncio.get_event_loop()
self._future = self._loop.create_future()

def __str__(self) -> str:
return self._func.func.__name__

def __call__(self) -> None:
"""Call the action asynchronously in a thread-safe way."""
try:
result = self._func()
except Exception as e:
self._call_threadsafe(self._future.set_exception, e)
else:
self._call_threadsafe(self._future.set_result, result)

async def result(self) -> Any:
"""Wait for completion and return the action result."""
return await self._future

def _call_threadsafe(self, call: Callable[..., Any], *args: Any) -> None:
self._loop.call_soon_threadsafe(partial(call, *args))


class DataBaseConnection:
"""A connection to a database engine."""

_conn: Connection | None = None
_worker: Thread | None = None

def __init__(
self,
dbname: str,
engine: Engine,
logger: logging.Logger = logging.getLogger(),
) -> None:
self.dbname = dbname
self.engine = engine
self.logger = logger
self._loop = asyncio.get_event_loop()
self._queue: asyncio.Queue[WorkerAction] = asyncio.Queue()

@property
def connected(self) -> bool:
"""Whether the connection is open."""
return self._conn is not None

async def open(self) -> None:
"""Open the connection."""
if self.connected:
return

self._create_worker()
await self._call_in_thread(self._connect)

async def close(self) -> None:
"""Close the connection."""
if not self.connected:
return

await self._call_in_thread(self._close)
self._terminate_worker()

async def execute(
self,
sql: TextClause,
parameters: dict[str, Any] | None = None,
) -> QueryResults:
"""Execute a query, returning results."""
if parameters is None:
parameters = {}
result = await self._call_in_thread(self._execute, sql, parameters)
query_results: QueryResults = await self._call_in_thread(
QueryResults.from_result, result
)
return query_results

def _create_worker(self) -> None:
assert not self._worker
self._worker = Thread(
target=self._run, name=f"DataBase-{self.dbname}", daemon=True
)
self._worker.start()

def _terminate_worker(self) -> None:
assert self._worker
self._worker.join()
self._worker = None

def _connect(self) -> None:
self._conn = self.engine.connect()

def _execute(
self, sql: TextClause, parameters: dict[str, Any]
) -> CursorResult[Any]:
assert self._conn
return self._conn.execute(sql, parameters)

def _close(self) -> None:
assert self._conn
self._conn.detach()
self._conn.close()
self._conn = None

def _run(self) -> None:
"""The worker thread function."""

def debug(message: str) -> None:
self.logger.debug(f'worker "{current_thread().name}": {message}')

debug(f"started with ID {current_thread().native_id}")
while True:
future = asyncio.run_coroutine_threadsafe(
self._queue.get(), self._loop
)
action = future.result()
debug(f'received action "{action}"')
action()
self._loop.call_soon_threadsafe(self._queue.task_done)
if self._conn is None:
# the connection has been closed, exit the thread
debug("shutting down")
return

async def _call_in_thread(
self, func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
"""Call a sync action in the worker thread."""
call = WorkerAction(func, *args, **kwargs)
await self._queue.put(call)
return await call.result()


class DataBase:
"""A database to perform Queries."""

_engine: AsyncioEngine
_conn: AsyncConnection | None = None
_conn: DataBaseConnection
_pending_queries: int = 0

def __init__(
Expand All @@ -291,27 +438,32 @@ def __init__(
self.config = config
self.logger = logger
self._connect_lock = asyncio.Lock()
self._engine = create_db_engine(
execution_options = {}
if self.config.autocommit:
execution_options["isolation_level"] = "AUTOCOMMIT"
engine = create_db_engine(
self.config.dsn,
strategy=ASYNCIO_STRATEGY,
execution_options={"autocommit": self.config.autocommit},
execution_options=execution_options,
)

self._setup_query_latency_tracking()
self._conn = DataBaseConnection(self.config.name, engine, self.logger)
self._setup_query_latency_tracking(engine)

async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(
self, exc_type: type, exc_value: Exception, traceback: TracebackType
self,
exc_type: type,
exc_value: Exception,
traceback: TracebackType,
) -> None:
await self.close()

@property
def connected(self) -> bool:
"""Whether the database is connected."""
return self._conn is not None
return self._conn.connected

async def connect(self) -> None:
"""Connect to the database."""
Expand All @@ -320,7 +472,7 @@ async def connect(self) -> None:
return

try:
self._conn = await self._engine.connect()
await self._conn.open()
except Exception as error:
raise self._db_error(error, exc_class=DataBaseConnectError)

Expand Down Expand Up @@ -349,10 +501,11 @@ async def execute(self, query: Query) -> MetricResults:
f'running query "{query.name}" on database "{self.config.name}"'
)
self._pending_queries += 1
self._conn: AsyncConnection
try:
result = await self._execute_query(query)
return query.results(await QueryResults.from_results(result))
query_results = await self.execute_sql(
query.sql, parameters=query.parameters, timeout=query.timeout
)
return query.results(query_results)
except TimeoutError:
raise self._query_timeout_error(
query.name, cast(QueryTimeout, query.timeout)
Expand All @@ -372,34 +525,19 @@ async def execute_sql(
sql: str,
parameters: dict[str, Any] | None = None,
timeout: QueryTimeout | None = None,
) -> AsyncResultProxy:
) -> QueryResults:
"""Execute a raw SQL query."""
if parameters is None:
parameters = {}
self._conn: AsyncConnection
return await asyncio.wait_for(
self._conn.execute(text(sql), parameters),
timeout=timeout,
)

async def _execute_query(self, query: Query) -> AsyncResultProxy:
"""Execute a query."""
return await self.execute_sql(
query.sql, parameters=query.parameters, timeout=query.timeout
)

async def _close(self) -> None:
# ensure the connection with the DB is actually closed
self._conn: AsyncConnection
self._conn.sync_connection.detach()
await self._conn.close()
self._conn = None
self._pending_queries = 0
self.logger.debug(f'disconnected from database "{self.config.name}"')

def _setup_query_latency_tracking(self) -> None:
engine = self._engine.sync_engine

def _setup_query_latency_tracking(self, engine: Engine) -> None:
@event.listens_for(engine, "before_cursor_execute") # type: ignore
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
Expand Down
Loading

0 comments on commit 6bf0bb4

Please sign in to comment.