Skip to content

Commit

Permalink
chore: fix mypy errs
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Feb 11, 2024
1 parent c44e616 commit a9c1828
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T
self._property = property
def __repr__(self) -> str:
return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>"
def __await__(self) -> T:
def __await__(self) -> Generator[Any, None, T]:
return self._coro.__await__()

@overload
Expand Down
6 changes: 4 additions & 2 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from decimal import Decimal
from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable,
Callable, DefaultDict, Dict, Generic, Iterable, Iterator,
List, Literal, Optional, Set, Tuple, Type, TypedDict, TypeVar,
Union, final, overload)
List, Literal, Optional, Protocol, Set, Tuple, Type, TypedDict,
TypeVar, Union, final, overload)

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

Expand All @@ -17,6 +17,8 @@
V = TypeVar("V")
P = ParamSpec("P")

Numeric = Union[int, float, Decimal]

MaybeAwaitable = Union[Awaitable[T], T]

Property = Callable[["ASyncABC"], T]
Expand Down
2 changes: 1 addition & 1 deletion a_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool:
return flag_value

@classmethod
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]:
def __get_a_sync_flag_value_from_class_def(cls, flag: Optional[str]) -> Optional[bool]:
for spec in [cls, *cls.__bases__]:
flag_value = spec.__dict__.get(flag)
if flag_value is not None:
Expand Down
4 changes: 2 additions & 2 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Any, Set
from typing import Any, Optional, Set


class ASyncFlagException(ValueError):
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, target, present_flags):
super().__init__(err)

class InvalidFlag(ASyncFlagException):
def __init__(self, flag: str):
def __init__(self, flag: Optional[str]):
err = f"'flag' must be one of: {self.viable_flags}. You passed {flag}."
err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition."
super().__init__(err)
Expand Down
6 changes: 3 additions & 3 deletions a_sync/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _materialize(meta: "ASyncFuture[T]") -> T:
except RuntimeError as e:
raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e

Numeric = Union[int, float, Decimal, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]
MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]

class ASyncFuture(concurrent.futures.Future, Awaitable[T]):
__slots__ = "__awaitable__", "__dependencies", "__dependants", "__task"
Expand Down Expand Up @@ -120,7 +120,7 @@ def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __add__(self, other: Numeric) -> "ASyncFuture":
def __add__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other))
@overload
def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":...
Expand Down Expand Up @@ -150,7 +150,7 @@ def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __sub__(self, other: Numeric) -> "ASyncFuture":
def __sub__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other))
def __mul__(self, other) -> "ASyncFuture":
return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other))
Expand Down
8 changes: 4 additions & 4 deletions a_sync/modifiers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
return sync_modifier_wrap

# Dictionary api
def items(self) -> List[Tuple[str, Any]]:
def items(self) -> ItemsView[str, Any]:
return self._modifiers.items()
def keys(self) -> List[str]:
def keys(self) -> KeysView[str]:
return self._modifiers.keys()
def values(self) -> List[Any]:
def values(self) -> ValuesView[Any]:
return self._modifiers.values()
def __contains__(self, key: str) -> bool:
def __contains__(self, key: str) -> bool: # type: ignore [override]
return key in self._modifiers
def __iter__(self) -> Iterator[str]:
return self._modifiers.__iter__()
Expand Down
4 changes: 2 additions & 2 deletions a_sync/primitives/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
Oh, and you can also use kwargs!
"""
return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs)
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]":
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override]
"""Submits a job to the executor and returns an `asyncio.Future` that can be awaited for the result without blocking."""
if self.sync_mode:
fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs))
Expand All @@ -49,7 +49,7 @@ def sync_mode(self) -> bool:
return self._max_workers == 0
@property
def worker_count_current(self) -> int:
len(getattr(self, f"_{self._workers}"))
return len(getattr(self, f"_{self._workers}"))
async def _exec_sync(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Just wraps a fn and its args into an awaitable."""
return fn(*args, **kwargs)
Expand Down
12 changes: 8 additions & 4 deletions a_sync/primitives/locks/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@

import asyncio, sys
from functools import cached_property
from typing import Optional
import asyncio
import sys

from a_sync._typing import *
from a_sync.primitives._debug import _DebugDaemonMixin


class Event(asyncio.Event, _DebugDaemonMixin):
"""asyncio.Event but with some additional debug logging to help detect deadlocks."""
_value: bool
_loop: asyncio.BaseEventLoop
_waiters: Deque["asyncio.Future[None]"]

def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None):
if sys.version_info >= (3, 10):
super().__init__()
Expand All @@ -22,7 +26,7 @@ def __repr__(self) -> str:
if self._waiters:
status += f', waiters:{len(self._waiters)}'
return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>"
async def wait(self) -> bool:
async def wait(self) -> Literal[True]:
if self.is_set():
return True
self._ensure_debug_daemon()
Expand Down
7 changes: 3 additions & 4 deletions a_sync/primitives/locks/prio_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import logging
from collections import deque
from functools import cached_property
from typing import (Deque, Dict, Generic, List, Literal, Optional, Protocol,
Type, TypeVar)

from a_sync._typing import *
from a_sync.primitives.locks.semaphore import Semaphore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,10 +187,10 @@ async def acquire(self) -> Literal[True]:
def release(self) -> None:
self._parent.release()

class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[int]):
class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]):
_priority_name = "priority"

class PrioritySemaphore(_AbstractPrioritySemaphore[int, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
_context_manager_class = _PrioritySemaphoreContextManager
_top_priority = -1
"""
Expand Down
6 changes: 3 additions & 3 deletions a_sync/primitives/locks/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None:
"""
super().__init__(value, **kwargs)
self.name = name or self.__origin__ if hasattr(self, '__origin__') else None
self._decorated = set()
self._decorated: Set[str] = set()

# Dank new functionality
def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
Expand All @@ -31,7 +31,7 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return len(self._waiters) if self._waiters else 0

def decorate(self, fn: Callable[P, T]) -> Callable[P, T]:
def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]:
if not asyncio.iscoroutinefunction(fn):
raise TypeError(f"{fn} must be a coroutine function")
@functools.wraps(fn)
Expand All @@ -41,7 +41,7 @@ async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self._decorated.add(f"{fn.__module__}.{fn.__name__}")
return semaphore_wrapper

async def acquire(self) -> bool:
async def acquire(self) -> Literal[True]:
if self._value <= 0:
self._ensure_debug_daemon()
return await super().acquire()
Expand Down
2 changes: 1 addition & 1 deletion a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until
__persist(task)
return task

class TaskMapping(DefaultDict[K, "asyncio.Task[_V]"]):
class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]):
def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
Expand Down
2 changes: 0 additions & 2 deletions a_sync/utils/map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

from typing import Awaitable, Callable, Literal, Tuple, Union, overload

from a_sync._typing import *
from a_sync.iter import ASyncIterator

Expand Down

0 comments on commit a9c1828

Please sign in to comment.