Skip to content

Commit

Permalink
refactor: fix higher order fn type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
aaraney committed Jan 24, 2024
1 parent a9f5440 commit 7a5d2a6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def __iter__(self):
return iter(self.__multiple_values)


_T = typing.TypeVar("_T")
_V = typing.TypeVar("_V")

class wrapper_fn(typing.Protocol, typing.Generic[_T, _V]):
def __call__(self, *args: _T, **kwargs: _V) -> typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]:
...


def wrapper_caller(
function: typing.Callable[[typing.Any, ...], typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]],
function: wrapper_fn[_T, _V],
_wrapper_return_values: WrapperResults,
args: typing.Iterable,
kwargs: typing.Mapping
args: typing.Sequence[_T],
kwargs: typing.Mapping[str, _V],
):
function_results = function(*args, **kwargs)

Expand All @@ -79,9 +87,9 @@ def __init__(self, model_type: typing.Type[django_models.Model]):

def __call_wrapper(
self,
function: typing.Callable[[typing.Any, ...], typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]],
*args,
**kwargs
function: wrapper_fn[_T, _V],
*args: _T,
**kwargs: _V,
) -> typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]:
results = WrapperResults()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Helper functions used for interacting with the Django ORM in an asynchronous context
"""
import typing
import typing_extensions
import functools
import asyncio
import threading
Expand All @@ -10,7 +11,7 @@


_T = typing.TypeVar("_T")
LAZY_RETRIEVER = typing.Callable[[typing.Any, ...], typing.Optional[QuerySet[_T]]]
LAZY_RETRIEVER = typing.Callable[[typing_extensions.ParamSpec], typing.Optional[QuerySet[_T]]]


def get_values_eagerly(function: LAZY_RETRIEVER, *args, **kwargs) -> typing.Optional[typing.Sequence[_T]]:
Expand All @@ -35,7 +36,7 @@ def get_values_eagerly(function: LAZY_RETRIEVER, *args, **kwargs) -> typing.Opti
return result


async def communicate_with_database(function: typing.Callable[[typing.Any, ...], _T], *args, **kwargs) -> _T:
async def communicate_with_database(function: typing.Callable[[typing_extensions.ParamSpec], _T], *args, **kwargs) -> _T:
"""
Use a function that has to use the Django database.
Expand All @@ -56,7 +57,7 @@ async def communicate_with_database(function: typing.Callable[[typing.Any, ...],
result = await asyncio.get_running_loop().run_in_executor(None, prepared_function)
return result

def wrapper_communicate(function: typing.Callable[[typing.Any, ...], _T], cwds_return_data: typing.MutableMapping, kwargs: typing.Mapping):
def wrapper_communicate(function: typing.Callable[[typing_extensions.ParamSpecKwargs], _T], cwds_return_data: typing.MutableMapping, kwargs: typing.Mapping):
results = function(**kwargs)

if results and isinstance(results, QuerySet):
Expand All @@ -65,7 +66,7 @@ def wrapper_communicate(function: typing.Callable[[typing.Any, ...], _T], cwds_r
cwds_return_data['results'] = results


def select_from_database(function: typing.Callable[[typing.Any, ...], _T], **kwargs) -> _T:
def select_from_database(function: typing.Callable[[typing_extensions.ParamSpec], _T], **kwargs) -> _T:
_cwds_return_data = {
"results": []
}
Expand Down

0 comments on commit 7a5d2a6

Please sign in to comment.