From a725d132920c987b0533579578fe8f2ca9217d22 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 24 Jan 2024 17:43:07 -0500 Subject: [PATCH] refactor: fix higher order fn type hints --- .../evaluation_service/wrapper.py | 20 +++++++++++++------ .../evaluationservice/utilities/async_orm.py | 9 +++++---- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/services/evaluationservice/dmod/evaluationservice/evaluation_service/wrapper.py b/python/services/evaluationservice/dmod/evaluationservice/evaluation_service/wrapper.py index be1239c2d..68868c501 100644 --- a/python/services/evaluationservice/dmod/evaluationservice/evaluation_service/wrapper.py +++ b/python/services/evaluationservice/dmod/evaluationservice/evaluation_service/wrapper.py @@ -48,11 +48,19 @@ def __iter__(self): return iter(self.__multiple_values) +_T = typing.TypeVar("_T") +_V = typing.TypeVar("_V") + +class wrapper_fn(typing.Protocol, typing.Generic[_T, _V]): + def __call__(self, *args: _T, **kwargs: _V) -> typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]: + ... + + def wrapper_caller( - function: typing.Callable[[typing.Any, ...], typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]], + function: wrapper_fn[_T, _V], _wrapper_return_values: WrapperResults, - args: typing.Iterable, - kwargs: typing.Mapping + args: typing.Sequence[_T], + kwargs: typing.Mapping[str, _V], ): function_results = function(*args, **kwargs) @@ -79,9 +87,9 @@ def __init__(self, model_type: typing.Type[django_models.Model]): def __call_wrapper( self, - function: typing.Callable[[typing.Any, ...], typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]], - *args, - **kwargs + function: wrapper_fn[_T, _V], + *args: _T, + **kwargs: _V, ) -> typing.Union[_MODEL_TYPE, typing.Sequence[_MODEL_TYPE]]: results = WrapperResults() diff --git a/python/services/evaluationservice/dmod/evaluationservice/utilities/async_orm.py b/python/services/evaluationservice/dmod/evaluationservice/utilities/async_orm.py index e79c73c8e..3e5feaf3a 100644 --- a/python/services/evaluationservice/dmod/evaluationservice/utilities/async_orm.py +++ b/python/services/evaluationservice/dmod/evaluationservice/utilities/async_orm.py @@ -2,6 +2,7 @@ Helper functions used for interacting with the Django ORM in an asynchronous context """ import typing +import typing_extensions import functools import asyncio import threading @@ -10,7 +11,7 @@ _T = typing.TypeVar("_T") -LAZY_RETRIEVER = typing.Callable[[typing.Any, ...], typing.Optional[QuerySet[_T]]] +LAZY_RETRIEVER = typing.Callable[[typing_extensions.ParamSpec], typing.Optional[QuerySet[_T]]] def get_values_eagerly(function: LAZY_RETRIEVER, *args, **kwargs) -> typing.Optional[typing.Sequence[_T]]: @@ -35,7 +36,7 @@ def get_values_eagerly(function: LAZY_RETRIEVER, *args, **kwargs) -> typing.Opti return result -async def communicate_with_database(function: typing.Callable[[typing.Any, ...], _T], *args, **kwargs) -> _T: +async def communicate_with_database(function: typing.Callable[[typing_extensions.ParamSpec], _T], *args, **kwargs) -> _T: """ Use a function that has to use the Django database. @@ -56,7 +57,7 @@ async def communicate_with_database(function: typing.Callable[[typing.Any, ...], result = await asyncio.get_running_loop().run_in_executor(None, prepared_function) return result -def wrapper_communicate(function: typing.Callable[[typing.Any, ...], _T], cwds_return_data: typing.MutableMapping, kwargs: typing.Mapping): +def wrapper_communicate(function: typing.Callable[[typing_extensions.ParamSpecKwargs], _T], cwds_return_data: typing.MutableMapping, kwargs: typing.Mapping): results = function(**kwargs) if results and isinstance(results, QuerySet): @@ -65,7 +66,7 @@ def wrapper_communicate(function: typing.Callable[[typing.Any, ...], _T], cwds_r cwds_return_data['results'] = results -def select_from_database(function: typing.Callable[[typing.Any, ...], _T], **kwargs) -> _T: +def select_from_database(function: typing.Callable[[typing_extensions.ParamSpec], _T], **kwargs) -> _T: _cwds_return_data = { "results": [] }