Skip to content

Commit

Permalink
the rest
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Sep 6, 2024
1 parent 77c43cb commit 6674d03
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 17 deletions.
4 changes: 4 additions & 0 deletions a_sync/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This module initializes the a_sync library by importing and organizing various components, utilities, and classes.
It provides a convenient and unified interface for asynchronous programming with a focus on flexibility and efficiency.
"""

from a_sync import aliases, exceptions, iter, task
from a_sync.a_sync import ASyncGenericBase, ASyncGenericSingleton, a_sync
Expand Down
73 changes: 69 additions & 4 deletions a_sync/_smart.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This module defines smart future and task utilities for the a_sync library.
These utilities provide enhanced functionality for managing asynchronous tasks and futures,
including task shielding and a custom task factory for creating SmartTask instances.
"""

import asyncio
import logging
Expand All @@ -18,10 +23,17 @@
logger = logging.getLogger(__name__)

class _SmartFutureMixin(Generic[T]):
"""
Mixin class that provides common functionality for smart futures and tasks.
"""
_queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None
_key: _Key
_waiters: "weakref.WeakSet[SmartTask[T]]"

def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]:
"""
Await the smart future or task, handling waiters and logging.
"""
if self.done():
return self.result() # May raise too.
self._asyncio_future_blocking = True
Expand All @@ -32,24 +44,38 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.

@property
def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int:
# NOTE: we check .done() because the callback may not have ran yet and its very lightweight
"""
Get the number of waiters currently awaiting the future or task.
"""
if self.done():
# if there are any waiters left, there won't be once the event loop runs once
return 0
return sum(getattr(waiter, 'num_waiters', 1) or 1 for waiter in self._waiters)

def _waiter_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask") -> None:
"Removes the waiter from _waiters, and _queue._futs if applicable"
"""
Callback to clean up waiters when a waiter task is done.
Removes the waiter from _waiters, and _queue._futs if applicable
"""
if not self.done():
self._waiters.remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
"""
Callback to clean up waiters and remove the future from the queue when done.
"""
self._waiters.clear()
if queue := self._queue:
queue._futs.pop(self._key)


class SmartFuture(_SmartFutureMixin[T], asyncio.Future):
"""
A smart future that tracks waiters and integrates with a smart processing queue.
"""
_queue = None
_key = None
def __init__(
Expand All @@ -59,6 +85,14 @@ def __init__(
key: Optional[_Key] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""
Initialize the SmartFuture with an optional queue and key.
Args:
queue: Optional; a smart processing queue.
key: Optional; a key identifying the future.
loop: Optional; the event loop.
"""
super().__init__(loop=loop)
if queue:
self._queue = weakref.proxy(queue)
Expand All @@ -69,7 +103,16 @@ def __init__(
def __repr__(self):
return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>"
def __lt__(self, other: "SmartFuture[T]") -> bool:
"""heap considers lower values as higher priority so a future with more waiters will be 'less than' a future with less waiters."""
"""
Compare the number of waiters to determine priority in a heap.
Lower values indicate higher priority, so more waiters means 'less than'.
Args:
other: Another SmartFuture to compare with.
Returns:
True if self has more waiters than other.
"""
#other = other_ref()
#if other is None:
# # garbage collected refs should always process first so they can be popped from the queue
Expand All @@ -82,16 +125,38 @@ def create_future(
key: Optional[_Key] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> SmartFuture[V]:
"""
Create a SmartFuture instance.
Args:
queue: Optional; a smart processing queue.
key: Optional; a key identifying the future.
loop: Optional; the event loop.
Returns:
A SmartFuture instance.
"""
return SmartFuture(queue=queue, key=key, loop=loop or asyncio.get_event_loop())

class SmartTask(_SmartFutureMixin[T], asyncio.Task):
"""
A smart task that tracks waiters and integrates with a smart processing queue.
"""
def __init__(
self,
coro: Awaitable[T],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
name: Optional[str] = None,
) -> None:
"""
Initialize the SmartTask with a coroutine and optional event loop.
Args:
coro: The coroutine to run in the task.
loop: Optional; the event loop.
name: Optional; the name of the task.
"""
super().__init__(coro, loop=loop, name=name)
self._waiters: Set["asyncio.Task[T]"] = set()
self.add_done_callback(SmartTask._self_done_cleanup_callback)
Expand Down
38 changes: 31 additions & 7 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@ class ASyncFlagException(ValueError):
"""
Base exception class for flag-related errors in the a_sync library.
"""
@property
def viable_flags(self) -> Set[str]:
"""
Returns the set of viable flags.
"""
return VIABLE_FLAGS

viable_flags = VIABLE_FLAGS
"""
The set of viable flags.
A-Sync uses 'flags' to indicate whether objects / fn calls will be sync or async.
You can use any of the provided flags, whichever makes most sense for your use case.
"""

def desc(self, target) -> str:
"""
Returns a description of the target for the flag error message.
Args:
target: The target object or string to describe.
Returns:
A string description of the target.
"""
if target == 'kwargs':
return "flags present in 'kwargs'"
else:
Expand Down Expand Up @@ -140,6 +151,9 @@ def __init__(self, fn):
super().__init__(f"`func` must be a coroutine function defined with `def`. You passed {fn}.")

class ASyncRuntimeError(RuntimeError):
"""
Raised for runtime errors in asynchronous operations.
"""
def __init__(self, e: RuntimeError):
"""
Initializes the ASyncRuntimeError exception.
Expand Down Expand Up @@ -196,12 +210,22 @@ class MappingNotEmptyError(MappingError):
_msg = "TaskMapping already contains some data. In order to use `map`, you need a fresh one"

class PersistedTaskException(Exception):
"""
Raised when an exception persists in an asyncio Task.
"""
def __init__(self, exc: E, task: asyncio.Task) -> None:
"""
Initializes the PersistedTaskException exception.
Args:
exc: The exception that persisted.
task: The asyncio Task where the exception occurred.
"""
super().__init__(f"{exc.__class__.__name__}: {exc}", task)
self.exception = exc
self.task = task

class EmptySequenceError(ValueError):
"""
Raised when an operation is attempted on an empty sequence but items are expected.
Raised when an operation is attempted on an empty sequence but items are required.
"""
Loading

0 comments on commit 6674d03

Please sign in to comment.