From dd73fd73f4e4cc5abcdb1e1f7ca4ec19f816b760 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 23 Dec 2023 02:15:55 +0100 Subject: [PATCH] Add support for `psycopg` and `asyncpg` drivers This introduces the `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect identifiers. The asynchronous variant of `psycopg` is also supported. --- CHANGES.md | 3 + pyproject.toml | 12 ++- src/sqlalchemy_cratedb/dialect.py | 50 +++++++++-- src/sqlalchemy_cratedb/dialect_more.py | 106 ++++++++++++++++++++++++ tests/conftest.py | 2 +- tests/engine_test.py | 110 +++++++++++++++++++++++++ 6 files changed, 273 insertions(+), 10 deletions(-) create mode 100644 src/sqlalchemy_cratedb/dialect_more.py create mode 100644 tests/engine_test.py diff --git a/CHANGES.md b/CHANGES.md index 92faa7c..3d007f7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,9 @@ # Changelog ## Unreleased +- Added support for `psycopg` and `asyncpg` drivers, by introducing the + `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect + identifiers. The asynchronous variant of `psycopg` is also supported. ## 2024/11/04 0.40.1 - CI: Verified support on Python 3.13 diff --git a/pyproject.toml b/pyproject.toml index 8ff9687..0734d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ dependencies = [ "verlib2==0.2", ] optional-dependencies.all = [ - "sqlalchemy-cratedb[vector]", + "sqlalchemy-cratedb[postgresql,vector]", ] optional-dependencies.develop = [ "mypy<1.14", @@ -102,6 +102,9 @@ optional-dependencies.doc = [ "crate-docs-theme>=0.26.5", "sphinx>=3.5,<9", ] +optional-dependencies.postgresql = [ + "sqlalchemy-postgresql-relaxed", +] optional-dependencies.release = [ "build<2", "twine<6", @@ -112,6 +115,7 @@ optional-dependencies.test = [ "pandas<2.3", "pueblo>=0.0.7", "pytest<9", + "pytest-asyncio<0.24", "pytest-cov<7", "pytest-mock<4", ] @@ -122,7 +126,11 @@ urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES. urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/" urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/" urls.repository = "https://github.com/crate/sqlalchemy-cratedb" -entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect" +entry-points."sqlalchemy.dialects"."crate" = "sqlalchemy_cratedb:dialect" +entry-points."sqlalchemy.dialects"."crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg" +entry-points."sqlalchemy.dialects"."crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg" +entry-points."sqlalchemy.dialects"."crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async" +entry-points."sqlalchemy.dialects"."crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3" [tool.black] line-length = 100 diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index f630aeb..253f060 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -21,6 +21,7 @@ import logging from datetime import date, datetime +from types import ModuleType from sqlalchemy import types as sqltypes from sqlalchemy.engine import default, reflection @@ -212,6 +213,12 @@ def initialize(self, connection): # get default schema name self.default_schema_name = self._get_default_schema_name(connection) + def set_isolation_level(self, dbapi_connection, level): + """ + For CrateDB, this is implemented as a noop. + """ + pass + def do_rollback(self, connection): # if any exception is raised by the dbapi, sqlalchemy by default # attempts to do a rollback crate doesn't support rollbacks. @@ -230,7 +237,21 @@ def connect(self, host=None, port=None, *args, **kwargs): use_ssl = asbool(kwargs.pop("ssl", False)) if use_ssl: servers = ["https://" + server for server in servers] - return self.dbapi.connect(servers=servers, **kwargs) + + is_module = isinstance(self.dbapi, ModuleType) + if is_module: + driver_name = self.dbapi.__name__ + else: + driver_name = self.dbapi.__class__.__name__ + if driver_name == "crate.client": + if "database" in kwargs: + del kwargs["database"] + return self.dbapi.connect(servers=servers, **kwargs) + elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]: + return self.dbapi.connect(host=host, port=port, **kwargs) + else: + raise ValueError(f"Unknown driver variant: {driver_name}") + return self.dbapi.connect(**kwargs) def do_execute(self, cursor, statement, parameters, context=None): @@ -300,10 +321,12 @@ def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = self._get_effective_schema_name(connection) cursor = connection.exec_driver_sql( - "SELECT table_name FROM information_schema.tables " - "WHERE {0} = ? " - "AND table_type = 'BASE TABLE' " - "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), + self._format_query( + "SELECT table_name FROM information_schema.tables " + "WHERE {0} = ? " + "AND table_type = 'BASE TABLE' " + "ORDER BY table_name ASC, {0} ASC" + ).format(self.schema_column), (schema or self.default_schema_name,), ) return [row[0] for row in cursor.fetchall()] @@ -326,7 +349,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): "AND column_name !~ ?".format(self.schema_column) ) cursor = connection.exec_driver_sql( - query, + self._format_query(query), ( table_name, schema or self.default_schema_name, @@ -366,7 +389,9 @@ def result_fun(result): rows = result.fetchone() return set(rows[0] if rows else []) - pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name)) + pk_result = engine.exec_driver_sql( + self._format_query(query), (table_name, schema or self.default_schema_name) + ) pks = result_fun(pk_result) return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"} @@ -405,6 +430,17 @@ def has_ilike_operator(self): server_version_info = self.server_version_info return server_version_info is not None and server_version_info >= (4, 1, 0) + def _format_query(self, query): + """ + When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`, + the paramstyle is not `qmark`, but `pyformat`. + + TODO: Review: Is it legit and sane? Are there alternatives? + """ + if self.paramstyle == "pyformat": + query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s") + return query + class DateTrunc(functions.GenericFunction): name = "date_trunc" diff --git a/src/sqlalchemy_cratedb/dialect_more.py b/src/sqlalchemy_cratedb/dialect_more.py new file mode 100644 index 0000000..0263012 --- /dev/null +++ b/src/sqlalchemy_cratedb/dialect_more.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed +from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed +from sqlalchemy_postgresql_relaxed.psycopg import ( + PGDialect_psycopg_relaxed, + PGDialectAsync_psycopg_relaxed, +) + +from sqlalchemy_cratedb import dialect + + +class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect): + """ + Provide a dialect on top of the relaxed PostgreSQL dialect. + """ + + inspector = Inspector + + # Need to manually override some methods because of polymorphic inheritance woes. + # TODO: Investigate if this can be solved using metaprogramming or other techniques. + has_schema = dialect.has_schema + has_table = dialect.has_table + get_schema_names = dialect.get_schema_names + get_table_names = dialect.get_table_names + get_view_names = dialect.get_view_names + get_columns = dialect.get_columns + get_pk_constraint = dialect.get_pk_constraint + get_foreign_keys = dialect.get_foreign_keys + get_indexes = dialect.get_indexes + + get_multi_columns = dialect.get_multi_columns + get_multi_pk_constraint = dialect.get_multi_pk_constraint + get_multi_foreign_keys = dialect.get_multi_foreign_keys + + # TODO: Those may want to go to dialect instead? + def get_multi_indexes(self, *args, **kwargs): + return [] + + def get_multi_unique_constraints(self, *args, **kwargs): + return [] + + def get_multi_check_constraints(self, *args, **kwargs): + return [] + + def get_multi_table_comment(self, *args, **kwargs): + return [] + + +class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter): + driver = "psycopg" + + @classmethod + def get_async_dialect_cls(cls, url): + return CrateDialectAsync_psycopg + + @classmethod + def import_dbapi(cls): + import psycopg + + return psycopg + + +class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter): + driver = "psycopg_async" + is_async = True + + +class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter): + driver = "asyncpg" + + # TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this! + + # TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle' + """ + @classmethod + def import_dbapi(cls): + import asyncpg + + return asyncpg + """ + + +dialect_urllib3 = dialect +dialect_psycopg = CrateDialect_psycopg +dialect_psycopg_async = CrateDialectAsync_psycopg +dialect_asyncpg = CrateDialect_asyncpg diff --git a/tests/conftest.py b/tests/conftest.py index 88b10d9..5b54603 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,6 @@ def cratedb_service(): Provide a CrateDB service instance to the test suite. """ db = CrateDBTestAdapter() - db.start() + db.start(ports={4200: None, 5432: None}) yield db db.stop() diff --git a/tests/engine_test.py b/tests/engine_test.py new file mode 100644 index 0000000..a9ffbf2 --- /dev/null +++ b/tests/engine_test.py @@ -0,0 +1,110 @@ +import pytest +import sqlalchemy as sa +from sqlalchemy.dialects import registry as dialect_registry + +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION + +if SA_VERSION < SA_2_0: + raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True) + +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +# Registering the additional dialects manually seems to be needed when running +# under tests. Apparently, manual registration is not needed under regular +# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint +# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie. +dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3") +dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg") +dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg") + + +QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;") + + +def test_engine_sync_vanilla(cratedb_service): + """ + crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver. + """ + port4200 = cratedb_service.cratedb.get_exposed_port(4200) + engine = sa.create_engine(f"crate://crate@localhost:{port4200}/", echo=True) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == { + "mountain": "Acherkogel", + "coordinates": [10.95667, 47.18917], + } + + +def test_engine_sync_urllib3(cratedb_service): + """ + crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver. + """ # noqa: E501 + port4200 = cratedb_service.cratedb.get_exposed_port(4200) + engine = sa.create_engine( + f"crate+urllib3://crate@localhost:{port4200}/", isolation_level="AUTOCOMMIT", echo=True + ) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == { + "mountain": "Acherkogel", + "coordinates": [10.95667, 47.18917], + } + + +def test_engine_sync_psycopg(cratedb_service): + """ + crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3). + """ + port5432 = cratedb_service.cratedb.get_exposed_port(5432) + engine = sa.create_engine( + f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True + ) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == { + "mountain": "Acherkogel", + "coordinates": "(10.95667,47.18917)", + } + + +@pytest.mark.asyncio +async def test_engine_async_psycopg(cratedb_service): + """ + crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3). + This time, in asynchronous mode. + """ + port5432 = cratedb_service.cratedb.get_exposed_port(5432) + engine = create_async_engine( + f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True + ) + assert isinstance(engine, AsyncEngine) + async with engine.begin() as conn: + result = await conn.execute(QUERY) + assert result.mappings().fetchone() == { + "mountain": "Acherkogel", + "coordinates": "(10.95667,47.18917)", + } + + +@pytest.mark.asyncio +async def test_engine_async_asyncpg(cratedb_service): + """ + crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver. + This exclusively uses asynchronous mode. + """ + port5432 = cratedb_service.cratedb.get_exposed_port(5432) + from asyncpg.pgproto.types import Point + + engine = create_async_engine( + f"crate+asyncpg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True + ) + assert isinstance(engine, AsyncEngine) + async with engine.begin() as conn: + result = await conn.execute(QUERY) + assert result.mappings().fetchone() == { + "mountain": "Acherkogel", + "coordinates": Point(10.95667, 47.18917), + }