Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use correct type hints for query methods #587

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,14 @@
import aiopg
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.backends.compilers.psycopg import PGCompiler_psycopg
from databases.backends.dialects.psycopg import PGDialect_psycopg
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -118,7 +112,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -142,7 +136,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -186,7 +180,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -108,7 +103,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -134,7 +129,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -180,7 +175,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
2 changes: 0 additions & 2 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import enum
import typing
from datetime import date, datetime, time

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -108,7 +103,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -131,7 +126,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -177,7 +172,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import dialect as psycopg_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -99,15 +94,15 @@ async def release(self) -> None:
self._connection = await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
rows = await self._connection.fetch(query_str, *args)
dialect = self._dialect
column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
row = await self._connection.fetchrow(query_str, *args)
Expand Down Expand Up @@ -151,7 +146,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
2 changes: 1 addition & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
4 changes: 2 additions & 2 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Mapping, None]:
) -> typing.AsyncGenerator[Record, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
yield record
Expand Down Expand Up @@ -328,7 +328,7 @@ async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
built_query = self._build_query(query, values)
async with self.transaction():
async with self._query_lock:
Expand Down
2 changes: 1 addition & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Mapping, None]:
) -> typing.AsyncGenerator["Record", None]:
raise NotImplementedError() # pragma: no cover
# mypy needs async iterators to contain a `yield`
# https://github.com/python/mypy/issues/5385#issuecomment-407281656
Expand Down
Loading