Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add callable that returns a model as a supported model type #333

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions pydapper/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def execute(self, sql: str, param: Union["ParamType", "ListParamType"] = None) -
rowcount = handler.execute(cursor)
return rowcount

def _buffered_query(self, handler: BaseSqlParamHandler, model: Type["_T"]) -> List["_T"]:
def _buffered_query(
self, handler: BaseSqlParamHandler, model: Union[Type["_T"], Callable[..., "_T"]]
) -> List["_T"]:
with self._cursor_context_proxy() as cursor:
handler.execute(cursor)
headers = get_col_names(cursor)
Expand Down Expand Up @@ -183,12 +185,22 @@ def query(

@overload
def query(
self, sql: str, param: Optional["ParamType"] = ..., buffered: "Literal[True]" = True, *, model: Type["_T"]
self,
sql: str,
param: Optional["ParamType"] = ...,
buffered: "Literal[True]" = True,
*,
model: Union[Type["_T"], Callable[..., "_T"]],
) -> List["_T"]: ...

@overload
def query(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"], buffered: "Literal[False]"
self,
sql: str,
param: Optional["ParamType"] = ...,
*,
model: Union[Type["_T"], Callable[..., "_T"]],
buffered: "Literal[False]",
) -> Generator["_T", None, None]: ...

def query(self, sql, model=dict, param=None, buffered=True):
Expand Down Expand Up @@ -230,7 +242,9 @@ def query_multiple(
def query_first(self, sql: str, model: Type[Dict] = dict, param: Optional["ParamType"] = ...) -> Dict[str, Any]: ...

@overload
def query_first(self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"]) -> "_T": ...
def query_first(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Union[Type["_T"], Callable[..., "_T"]]
) -> "_T": ...

def query_first(self, sql, model=dict, param=None):
handler = self.SqlParamHandler(sql, param)
Expand Down Expand Up @@ -260,7 +274,7 @@ def query_first_or_default(
default: Callable[[], "_Default"],
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

@overload
Expand All @@ -270,7 +284,7 @@ def query_first_or_default(
default: "_Default",
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

def query_first_or_default(self, sql, default, model=dict, param=None):
Expand All @@ -285,7 +299,9 @@ def query_single(
) -> Dict[str, Any]: ...

@overload
def query_single(self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"]) -> "_T": ...
def query_single(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Union[Type["_T"], Callable[..., "_T"]]
) -> "_T": ...

def query_single(self, sql, model=dict, param=None):
handler = self.SqlParamHandler(sql, param)
Expand Down Expand Up @@ -320,7 +336,7 @@ def query_single_or_default(
default: Callable[[], "_Default"],
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

@overload
Expand All @@ -330,7 +346,7 @@ def query_single_or_default(
default: "_Default",
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

def query_single_or_default(self, sql, default, model=dict, param=None):
Expand Down Expand Up @@ -376,14 +392,18 @@ async def execute_async(self, sql: str, param: Union["ParamType", "ListParamType
async with self.cursor() as cursor:
return await handler.execute_async(cursor)

async def _buffered_query(self, handler: BaseSqlParamHandler, model: Type["_T"]) -> List["_T"]:
async def _buffered_query(
self, handler: BaseSqlParamHandler, model: Union[Type["_T"], Callable[..., "_T"]]
) -> List["_T"]:
async with self.cursor() as cursor:
await handler.execute_async(cursor)
headers = get_col_names(cursor)
data = await cursor.fetchall()
return [serialize_dict_row(model, database_row_to_dict(headers, row)) for row in data]

async def _unbuffered_query(self, handler: BaseSqlParamHandler, model: Type["_T"]) -> AsyncGenerator["_T", None]:
async def _unbuffered_query(
self, handler: BaseSqlParamHandler, model: Union[Type["_T"], Callable[..., "_T"]]
) -> AsyncGenerator["_T", None]:
async with self.cursor() as cursor:
await handler.execute_async(cursor)
headers = get_col_names(cursor)
Expand All @@ -410,12 +430,22 @@ async def query_async(

@overload
async def query_async(
self, sql: str, param: Optional["ParamType"] = ..., buffered: "Literal[True]" = True, *, model: Type["_T"]
self,
sql: str,
param: Optional["ParamType"] = ...,
buffered: "Literal[True]" = True,
*,
model: Union[Type["_T"], Callable[..., "_T"]],
) -> List["_T"]: ...

@overload
async def query_async(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"], buffered: "Literal[False]"
self,
sql: str,
param: Optional["ParamType"] = ...,
*,
model: Union[Type["_T"], Callable[..., "_T"]],
buffered: "Literal[False]",
) -> AsyncGenerator["_T", None]: ...

async def query_async(self, sql, model=dict, param=None, buffered=True):
Expand Down Expand Up @@ -460,7 +490,9 @@ async def query_first_async(
) -> Dict[str, Any]: ...

@overload
async def query_first_async(self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"]) -> "_T": ...
async def query_first_async(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Union[Type["_T"], Callable[..., "_T"]]
) -> "_T": ...

async def query_first_async(self, sql, model=dict, param=None):
handler = self.SqlParamHandler(sql, param)
Expand Down Expand Up @@ -490,7 +522,7 @@ async def query_first_or_default_async(
default: Callable[[], "_Default"],
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

@overload
Expand All @@ -500,7 +532,7 @@ async def query_first_or_default_async(
default: "_Default",
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

async def query_first_or_default_async(self, sql, default, model=dict, param=None):
Expand All @@ -515,7 +547,9 @@ async def query_single_async(
) -> Dict[str, Any]: ...

@overload
async def query_single_async(self, sql: str, param: Optional["ParamType"] = ..., *, model: Type["_T"]) -> "_T": ...
async def query_single_async(
self, sql: str, param: Optional["ParamType"] = ..., *, model: Union[Type["_T"], Callable[..., "_T"]]
) -> "_T": ...

async def query_single_async(self, sql, model=dict, param=None):
handler = self.SqlParamHandler(sql, param)
Expand Down Expand Up @@ -550,7 +584,7 @@ async def query_single_or_default_async(
default: Callable[[], "_Default"],
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

@overload
Expand All @@ -560,7 +594,7 @@ async def query_single_or_default_async(
default: "_Default",
param: Optional["ParamType"] = ...,
*,
model: Type["_T"],
model: Union[Type["_T"], Callable[..., "_T"]],
) -> Union["_Default", "_T"]: ...

async def query_single_or_default_async(self, sql, default, model=dict, param=None):
Expand Down
4 changes: 3 additions & 1 deletion pydapper/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from typing import overload

if TYPE_CHECKING:
Expand All @@ -32,7 +34,7 @@ def serialize_dict_row(model: Type[Dict], row: Dict[str, Any]) -> Dict[str, Any]


@overload
def serialize_dict_row(model: Type["_T"], row: Dict[str, Any]) -> "_T": ...
def serialize_dict_row(model: Union[Type["_T"], Callable[..., "_T"]], row: Dict[str, Any]) -> "_T": ...


def serialize_dict_row(model, row):
Expand Down
5 changes: 5 additions & 0 deletions tests/annotation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def query(query: str) -> None:
assert_type(commands.query(query, buffered=True), List[Dict[str, Any]])
assert_type(commands.query(query, buffered=False), Generator[Dict[str, Any], None, None])
assert_type(commands.query(query, model=Task, buffered=True), List[Task])
assert_type(commands.query(query, model=lambda **kwargs: Task(**kwargs)), List[Task])
assert_type(
commands.query(query, model=lambda **kwargs: Task(**kwargs), buffered=False),
Generator[Task, None, None],
)
assert_type(commands.query(query, model=Task, buffered=False), Generator[Task, None, None])

@staticmethod
Expand Down
Loading