Skip to content

Commit

Permalink
feat: optimize task mapping init (#442)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 25, 2024
1 parent 007f6b5 commit e5d0232
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def _wrapped_set_next(
self._wrapped_func = _wrapped_set_next
init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue()
self.__init_loader_coro = exhaust_iterator(
self._tasks_for_iterables(*iterables), queue=init_loader_queue
self._start_tasks_for_iterables(*iterables), queue=init_loader_queue
)
with contextlib.suppress(_NoRunningLoop):
# its okay if we get this exception, we can start the task as soon as the loop starts
Expand All @@ -221,16 +221,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
try:
return dict.__getitem__(self, item)
except KeyError:
if self.concurrency:
# NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks
fut = self._queue.put_nowait(item)
else:
fut = create_task(
coro=self._wrapped_func(item, **self._wrapped_func_kwargs),
name=f"{self._name}[{item}]" if self._name else f"{item}",
)
dict.__setitem__(self, item, fut)
return fut
return self.__start_task(item)

def __await__(self) -> Generator[Any, None, Dict[K, V]]:
"""Wait for all tasks to complete and return a dictionary of the results."""
Expand Down Expand Up @@ -592,6 +583,35 @@ async def _tasks_for_iterables(
"DEV: figure out how to handle this situation"
) from None

@ASyncGeneratorFunction
async def _start_tasks_for_iterables(
self, *iterables: AnyIterableOrAwaitableIterable[K]
) -> AsyncIterator[Tuple[K, "asyncio.Task[V]"]]:
"""Start new tasks for each key in the provided iterables."""
# if we have any regular containers we can yield their contents right away
containers = [
iterable
for iterable in iterables
if not isinstance(iterable, AsyncIterable)
and isinstance(iterable, Iterable)
]
for iterable in containers:
async for key in _yield_keys(iterable):
yield key, self.__start_task(key)

if remaining := [
iterable for iterable in iterables if iterable not in containers
]:
try:
async for key in as_yielded(*[_yield_keys(iterable) for iterable in remaining]): # type: ignore [attr-defined]
yield key, self.__start_task(key)
except _EmptySequenceError:
if len(iterables) == 1:
raise
raise RuntimeError(
"DEV: figure out how to handle this situation"
) from None

def _if_pop_check_destroyed(self, pop: bool) -> None:
if pop:
if self._destroyed:
Expand Down Expand Up @@ -622,6 +642,18 @@ async def _wait_for_next_key(self) -> None:
# check for exceptions
await task

def __start_task(self, item: K) -> "asyncio.Future[V]":
if self.concurrency:
# NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks
fut = self._queue.put_nowait(item)
else:
fut = create_task(
coro=self._wrapped_func(item, **self._wrapped_func_kwargs),
name=f"{self._name}[{item}]" if self._name else f"{item}",
)
dict.__setitem__(self, item, fut)
return fut

def __cleanup(self, t: "asyncio.Task[None]") -> None:
# clear the slot and let the bound Queue die
del self.__init_loader_coro
Expand Down

0 comments on commit e5d0232

Please sign in to comment.