Skip to content

Commit

Permalink
Fix two bugs by adding more type hints to CallbackCollection.
Browse files Browse the repository at this point in the history
The first bug is `Channel` passing `Optional[BaseException]` to `self.close()` while `RobustChannel` passed `asyncio.Future`

The second is registering a `CallbackCollection` instance as a callback for a different `CallbackCollection`. (Which was not supported before)
  • Loading branch information
Darsstar committed Sep 13, 2024
1 parent db400b7 commit ac79041
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 86 deletions.
30 changes: 22 additions & 8 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import dataclasses
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -248,7 +250,10 @@ class AbstractQueue:
arguments: Arguments
passive: bool
declaration_result: aiormq.spec.Queue.DeclareOk
close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractQueue,
[Optional[BaseException]],
]

@abstractmethod
def __init__(
Expand Down Expand Up @@ -504,8 +509,14 @@ class AbstractChannel(PoolInstance, ABC):
QUEUE_CLASS: Type[AbstractQueue]
EXCHANGE_CLASS: Type[AbstractExchange]

close_callbacks: CallbackCollection
return_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractChannel,
[Optional[BaseException]],
]
return_callbacks: CallbackCollection[
AbstractChannel,
[AbstractIncomingMessage],
]
default_exchange: AbstractExchange

publisher_confirms: bool
Expand Down Expand Up @@ -715,7 +726,10 @@ def parse(self, value: Optional[str]) -> Any:
class AbstractConnection(PoolInstance, ABC):
PARAMETERS: Tuple[ConnectionParameter, ...]

close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractConnection,
[Optional[BaseException]],
]
connected: asyncio.Event
transport: Optional[UnderlayConnection]
kwargs: Mapping[str, Any]
Expand Down Expand Up @@ -832,7 +846,7 @@ async def bind(


class AbstractRobustChannel(AbstractChannel):
reopen_callbacks: CallbackCollection
reopen_callbacks: CallbackCollection[AbstractRobustChannel, []]

@abstractmethod
def reopen(self) -> Awaitable[None]:
Expand Down Expand Up @@ -875,7 +889,7 @@ async def declare_queue(


class AbstractRobustConnection(AbstractConnection):
reconnect_callbacks: CallbackCollection
reconnect_callbacks: CallbackCollection[AbstractRobustConnection, []]

@property
@abstractmethod
Expand All @@ -897,10 +911,10 @@ def channel(


ChannelCloseCallback = Callable[
[AbstractChannel, Optional[BaseException]], Any,
[Optional[AbstractChannel], Optional[BaseException]], Any,
]
ConnectionCloseCallback = Callable[
[AbstractConnection, Optional[BaseException]], Any,
[Optional[AbstractConnection], Optional[BaseException]], Any,
]
ConnectionType = TypeVar("ConnectionType", bound=AbstractConnection)

Expand Down
3 changes: 2 additions & 1 deletion aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ async def _on_close(

async def _set_closed_callback(
self,
_: AbstractChannel, exc: BaseException
_: Optional[AbstractChannel],
exc: Optional[BaseException],
) -> None:
if not self._closed.done():
self._closed.set_result(True)
Expand Down
8 changes: 4 additions & 4 deletions aio_pika/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def encode_expiration_timedelta(value: timedelta) -> str:
return str(int(value.total_seconds() * MILLISECONDS))


@encode_expiration.register(NoneType) # type: ignore
@encode_expiration.register(NoneType)
def encode_expiration_none(_: Any) -> None:
return None

Expand All @@ -62,7 +62,7 @@ def decode_expiration_str(t: str) -> float:
return float(t) / MILLISECONDS


@decode_expiration.register(NoneType) # type: ignore
@decode_expiration.register(NoneType)
def decode_expiration_none(_: Any) -> None:
return None

Expand All @@ -88,7 +88,7 @@ def encode_timestamp_timedelta(value: timedelta) -> datetime:
return datetime.now(tz=timezone.utc) + value


@encode_timestamp.register(NoneType) # type: ignore
@encode_timestamp.register(NoneType)
def encode_timestamp_none(_: Any) -> None:
return None

Expand All @@ -103,7 +103,7 @@ def decode_timestamp_datetime(value: datetime) -> datetime:
return value


@decode_timestamp.register(NoneType) # type: ignore
@decode_timestamp.register(NoneType)
def decode_timestamp_none(_: Any) -> None:
return None

Expand Down
6 changes: 3 additions & 3 deletions aio_pika/patterns/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AbstractChannel, AbstractExchange, AbstractIncomingMessage, AbstractQueue,
ConsumerTag, DeliveryMode,
)
from aio_pika.message import Message, ReturnedMessage
from aio_pika.message import Message

from ..tools import create_task, ensure_awaitable
from .base import Base, CallbackType, Proxy, T
Expand Down Expand Up @@ -113,8 +113,8 @@ def exchange(self) -> AbstractExchange:

@staticmethod
def on_message_returned(
channel: AbstractChannel,
message: ReturnedMessage,
channel: Optional[AbstractChannel],
message: AbstractIncomingMessage,
) -> None:
log.warning(
"Message returned. Probably destination queue does not exists: %r",
Expand Down
9 changes: 6 additions & 3 deletions aio_pika/patterns/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from aio_pika.exceptions import MessageProcessError
from aio_pika.exchange import ExchangeType
from aio_pika.message import IncomingMessage, Message, ReturnedMessage
from aio_pika.message import IncomingMessage, Message

from ..tools import ensure_awaitable
from .base import Base, CallbackType, Proxy, T
Expand Down Expand Up @@ -193,7 +193,8 @@ async def initialize(
self.channel.return_callbacks.add(self.on_message_returned)

def on_close(
self, channel: AbstractChannel,
self,
channel: Optional[AbstractChannel],
exc: Optional[ExceptionType] = None,
) -> None:
log.debug("Closing RPC futures because %r", exc)
Expand All @@ -218,7 +219,9 @@ async def create(cls, channel: AbstractChannel, **kwargs: Any) -> "RPC":
return rpc

def on_message_returned(
self, channel: AbstractChannel, message: ReturnedMessage,
self,
channel: Optional[AbstractChannel],
message: AbstractIncomingMessage,
) -> None:
if message.correlation_id is None:
log.warning(
Expand Down
8 changes: 4 additions & 4 deletions aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,21 @@ def consumer_tag(self) -> Optional[ConsumerTag]:
return getattr(self, "_consumer_tag", None)

async def close(self) -> None:
await self._on_close(self._amqp_queue.channel, None)
await self._on_close(self._amqp_queue, None)
if not self._closed.done():
self._closed.set_result(True)

async def _set_closed(
self,
_channel: AbstractChannel,
_channel: Optional[AbstractQueue],
exc: Optional[BaseException]
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_close(
self,
_channel: AbstractChannel,
_channel: Optional[AbstractQueue],
_exc: Optional[BaseException]
) -> None:
log.debug("Cancelling queue iterator %r", self)
Expand Down Expand Up @@ -503,7 +503,7 @@ def __repr__(self) -> str:

def __init__(self, queue: Queue, **kwargs: Any):
self._consumer_tag: ConsumerTag
self._amqp_queue: AbstractQueue = queue
self._amqp_queue: Queue = queue
self._queue = asyncio.Queue()
self._closed = asyncio.get_running_loop().create_future()
self._message_or_closed = asyncio.Event()
Expand Down
2 changes: 1 addition & 1 deletion aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self._prefetch_count: int = 0
self._prefetch_size: int = 0
self._global_qos: bool = False
self.reopen_callbacks: CallbackCollection = CallbackCollection(self)
self.reopen_callbacks = CallbackCollection(self)
self.__restore_lock = asyncio.Lock()
self.__restored = asyncio.Event()

Expand Down
6 changes: 2 additions & 4 deletions aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.__reconnection_task: Optional[asyncio.Task] = None

self._reconnect_lock = asyncio.Lock()
self.reconnect_callbacks: CallbackCollection = CallbackCollection(self)
self.reconnect_callbacks = CallbackCollection(self)

self.__connection_close_event.set()

Expand Down Expand Up @@ -104,9 +104,7 @@ async def _on_connected(self) -> None:
log.exception("Failed to reopen channel")
raise
except Exception as e:
closing = self.loop.create_future()
closing.set_exception(e)
await self.close_callbacks(closing)
await self.close_callbacks(e)
await asyncio.gather(
transport.connection.close(e),
return_exceptions=True,
Expand Down
Loading

0 comments on commit ac79041

Please sign in to comment.