diff --git a/a_sync/task.py b/a_sync/task.py index 5e6b37f8..ca177fbb 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -201,7 +201,7 @@ async def _wrapped_set_next( self._wrapped_func = _wrapped_set_next init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue() self.__init_loader_coro = exhaust_iterator( - self._tasks_for_iterables(*iterables), queue=init_loader_queue + self._start_tasks_for_iterables(*iterables), queue=init_loader_queue ) with contextlib.suppress(_NoRunningLoop): # its okay if we get this exception, we can start the task as soon as the loop starts @@ -221,16 +221,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": try: return dict.__getitem__(self, item) except KeyError: - if self.concurrency: - # NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks - fut = self._queue.put_nowait(item) - else: - fut = create_task( - coro=self._wrapped_func(item, **self._wrapped_func_kwargs), - name=f"{self._name}[{item}]" if self._name else f"{item}", - ) - dict.__setitem__(self, item, fut) - return fut + return self.__start_task(item) def __await__(self) -> Generator[Any, None, Dict[K, V]]: """Wait for all tasks to complete and return a dictionary of the results.""" @@ -592,6 +583,35 @@ async def _tasks_for_iterables( "DEV: figure out how to handle this situation" ) from None + @ASyncGeneratorFunction + async def _start_tasks_for_iterables( + self, *iterables: AnyIterableOrAwaitableIterable[K] + ) -> AsyncIterator[Tuple[K, "asyncio.Task[V]"]]: + """Start new tasks for each key in the provided iterables.""" + # if we have any regular containers we can yield their contents right away + containers = [ + iterable + for iterable in iterables + if not isinstance(iterable, AsyncIterable) + and isinstance(iterable, Iterable) + ] + for iterable in containers: + async for key in _yield_keys(iterable): + yield key, self.__start_task(key) + + if remaining := [ + iterable for iterable in iterables if iterable not in containers + ]: + try: + async for key in as_yielded(*[_yield_keys(iterable) for iterable in remaining]): # type: ignore [attr-defined] + yield key, self.__start_task(key) + except _EmptySequenceError: + if len(iterables) == 1: + raise + raise RuntimeError( + "DEV: figure out how to handle this situation" + ) from None + def _if_pop_check_destroyed(self, pop: bool) -> None: if pop: if self._destroyed: @@ -622,6 +642,18 @@ async def _wait_for_next_key(self) -> None: # check for exceptions await task + def __start_task(self, item: K) -> "asyncio.Future[V]": + if self.concurrency: + # NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks + fut = self._queue.put_nowait(item) + else: + fut = create_task( + coro=self._wrapped_func(item, **self._wrapped_func_kwargs), + name=f"{self._name}[{item}]" if self._name else f"{item}", + ) + dict.__setitem__(self, item, fut) + return fut + def __cleanup(self, t: "asyncio.Task[None]") -> None: # clear the slot and let the bound Queue die del self.__init_loader_coro