diff --git a/a_sync/_bound.py b/a_sync/_bound.py index 0f908e1c..51b44ddb 100644 --- a/a_sync/_bound.py +++ b/a_sync/_bound.py @@ -63,7 +63,7 @@ def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T self._property = property def __repr__(self) -> str: return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>" - def __await__(self) -> T: + def __await__(self) -> Generator[Any, None, T]: return self._coro.__await__() @overload diff --git a/a_sync/_typing.py b/a_sync/_typing.py index 61975886..bbc4c9fe 100644 --- a/a_sync/_typing.py +++ b/a_sync/_typing.py @@ -4,8 +4,8 @@ from decimal import Decimal from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable, Callable, DefaultDict, Dict, Generic, Iterable, Iterator, - List, Literal, Optional, Set, Tuple, Type, TypedDict, TypeVar, - Union, final, overload) + List, Literal, Optional, Protocol, Set, Tuple, Type, TypedDict, + TypeVar, Union, final, overload) from typing_extensions import Concatenate, ParamSpec, Self, Unpack @@ -17,6 +17,8 @@ V = TypeVar("V") P = ParamSpec("P") +Numeric = Union[int, float, Decimal] + MaybeAwaitable = Union[Awaitable[T], T] Property = Callable[["ASyncABC"], T] diff --git a/a_sync/base.py b/a_sync/base.py index 47636fdd..8e2037e4 100644 --- a/a_sync/base.py +++ b/a_sync/base.py @@ -93,7 +93,7 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool: return flag_value @classmethod - def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]: + def __get_a_sync_flag_value_from_class_def(cls, flag: Optional[str]) -> Optional[bool]: for spec in [cls, *cls.__bases__]: flag_value = spec.__dict__.get(flag) if flag_value is not None: diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 9ad1277a..1a806efd 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -1,5 +1,5 @@ -from typing import Any, Set +from typing import Any, Optional, Set class ASyncFlagException(ValueError): @@ -31,7 +31,7 @@ def __init__(self, target, present_flags): super().__init__(err) class InvalidFlag(ASyncFlagException): - def __init__(self, flag: str): + def __init__(self, flag: Optional[str]): err = f"'flag' must be one of: {self.viable_flags}. You passed {flag}." err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition." super().__init__(err) diff --git a/a_sync/future.py b/a_sync/future.py index 35e70595..1225a87e 100644 --- a/a_sync/future.py +++ b/a_sync/future.py @@ -23,7 +23,7 @@ def _materialize(meta: "ASyncFuture[T]") -> T: except RuntimeError as e: raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e -Numeric = Union[int, float, Decimal, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"] +MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"] class ASyncFuture(concurrent.futures.Future, Awaitable[T]): __slots__ = "__awaitable__", "__dependencies", "__dependants", "__task" @@ -120,7 +120,7 @@ def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... @overload def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... - def __add__(self, other: Numeric) -> "ASyncFuture": + def __add__(self, other: MetaNumeric) -> "ASyncFuture": return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other)) @overload def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... @@ -150,7 +150,7 @@ def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... @overload def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... - def __sub__(self, other: Numeric) -> "ASyncFuture": + def __sub__(self, other: MetaNumeric) -> "ASyncFuture": return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other)) def __mul__(self, other) -> "ASyncFuture": return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other)) diff --git a/a_sync/modifiers/manager.py b/a_sync/modifiers/manager.py index 0dd8e3f8..d55573a6 100644 --- a/a_sync/modifiers/manager.py +++ b/a_sync/modifiers/manager.py @@ -70,13 +70,13 @@ def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T: return sync_modifier_wrap # Dictionary api - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> ItemsView[str, Any]: return self._modifiers.items() - def keys(self) -> List[str]: + def keys(self) -> KeysView[str]: return self._modifiers.keys() - def values(self) -> List[Any]: + def values(self) -> ValuesView[Any]: return self._modifiers.values() - def __contains__(self, key: str) -> bool: + def __contains__(self, key: str) -> bool: # type: ignore [override] return key in self._modifiers def __iter__(self) -> Iterator[str]: return self._modifiers.__iter__() diff --git a/a_sync/primitives/executor.py b/a_sync/primitives/executor.py index c6c1c32f..28373a0d 100644 --- a/a_sync/primitives/executor.py +++ b/a_sync/primitives/executor.py @@ -31,7 +31,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: Oh, and you can also use kwargs! """ return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs) - def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": + def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override] """Submits a job to the executor and returns an `asyncio.Future` that can be awaited for the result without blocking.""" if self.sync_mode: fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs)) @@ -49,7 +49,7 @@ def sync_mode(self) -> bool: return self._max_workers == 0 @property def worker_count_current(self) -> int: - len(getattr(self, f"_{self._workers}")) + return len(getattr(self, f"_{self._workers}")) async def _exec_sync(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """Just wraps a fn and its args into an awaitable.""" return fn(*args, **kwargs) diff --git a/a_sync/primitives/locks/event.py b/a_sync/primitives/locks/event.py index 01b74719..da821842 100644 --- a/a_sync/primitives/locks/event.py +++ b/a_sync/primitives/locks/event.py @@ -1,13 +1,17 @@ -import asyncio, sys -from functools import cached_property -from typing import Optional +import asyncio +import sys +from a_sync._typing import * from a_sync.primitives._debug import _DebugDaemonMixin class Event(asyncio.Event, _DebugDaemonMixin): """asyncio.Event but with some additional debug logging to help detect deadlocks.""" + _value: bool + _loop: asyncio.BaseEventLoop + _waiters: Deque["asyncio.Future[None]"] + def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None): if sys.version_info >= (3, 10): super().__init__() @@ -22,7 +26,7 @@ def __repr__(self) -> str: if self._waiters: status += f', waiters:{len(self._waiters)}' return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>" - async def wait(self) -> bool: + async def wait(self) -> Literal[True]: if self.is_set(): return True self._ensure_debug_daemon() diff --git a/a_sync/primitives/locks/prio_semaphore.py b/a_sync/primitives/locks/prio_semaphore.py index cb0f5626..755818d9 100644 --- a/a_sync/primitives/locks/prio_semaphore.py +++ b/a_sync/primitives/locks/prio_semaphore.py @@ -4,9 +4,8 @@ import logging from collections import deque from functools import cached_property -from typing import (Deque, Dict, Generic, List, Literal, Optional, Protocol, - Type, TypeVar) +from a_sync._typing import * from a_sync.primitives.locks.semaphore import Semaphore logger = logging.getLogger(__name__) @@ -188,10 +187,10 @@ async def acquire(self) -> Literal[True]: def release(self) -> None: self._parent.release() -class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[int]): +class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]): _priority_name = "priority" -class PrioritySemaphore(_AbstractPrioritySemaphore[int, _PrioritySemaphoreContextManager]): # type: ignore [type-var] +class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var] _context_manager_class = _PrioritySemaphoreContextManager _top_priority = -1 """ diff --git a/a_sync/primitives/locks/semaphore.py b/a_sync/primitives/locks/semaphore.py index 3a3cca76..f22ea4fb 100644 --- a/a_sync/primitives/locks/semaphore.py +++ b/a_sync/primitives/locks/semaphore.py @@ -16,7 +16,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None: """ super().__init__(value, **kwargs) self.name = name or self.__origin__ if hasattr(self, '__origin__') else None - self._decorated = set() + self._decorated: Set[str] = set() # Dank new functionality def __call__(self, fn: Callable[P, T]) -> Callable[P, T]: @@ -31,7 +31,7 @@ def __repr__(self) -> str: def __len__(self) -> int: return len(self._waiters) if self._waiters else 0 - def decorate(self, fn: Callable[P, T]) -> Callable[P, T]: + def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: if not asyncio.iscoroutinefunction(fn): raise TypeError(f"{fn} must be a coroutine function") @functools.wraps(fn) @@ -41,7 +41,7 @@ async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: self._decorated.add(f"{fn.__module__}.{fn.__name__}") return semaphore_wrapper - async def acquire(self) -> bool: + async def acquire(self) -> Literal[True]: if self._value <= 0: self._ensure_debug_daemon() return await super().acquire() diff --git a/a_sync/task.py b/a_sync/task.py index 502c1fce..31f58fbc 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -19,7 +19,7 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until __persist(task) return task -class TaskMapping(DefaultDict[K, "asyncio.Task[_V]"]): +class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]): def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None: self._coro_fn = coro_fn self._coro_fn_kwargs = coro_fn_kwargs diff --git a/a_sync/utils/map.py b/a_sync/utils/map.py index e13e1f00..0e7e613b 100644 --- a/a_sync/utils/map.py +++ b/a_sync/utils/map.py @@ -1,6 +1,4 @@ -from typing import Awaitable, Callable, Literal, Tuple, Union, overload - from a_sync._typing import * from a_sync.iter import ASyncIterator