Skip to content

Commit

Permalink
S01E04
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 3, 2024
1 parent 9d78221 commit d7ff8e8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
23 changes: 17 additions & 6 deletions databases/backends/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import psycopg
import psycopg_pool
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
from sqlalchemy.sql import ClauseElement

from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import compile_query, get_dialect
from databases.core import DatabaseURL
from databases.interfaces import (
ConnectionBackend,
Expand All @@ -29,7 +29,7 @@ def __init__(
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = get_dialect()
self._dialect = PGDialect_psycopg()
self._pool = None

async def connect(self) -> None:
Expand Down Expand Up @@ -86,7 +86,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
if self._connection is None:
raise RuntimeError("Connection is not acquired")

query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)

async with self._connection.cursor() as cursor:
await cursor.execute(query_str, args)
Expand All @@ -99,7 +99,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa
if self._connection is None:
raise RuntimeError("Connection is not acquired")

query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)

async with self._connection.cursor() as cursor:
await cursor.execute(query_str, args)
Expand All @@ -125,7 +125,7 @@ async def execute(self, query: ClauseElement) -> typing.Any:
if self._connection is None:
raise RuntimeError("Connection is not acquired")

query_str, args, _ = compile_query(query, self._dialect)
query_str, args, _ = self._compile(query)

async with self._connection.cursor() as cursor:
await cursor.execute(query_str, args)
Expand All @@ -141,7 +141,7 @@ async def iterate(
if self._connection is None:
raise RuntimeError("Connection is not acquired")

query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)
column_maps = create_column_maps(result_columns)

async with self._connection.cursor() as cursor:
Expand All @@ -164,6 +164,17 @@ def raw_connection(self) -> typing.Any:
raise RuntimeError("Connection is not acquired")
return self._connection

def _compile(
self, query: ClauseElement,
) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]:
compiled = query.compile(dialect=self._dialect)

compiled_query = compiled.string
params = compiled.params
result_map = compiled._result_columns

return compiled_query, params, result_map


class PsycopgTransaction(TransactionBackend):
_connecttion: PsycopgConnection
Expand Down
1 change: 1 addition & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ async def test_queries_with_expose_backend_connection(database_url):
"mysql+asyncmy",
"mysql+aiomysql",
"postgresql+aiopg",
"postgresql+psycopg",
]:
insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
else:
Expand Down

0 comments on commit d7ff8e8

Please sign in to comment.