diff --git a/rstream/client.py b/rstream/client.py index 95b401a..cc88045 100644 --- a/rstream/client.py +++ b/rstream/client.py @@ -86,9 +86,7 @@ def __init__( } self._corr_id_seq = utils.MonotonicSeq() - self._waiters: dict[ - tuple[constants.Key, Optional[int]], set[asyncio.Future[schema.Frame]] - ] = defaultdict(set) + self._waiters: dict[tuple[constants.Key, Optional[int]], asyncio.Future[schema.Frame]] = {} self._tasks: dict[str, asyncio.Task[None]] = {} self._handlers: dict[Type[schema.Frame], dict[str, HT[Any]]] = defaultdict(dict) @@ -155,8 +153,8 @@ def wait_frame( fut: asyncio.Future[schema.Frame] = asyncio.Future() _key = frame_cls.key, corr_id - self._waiters[_key].add(fut) - fut.add_done_callback(self._waiters[_key].discard) + self._waiters[_key] = fut + fut.add_done_callback(lambda _: self._waiters.pop(_key, None)) return utils.TimeoutWrapper(fut, timeout) async def sync_request(self, frame: schema.Frame, resp_schema: Type[FT], raise_exception=True) -> FT: @@ -198,11 +196,12 @@ async def _listener(self) -> None: break logger.debug("Received frame: %s", frame) - _key = frame.key, frame.corr_id - while self._waiters[_key]: - fut = self._waiters[_key].pop() + _key = frame.key, frame.corr_id + fut = self._waiters.get(_key) + if fut is not None: fut.set_result(frame) + del self._waiters[_key] for _, handler in self._handlers.get(frame.__class__, {}).items(): try: