diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index f85ca99a..a46893de 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -18,7 +18,7 @@ jobs: services: mysql: - image: mysql:5.7 + image: mariadb:11 env: MYSQL_USER: username MYSQL_PASSWORD: password @@ -26,10 +26,10 @@ jobs: MYSQL_DATABASE: testsuite ports: - 3306:3306 - options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + options: --health-cmd="mariadb-admin ping" --health-interval=10s --health-timeout=5s --health-retries=3 postgres: - image: postgres:14 + image: postgres:16 env: POSTGRES_USER: username POSTGRES_PASSWORD: password @@ -59,5 +59,6 @@ jobs: mysql+asyncmy://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, - postgresql+asyncpg://username:password@localhost:5432/testsuite + postgresql+asyncpg://username:password@localhost:5432/testsuite, + postgresql+psycopg://username:password@localhost:5432/testsuite run: "scripts/test" diff --git a/README.md b/README.md index f40cd173..edf68e40 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Database drivers supported are: * [asyncpg][asyncpg] * [aiopg][aiopg] +* [psycopg3][psycopg3] * [aiomysql][aiomysql] * [asyncmy][asyncmy] * [aiosqlite][aiosqlite] @@ -42,6 +43,7 @@ You can install the required database drivers with: ```shell $ pip install databases[asyncpg] $ pip install databases[aiopg] +$ pip install databases[psycopg3] $ pip install databases[aiomysql] $ pip install databases[asyncmy] $ pip install databases[aiosqlite] @@ -105,6 +107,7 @@ for examples of how to start using databases together with SQLAlchemy core expre [pymysql]: https://github.com/PyMySQL/PyMySQL [asyncpg]: https://github.com/MagicStack/asyncpg [aiopg]: https://github.com/aio-libs/aiopg +[psycopg3]: https://github.com/psycopg/psycopg [aiomysql]: https://github.com/aio-libs/aiomysql [asyncmy]: https://github.com/long2ice/asyncmy [aiosqlite]: https://github.com/omnilib/aiosqlite diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 0b4d95a3..9928f8b3 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -7,7 +7,6 @@ import aiopg from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement diff --git a/databases/backends/postgres.py b/databases/backends/asyncpg.py similarity index 95% rename from databases/backends/postgres.py rename to databases/backends/asyncpg.py index c42688e1..92ad93b0 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/asyncpg.py @@ -19,7 +19,7 @@ logger = logging.getLogger("databases") -class PostgresBackend(DatabaseBackend): +class AsyncpgBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: @@ -78,12 +78,12 @@ async def disconnect(self) -> None: await self._pool.close() self._pool = None - def connection(self) -> "PostgresConnection": - return PostgresConnection(self, self._dialect) + def connection(self) -> "AsyncpgConnection": + return AsyncpgConnection(self, self._dialect) -class PostgresConnection(ConnectionBackend): - def __init__(self, database: PostgresBackend, dialect: Dialect): +class AsyncpgConnection(ConnectionBackend): + def __init__(self, database: AsyncpgBackend, dialect: Dialect): self._database = database self._dialect = dialect self._connection: typing.Optional[asyncpg.connection.Connection] = None @@ -159,7 +159,7 @@ async def iterate( yield Record(row, result_columns, self._dialect, column_maps) def transaction(self) -> TransactionBackend: - return PostgresTransaction(connection=self) + return AsyncpgTransaction(connection=self) def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( @@ -197,8 +197,8 @@ def raw_connection(self) -> asyncpg.connection.Connection: return self._connection -class PostgresTransaction(TransactionBackend): - def __init__(self, connection: PostgresConnection): +class AsyncpgTransaction(TransactionBackend): + def __init__(self, connection: AsyncpgConnection): self._connection = connection self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py index e963af50..65032fc8 100644 --- a/databases/backends/common/records.py +++ b/databases/backends/common/records.py @@ -1,12 +1,9 @@ -import enum import typing -from datetime import date, datetime, time from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.row import Row as SQLRow from sqlalchemy.sql.compiler import _CompileLabel from sqlalchemy.sql.schema import Column -from sqlalchemy.sql.sqltypes import JSON from sqlalchemy.types import TypeEngine from databases.interfaces import Record as RecordInterface diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py new file mode 100644 index 00000000..527b2600 --- /dev/null +++ b/databases/backends/psycopg.py @@ -0,0 +1,275 @@ +import typing + +import psycopg +import psycopg.adapt +import psycopg.types +import psycopg_pool +from psycopg.rows import namedtuple_row +from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.schema import Column + +from databases.backends.common.records import Record, create_column_maps +from databases.core import DatabaseURL +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record as RecordInterface, + TransactionBackend, +) + +try: + import orjson + + def load(data): + return orjson.loads(data) + + def dump(data): + return orjson.dumps(data) + +except ImportError: + import json + + def load(data): + return json.loads(data.decode("utf-8")) + + def dump(data): + return json.dumps(data).encode("utf-8") + + +class JsonLoader(psycopg.adapt.Loader): + def load(self, data): + return load(data) + + +class JsonDumper(psycopg.adapt.Dumper): + def dump(self, data): + return dump(data) + + +class PsycopgBackend(DatabaseBackend): + _database_url: DatabaseURL + _options: typing.Dict[str, typing.Any] + _dialect: Dialect + _pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None + + def __init__( + self, + database_url: typing.Union[DatabaseURL, str], + **options: typing.Dict[str, typing.Any], + ) -> None: + self._database_url = DatabaseURL(database_url) + self._options = options + self._dialect = PGDialect_psycopg() + self._dialect.implicit_returning = True + + async def connect(self) -> None: + if self._pool is not None: + return + + url = self._database_url._url.replace("postgresql+psycopg", "postgresql") + self._pool = psycopg_pool.AsyncConnectionPool(url, open=False, **self._options) + + # TODO: Add configurable timeouts + await self._pool.open() + + async def disconnect(self) -> None: + if self._pool is None: + return + + # TODO: Add configurable timeouts + await self._pool.close() + self._pool = None + + def connection(self) -> "PsycopgConnection": + return PsycopgConnection(self, self._dialect) + + +class PsycopgConnection(ConnectionBackend): + _database: PsycopgBackend + _dialect: Dialect + _connection: typing.Optional[psycopg.AsyncConnection] = None + + def __init__(self, database: PsycopgBackend, dialect: Dialect) -> None: + self._database = database + self._dialect = dialect + + async def acquire(self) -> None: + if self._connection is not None: + return + + if self._database._pool is None: + raise RuntimeError("PsycopgBackend is not running") + + # TODO: Add configurable timeouts + connection = await self._database._pool.getconn() + connection.adapters.register_loader("json", JsonLoader) + connection.adapters.register_loader("jsonb", JsonLoader) + connection.adapters.register_dumper(dict, JsonDumper) + connection.adapters.register_dumper(list, JsonDumper) + await connection.set_autocommit(True) + self._connection = connection + + async def release(self) -> None: + if self._connection is None: + return + + await self._database._pool.putconn(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + + query_str, args, result_columns = self._compile(query) + + async with self._connection.cursor(row_factory=namedtuple_row) as cursor: + await cursor.execute(query_str, args) + rows = await cursor.fetchall() + + column_maps = create_column_maps(result_columns) + return [ + PsycopgRecord(row, result_columns, self._dialect, column_maps) + for row in rows + ] + + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + + query_str, args, result_columns = self._compile(query) + + async with self._connection.cursor(row_factory=namedtuple_row) as cursor: + await cursor.execute(query_str, args) + row = await cursor.fetchone() + + if row is None: + return None + + return PsycopgRecord( + row, + result_columns, + self._dialect, + create_column_maps(result_columns), + ) + + async def fetch_val( + self, query: ClauseElement, column: typing.Any = 0 + ) -> typing.Any: + row = await self.fetch_one(query) + return None if row is None else row[column] + + async def execute(self, query: ClauseElement) -> typing.Any: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + + query_str, args, _ = self._compile(query) + + async with self._connection.cursor(row_factory=namedtuple_row) as cursor: + await cursor.execute(query_str, args) + + async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + # TODO: Find a way to use psycopg's executemany + for query in queries: + await self.execute(query) + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Mapping, None]: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + + query_str, args, result_columns = self._compile(query) + column_maps = create_column_maps(result_columns) + + async with self._connection.cursor(row_factory=namedtuple_row) as cursor: + await cursor.execute(query_str, args) + + while True: + row = await cursor.fetchone() + + if row is None: + break + + yield PsycopgRecord(row, result_columns, self._dialect, column_maps) + + def transaction(self) -> "TransactionBackend": + return PsycopgTransaction(connection=self) + + @property + def raw_connection(self) -> typing.Any: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + return self._connection + + def _compile( + self, + query: ClauseElement, + ) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]: + compiled = query.compile( + dialect=self._dialect, + compile_kwargs={"render_postcompile": True}, + ) + + compiled_query = compiled.string + params = compiled.params + result_map = compiled._result_columns + + return compiled_query, params, result_map + + +class PsycopgTransaction(TransactionBackend): + _connecttion: PsycopgConnection + _transaction: typing.Optional[psycopg.AsyncTransaction] + + def __init__(self, connection: PsycopgConnection): + self._connection = connection + self._transaction: typing.Optional[psycopg.AsyncTransaction] = None + + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: + if self._connection._connection is None: + raise RuntimeError("Connection is not acquired") + + transaction = psycopg.AsyncTransaction( + self._connection._connection, **extra_options + ) + async with transaction._conn.lock: + await transaction._conn.wait(transaction._enter_gen()) + self._transaction = transaction + + async def commit(self) -> None: + if self._transaction is None: + raise RuntimeError("Transaction was not started") + + async with self._transaction._conn.lock: + await self._transaction._conn.wait(self._transaction._commit_gen()) + + async def rollback(self) -> None: + if self._transaction is None: + raise RuntimeError("Transaction was not started") + + async with self._transaction._conn.lock: + await self._transaction._conn.wait(self._transaction._rollback_gen(None)) + + +class PsycopgRecord(Record): + @property + def _mapping(self) -> typing.Mapping: + return self._row._asdict() + + def __getitem__(self, key: typing.Any) -> typing.Any: + if len(self._column_map) == 0: + if isinstance(key, str): + return self._mapping[key] + return self._row[key] + elif isinstance(key, Column): + idx, datatype = self._column_map_full[str(key)] + elif isinstance(key, int): + idx, datatype = self._column_map_int[key] + else: + idx, datatype = self._column_map[key] + + return self._row[idx] diff --git a/databases/core.py b/databases/core.py index d55dd3c8..cba06ced 100644 --- a/databases/core.py +++ b/databases/core.py @@ -43,12 +43,16 @@ class Database: SUPPORTED_BACKENDS = { - "postgresql": "databases.backends.postgres:PostgresBackend", + "postgres": "databases.backends.asyncpg:AsyncpgBackend", + "postgresql": "databases.backends.asyncpg:AsyncpgBackend", "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", - "postgres": "databases.backends.postgres:PostgresBackend", + "postgresql+asyncpg": "databases.backends.asyncpg:AsyncpgBackend", + "postgresql+psycopg": "databases.backends.psycopg:PsycopgBackend", "mysql": "databases.backends.mysql:MySQLBackend", + "mysql+aiomysql": "databases.backends.asyncmy:MySQLBackend", "mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", + "sqlite+aiosqlite": "databases.backends.sqlite:SQLiteBackend", } _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" diff --git a/docs/contributing.md b/docs/contributing.md index 92ab3b3c..fccc8687 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -73,7 +73,7 @@ run all of those with lint script version: '2.1' services: postgres: - image: postgres:10.8 + image: postgres:16 environment: POSTGRES_USER: username POSTGRES_PASSWORD: password @@ -82,7 +82,7 @@ run all of those with lint script - 5432:5432 mysql: - image: mysql:5.7 + image: mariadb:11 environment: MYSQL_USER: username MYSQL_PASSWORD: password diff --git a/docs/index.md b/docs/index.md index 7c3cebf2..c4e581d5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,6 +31,7 @@ Database drivers supported are: * [asyncpg][asyncpg] * [aiopg][aiopg] +* [psycopg3][psycopg3] * [aiomysql][aiomysql] * [asyncmy][asyncmy] * [aiosqlite][aiosqlite] @@ -40,6 +41,7 @@ You can install the required database drivers with: ```shell $ pip install databases[asyncpg] $ pip install databases[aiopg] +$ pip install databases[psycopg3] $ pip install databases[aiomysql] $ pip install databases[asyncmy] $ pip install databases[aiosqlite] @@ -103,6 +105,7 @@ for examples of how to start using databases together with SQLAlchemy core expre [pymysql]: https://github.com/PyMySQL/PyMySQL [asyncpg]: https://github.com/MagicStack/asyncpg [aiopg]: https://github.com/aio-libs/aiopg +[psycopg3]: https://github.com/psycopg/psycopg [aiomysql]: https://github.com/aio-libs/aiomysql [asyncmy]: https://github.com/long2ice/asyncmy [aiosqlite]: https://github.com/omnilib/aiosqlite diff --git a/requirements.txt b/requirements.txt index 8b05a46e..450b1c63 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,21 @@ -e . +# Speedups +orjson==3.9.15 + # Async database drivers asyncmy==0.2.9 aiomysql==0.2.0 aiopg==1.4.0 aiosqlite==0.20.0 asyncpg==0.29.0 +psycopg==3.1.18 +psycopg-binary==3.1.18 +psycopg-pool==3.2.1 # Sync database drivers for standard tooling around setup/teardown/migrations. psycopg==3.1.18 +psycopg-binary==3.1.18 pymysql==1.1.0 # Testing diff --git a/setup.py b/setup.py index 41c0c584..33b6f137 100644 --- a/setup.py +++ b/setup.py @@ -47,16 +47,18 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=2.0.7"], + install_requires=["sqlalchemy>=2.0.11"], extras_require={ - "postgresql": ["asyncpg"], - "asyncpg": ["asyncpg"], - "aiopg": ["aiopg"], "mysql": ["aiomysql"], "aiomysql": ["aiomysql"], "asyncmy": ["asyncmy"], + "postgresql": ["asyncpg"], + "aiopg": ["aiopg"], + "asyncpg": ["asyncpg"], + "psycopg3": ["psycopg", "psycopg-pool"], "sqlite": ["aiosqlite"], "aiosqlite": ["aiosqlite"], + "orjson": ["orjson"], }, classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index 81ce2ac7..757393a4 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -6,7 +6,7 @@ import pytest from databases.backends.aiopg import AiopgBackend -from databases.backends.postgres import PostgresBackend +from databases.backends.asyncpg import AsyncpgBackend from databases.core import DatabaseURL from tests.test_databases import DATABASE_URLS, async_adapter @@ -19,7 +19,7 @@ def test_postgres_pool_size(): - backend = PostgresBackend("postgres://localhost/database?min_size=1&max_size=20") + backend = AsyncpgBackend("postgres://localhost/database?min_size=1&max_size=20") kwargs = backend._get_connection_kwargs() assert kwargs == {"min_size": 1, "max_size": 20} @@ -29,43 +29,43 @@ async def test_postgres_pool_size_connect(): for url in DATABASE_URLS: if DatabaseURL(url).dialect != "postgresql": continue - backend = PostgresBackend(url + "?min_size=1&max_size=20") + backend = AsyncpgBackend(url + "?min_size=1&max_size=20") await backend.connect() await backend.disconnect() def test_postgres_explicit_pool_size(): - backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20) + backend = AsyncpgBackend("postgres://localhost/database", min_size=1, max_size=20) kwargs = backend._get_connection_kwargs() assert kwargs == {"min_size": 1, "max_size": 20} def test_postgres_ssl(): - backend = PostgresBackend("postgres://localhost/database?ssl=true") + backend = AsyncpgBackend("postgres://localhost/database?ssl=true") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} def test_postgres_ssl_verify_full(): - backend = PostgresBackend("postgres://localhost/database?ssl=verify-full") + backend = AsyncpgBackend("postgres://localhost/database?ssl=verify-full") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": "verify-full"} def test_postgres_explicit_ssl(): - backend = PostgresBackend("postgres://localhost/database", ssl=True) + backend = AsyncpgBackend("postgres://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} def test_postgres_explicit_ssl_verify_full(): - backend = PostgresBackend("postgres://localhost/database", ssl="verify-full") + backend = AsyncpgBackend("postgres://localhost/database", ssl="verify-full") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": "verify-full"} def test_postgres_no_extra_options(): - backend = PostgresBackend("postgres://localhost/database") + backend = AsyncpgBackend("postgres://localhost/database") kwargs = backend._get_connection_kwargs() assert kwargs == {} @@ -74,7 +74,7 @@ def test_postgres_password_as_callable(): def gen_password(): return "Foo" - backend = PostgresBackend( + backend = AsyncpgBackend( "postgres://:password@localhost/database", password=gen_password ) kwargs = backend._get_connection_kwargs() diff --git a/tests/test_databases.py b/tests/test_databases.py index d9d9e4d6..66164aea 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -134,6 +134,7 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "postgresql+psycopg", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -151,6 +152,7 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "postgresql+psycopg", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -202,17 +204,17 @@ async def test_queries(database_url): assert len(results) == 3 assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True assert results[1]["text"] == "example2" - assert results[1]["completed"] == False + assert results[1]["completed"] is False assert results[2]["text"] == "example3" - assert results[2]["completed"] == True + assert results[2]["completed"] is True # fetch_one() query = notes.select() result = await database.fetch_one(query=query) assert result["text"] == "example1" - assert result["completed"] == True + assert result["completed"] is True # fetch_val() query = sqlalchemy.sql.select(*[notes.c.text]) @@ -244,11 +246,11 @@ async def test_queries(database_url): iterate_results.append(result) assert len(iterate_results) == 3 assert iterate_results[0]["text"] == "example1" - assert iterate_results[0]["completed"] == True + assert iterate_results[0]["completed"] is True assert iterate_results[1]["text"] == "example2" - assert iterate_results[1]["completed"] == False + assert iterate_results[1]["completed"] is False assert iterate_results[2]["text"] == "example3" - assert iterate_results[2]["completed"] == True + assert iterate_results[2]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -278,26 +280,26 @@ async def test_queries_raw(database_url): results = await database.fetch_all(query=query, values={"completed": True}) assert len(results) == 2 assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True assert results[1]["text"] == "example3" - assert results[1]["completed"] == True + assert results[1]["completed"] is True # fetch_one() query = "SELECT * FROM notes WHERE completed = :completed" result = await database.fetch_one(query=query, values={"completed": False}) assert result["text"] == "example2" - assert result["completed"] == False + assert result["completed"] is False # fetch_val() query = "SELECT completed FROM notes WHERE text = :text" result = await database.fetch_val(query=query, values={"text": "example1"}) - assert result == True + assert result is True query = "SELECT * FROM notes WHERE text = :text" result = await database.fetch_val( query=query, values={"text": "example1"}, column="completed" ) - assert result == True + assert result is True # iterate() query = "SELECT * FROM notes" @@ -306,11 +308,11 @@ async def test_queries_raw(database_url): iterate_results.append(result) assert len(iterate_results) == 3 assert iterate_results[0]["text"] == "example1" - assert iterate_results[0]["completed"] == True + assert iterate_results[0]["completed"] is True assert iterate_results[1]["text"] == "example2" - assert iterate_results[1]["completed"] == False + assert iterate_results[1]["completed"] is False assert iterate_results[2]["text"] == "example3" - assert iterate_results[2]["completed"] == True + assert iterate_results[2]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -378,7 +380,7 @@ async def test_results_support_mapping_interface(database_url): assert isinstance(results_as_dicts[0]["id"], int) assert results_as_dicts[0]["text"] == "example1" - assert results_as_dicts[0]["completed"] == True + assert results_as_dicts[0]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -465,7 +467,7 @@ async def test_execute_return_val(database_url): query = notes.select().where(notes.c.id == pk) result = await database.fetch_one(query) assert result["text"] == "example1" - assert result["completed"] == True + assert result["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -855,7 +857,7 @@ async def test_transaction_commit_low_level(database_url): try: query = notes.insert().values(text="example1", completed=True) await database.execute(query) - except: # pragma: no cover + except Exception: # pragma: no cover await transaction.rollback() else: await transaction.commit() @@ -879,7 +881,7 @@ async def test_transaction_rollback_low_level(database_url): query = notes.insert().values(text="example1", completed=True) await database.execute(query) raise RuntimeError() - except: + except Exception: await transaction.rollback() else: # pragma: no cover await transaction.commit() @@ -1336,6 +1338,7 @@ async def test_queries_with_expose_backend_connection(database_url): "mysql+asyncmy", "mysql+aiomysql", "postgresql+aiopg", + "postgresql+psycopg", ]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: @@ -1351,10 +1354,13 @@ async def test_queries_with_expose_backend_connection(database_url): ]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(insert_query, values) - elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: + elif database.url.scheme in [ + "postgresql", + "postgresql+asyncpg", + ]: await raw_connection.execute(insert_query, *values) elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: await raw_connection.execute(insert_query, values) @@ -1365,7 +1371,7 @@ async def test_queries_with_expose_backend_connection(database_url): if database.url.scheme in ["mysql", "mysql+aiomysql"]: cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.executemany(insert_query, values) elif database.url.scheme == "postgresql+aiopg": @@ -1388,7 +1394,7 @@ async def test_queries_with_expose_backend_connection(database_url): cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(select_query) results = await cursor.fetchall() @@ -1400,16 +1406,16 @@ async def test_queries_with_expose_backend_connection(database_url): assert len(results) == 3 # Raw output for the raw request assert results[0][1] == "example1" - assert results[0][2] == True + assert results[0][2] is True assert results[1][1] == "example2" - assert results[1][2] == False + assert results[1][2] is False assert results[2][1] == "example3" - assert results[2][2] == True + assert results[2][2] is True # fetch_one() if database.url.scheme in ["postgresql", "postgresql+asyncpg"]: result = await raw_connection.fetchrow(select_query) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(select_query) result = await cursor.fetchone() @@ -1420,7 +1426,7 @@ async def test_queries_with_expose_backend_connection(database_url): # Raw output for the raw request assert result[1] == "example1" - assert result[2] == True + assert result[2] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1591,7 +1597,7 @@ async def test_column_names(database_url, select_query): assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1626,23 +1632,6 @@ async def test_result_named_access(database_url): assert result.completed is True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_mapping_property_interface(database_url): - """ - Test that all connections implement interface with `_mapping` property - """ - async with Database(database_url) as database: - query = notes.select() - single_result = await database.fetch_one(query=query) - assert single_result._mapping["text"] == "example1" - assert single_result._mapping["completed"] is True - - list_result = await database.fetch_all(query=query) - assert list_result[0]._mapping["text"] == "example1" - assert list_result[0]._mapping["completed"] is True - - @async_adapter async def test_should_not_maintain_ref_when_no_cache_param(): async with Database( diff --git a/tests/test_integration.py b/tests/test_integration.py index 139f8ffe..0605529f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,10 @@ +import contextlib + import pytest import sqlalchemy from starlette.applications import Starlette from starlette.responses import JSONResponse +from starlette.routing import Route from starlette.testclient import TestClient from databases import Database, DatabaseURL @@ -29,6 +32,7 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "postgresql+psycopg", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -45,6 +49,7 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + "postgresql+psycopg", ]: url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) @@ -53,17 +58,13 @@ def create_test_database(): def get_app(database_url): database = Database(database_url, force_rollback=True) - app = Starlette() - @app.on_event("startup") - async def startup(): + @contextlib.asynccontextmanager + async def lifespan(app): await database.connect() - - @app.on_event("shutdown") - async def shutdown(): + yield await database.disconnect() - @app.route("/notes", methods=["GET"]) async def list_notes(request): query = notes.select() results = await database.fetch_all(query) @@ -73,14 +74,18 @@ async def list_notes(request): ] return JSONResponse(content) - @app.route("/notes", methods=["POST"]) async def add_note(request): data = await request.json() query = notes.insert().values(text=data["text"], completed=data["completed"]) await database.execute(query) return JSONResponse({"text": data["text"], "completed": data["completed"]}) - return app + routes = [ + Route("/notes", list_notes, methods=["GET"]), + Route("/notes", add_note, methods=["POST"]), + ] + + return Starlette(routes=routes, lifespan=lifespan) @pytest.mark.parametrize("database_url", DATABASE_URLS)