From 84a6b72ad369c41b3341ad6cf4ec753235550ed2 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Thu, 22 Sep 2022 09:53:35 +0300 Subject: [PATCH 1/8] chore: remove loop --- docs/source/examples/master.py | 3 +-- docs/source/examples/pooling.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/docs/source/examples/master.py b/docs/source/examples/master.py index cc1b824d..750786ea 100644 --- a/docs/source/examples/master.py +++ b/docs/source/examples/master.py @@ -27,5 +27,4 @@ async def main() -> None: if __name__ == "__main__": - loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + asyncio.run(main()) diff --git a/docs/source/examples/pooling.py b/docs/source/examples/pooling.py index 9ed056b1..af8d5c3c 100644 --- a/docs/source/examples/pooling.py +++ b/docs/source/examples/pooling.py @@ -6,18 +6,16 @@ async def main() -> None: - loop = asyncio.get_event_loop() - async def get_connection() -> AbstractRobustConnection: return await aio_pika.connect_robust("amqp://guest:guest@localhost/") - connection_pool: Pool = Pool(get_connection, max_size=2, loop=loop) + connection_pool: Pool = Pool(get_connection, max_size=2) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() - channel_pool: Pool = Pool(get_channel, max_size=10, loop=loop) + channel_pool: Pool = Pool(get_channel, max_size=10) queue_name = "pool_queue" async def consume() -> None: @@ -41,7 +39,7 @@ async def publish() -> None: ) async with connection_pool, channel_pool: - task = loop.create_task(consume()) + task = asyncio.create_task(consume()) await asyncio.wait([publish() for _ in range(50)]) await task From 54b814b3c6d6c3357e1a29dbb1369aa31f93d980 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Sun, 25 Sep 2022 22:28:16 +0300 Subject: [PATCH 2/8] chore: remove direct loop --- aio_pika/abc.py | 7 ++--- aio_pika/channel.py | 2 -- aio_pika/connection.py | 19 ++++++------ aio_pika/patterns/master.py | 16 ++-------- aio_pika/patterns/rpc.py | 6 ++-- aio_pika/pool.py | 3 -- aio_pika/robust_channel.py | 2 -- aio_pika/robust_connection.py | 10 ++----- aio_pika/tools.py | 10 +++---- .../examples/6-rpc/rpc_client.py | 10 +++---- tests/conftest.py | 10 +++---- tests/test_amqp.py | 21 +++++++------- tests/test_amqp_robust.py | 4 +-- tests/test_amqp_robust_proxy.py | 29 +++++++++---------- tests/test_amqps.py | 3 +- tests/test_pool.py | 12 ++++---- tests/test_rpc.py | 4 +-- tests/test_tools.py | 2 +- 18 files changed, 68 insertions(+), 102 deletions(-) diff --git a/aio_pika/abc.py b/aio_pika/abc.py index cadc91f1..b6dd3869 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -669,12 +669,9 @@ class AbstractConnection(PoolInstance, ABC): transport: Optional[UnderlayConnection] @abstractmethod - def __init__( - self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, - **kwargs: Any, - ): + def __init__(self, url: URL, **kwargs: Any): raise NotImplementedError( - f"Method not implemented, passed: url={url}, loop={loop!r}", + f"Method not implemented, passed: url={url}", ) @property diff --git a/aio_pika/channel.py b/aio_pika/channel.py index a9852452..17351762 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -60,8 +60,6 @@ def __init__( """ :param connection: :class:`aio_pika.adapter.AsyncioConnection` instance - :param loop: Event loop (:func:`asyncio.get_event_loop()` - when :class:`None`) :param future_store: :class:`aio_pika.common.FutureStore` instance :param publisher_confirms: False if you don't need delivery confirmations (in pursuit of performance) diff --git a/aio_pika/connection.py b/aio_pika/connection.py index 504b4017..2364c3d5 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -51,10 +51,12 @@ def _parse_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: return result def __init__( - self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, - ssl_context: Optional[SSLContext] = None, **kwargs: Any + self, + url: URL, + ssl_context: Optional[SSLContext] = None, + **kwargs: Any, ): - self.loop = loop or asyncio.get_event_loop() + self._loop = asyncio.get_event_loop() self.transport = None self._closed = False self._close_called = False @@ -112,7 +114,7 @@ def channel( import aio_pika - async def main(loop): + async def main(): connection = await aio_pika.connect( "amqp://guest:guest@127.0.0.1/" ) @@ -136,7 +138,7 @@ async def main(loop): import aio_pika - async def main(loop): + async def main(): connection = await aio_pika.connect( "amqp://guest:guest@127.0.0.1/" ) @@ -178,7 +180,7 @@ async def ready(self) -> None: def __del__(self) -> None: if ( self.is_closed or - self.loop.is_closed() or + self._loop.is_closed() or not hasattr(self, "connection") ): return @@ -256,7 +258,6 @@ async def connect( password: str = "guest", virtualhost: str = "/", ssl: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, ssl_options: Optional[SSLOptions] = None, ssl_context: Optional[SSLContext] = None, timeout: TimeoutType = None, @@ -331,8 +332,6 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds - :param loop: - Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param ssl_context: ssl.SSLContext instance :param connection_class: Factory of a new connection :param kwargs: addition parameters which will be passed to the connection. @@ -357,7 +356,7 @@ async def main(): client_properties=client_properties, **kwargs ), - loop=loop, ssl_context=ssl_context, + ssl_context=ssl_context, ) await connection.connect(timeout=timeout) diff --git a/aio_pika/patterns/master.py b/aio_pika/patterns/master.py index 70511fe1..7767129f 100644 --- a/aio_pika/patterns/master.py +++ b/aio_pika/patterns/master.py @@ -1,4 +1,3 @@ -import asyncio import gzip import json import logging @@ -14,10 +13,8 @@ ConsumerTag, DeliveryMode, ) from aio_pika.message import Message, ReturnedMessage - -from ..tools import create_task from .base import Base, Proxy - +from ..tools import create_task log = logging.getLogger(__name__) T = TypeVar("T") @@ -41,16 +38,11 @@ class Worker: __slots__ = ( "queue", "consumer_tag", - "loop", ) - def __init__( - self, queue: AbstractQueue, consumer_tag: ConsumerTag, - loop: asyncio.AbstractEventLoop, - ): + def __init__(self, queue: AbstractQueue, consumer_tag: ConsumerTag): self.queue = queue self.consumer_tag = consumer_tag - self.loop = loop def close(self) -> Awaitable[None]: """ Cancel subscription to the channel @@ -67,7 +59,6 @@ async def closer() -> None: class Master(Base): __slots__ = ( "channel", - "loop", "proxy", ) @@ -99,7 +90,6 @@ def __init__( :param channel: Initialized instance of :class:`aio_pika.Channel` """ self.channel: AbstractChannel = channel - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.proxy = Proxy(self.create_task) self.channel.return_callbacks.add(self.on_message_returned) @@ -186,7 +176,7 @@ async def create_worker( fn = awaitable(func) consumer_tag = await queue.consume(partial(self.on_message, fn)) - return Worker(queue, consumer_tag, self.loop) + return Worker(queue, consumer_tag) async def create_task( self, channel_name: str, diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index e5832d05..069ab58b 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -45,7 +45,7 @@ class RPCMessageType(str, Enum): class RPC(Base): __slots__ = ( "channel", - "loop", + "_loop", "proxy", "result_queue", "result_consumer_tag", @@ -88,7 +88,7 @@ def multiply(*, x, y): def __init__(self, channel: AbstractChannel): self.channel = channel - self.loop = asyncio.get_event_loop() + self._loop = asyncio.get_event_loop() self.proxy = Proxy(self.call) self.futures: Dict[str, asyncio.Future] = {} self.routes: Dict[str, Callable[..., Any]] = {} @@ -100,7 +100,7 @@ def __remove_future(self, future: asyncio.Future) -> None: self.futures.pop(str(id(future)), None) def create_future(self) -> Tuple[asyncio.Future, str]: - future = self.loop.create_future() + future = self._loop.create_future() log.debug("Create future for RPC call") correlation_id = str(uuid.uuid4()) self.futures[correlation_id] = future diff --git a/aio_pika/pool.py b/aio_pika/pool.py index 93fc4080..1c0a81db 100644 --- a/aio_pika/pool.py +++ b/aio_pika/pool.py @@ -38,7 +38,6 @@ class PoolInvalidStateError(RuntimeError): class Pool(Generic[T]): __slots__ = ( - "loop", "__max_size", "__items", "__constructor", @@ -54,9 +53,7 @@ def __init__( constructor: ConstructorType, *args: Any, max_size: Optional[int] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, ): - self.loop = loop or asyncio.get_event_loop() self.__closed = False self.__constructor: Callable[..., Awaitable[Any]] = awaitable( constructor, diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 20714768..643c37fa 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -42,8 +42,6 @@ def __init__( """ :param connection: :class:`aio_pika.adapter.AsyncioConnection` instance - :param loop: - Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param future_store: :class:`aio_pika.common.FutureStore` instance :param publisher_confirms: False if you don't need delivery confirmations diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index 8f4b59d9..86524a43 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -35,10 +35,9 @@ class RobustConnection(Connection, AbstractRobustConnection): def __init__( self, url: URL, - loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any, ): - super().__init__(url=url, loop=loop, **kwargs) + super().__init__(url=url, **kwargs) self.reconnect_interval = self.kwargs.pop("reconnect_interval") self.fail_fast = self.kwargs.pop("fail_fast") @@ -92,7 +91,7 @@ async def __connection_attempt( # from Exception and this needed for catch it first raise except Exception as e: - closing = self.loop.create_future() + closing = self._loop.create_future() closing.set_exception(e) await self.close_callbacks(closing) await asyncio.gather(connection.close(e), return_exceptions=True) @@ -183,7 +182,6 @@ async def connect_robust( password: str = "guest", virtualhost: str = "/", ssl: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, ssl_options: Optional[SSLOptions] = None, ssl_context: Optional[SSLContext] = None, timeout: TimeoutType = None, @@ -260,8 +258,6 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds - :param loop: - Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param ssl_context: ssl.SSLContext instance :param connection_class: Factory of a new connection :param kwargs: addition parameters which will be passed to the connection. @@ -286,7 +282,7 @@ async def main(): client_properties=client_properties, **kwargs ), - loop=loop, ssl_context=ssl_context, + ssl_context=ssl_context, ) await connection.connect(timeout=timeout) diff --git a/aio_pika/tools.py b/aio_pika/tools.py index 0e103d5b..37958a1c 100644 --- a/aio_pika/tools.py +++ b/aio_pika/tools.py @@ -46,13 +46,12 @@ def _task_done(future: asyncio.Future) -> None: def create_task( func: Callable[..., Union[Coroutine[Any, Any, T], Awaitable[T]]], *args: Any, - loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> Awaitable[T]: - loop = loop or asyncio.get_event_loop() + loop = asyncio.get_event_loop() if iscoroutinepartial(func): - task = loop.create_task(func(*args, **kwargs)) # type: ignore + task = asyncio.create_task(func(*args, **kwargs)) # type: ignore task.add_done_callback(_task_done) return task @@ -201,11 +200,10 @@ def __hash__(self) -> int: class OneShotCallback: - __slots__ = ("loop", "finished", "__lock", "callback", "__task") + __slots__ = ("finished", "__lock", "callback", "__task") def __init__(self, callback: Callable[..., Awaitable[T]]): self.callback: Callable[..., Awaitable[T]] = callback - self.loop = asyncio.get_event_loop() self.finished: asyncio.Event = asyncio.Event() self.__lock: asyncio.Lock = asyncio.Lock() self.__task: asyncio.Future @@ -234,7 +232,7 @@ async def __task_inner(self, *args: Any, **kwargs: Any) -> None: def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: if self.finished.is_set(): return STUB_AWAITABLE - self.__task = self.loop.create_task( + self.__task = asyncio.create_task( self.__task_inner(*args, **kwargs), ) return self.__task diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index 3fb56706..24bddd6b 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -12,16 +12,14 @@ class FibonacciRpcClient: connection: AbstractConnection channel: AbstractChannel callback_queue: AbstractQueue - loop: asyncio.AbstractEventLoop + _loop: asyncio.AbstractEventLoop def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self.loop = asyncio.get_running_loop() + self._loop = asyncio.get_running_loop() async def connect(self) -> "FibonacciRpcClient": - self.connection = await connect( - "amqp://guest:guest@localhost/", loop=self.loop, - ) + self.connection = await connect("amqp://guest:guest@localhost/") self.channel = await self.connection.channel() self.callback_queue = await self.channel.declare_queue(exclusive=True) await self.callback_queue.consume(self.on_response) @@ -38,7 +36,7 @@ def on_response(self, message: AbstractIncomingMessage) -> None: async def call(self, n: int) -> int: correlation_id = str(uuid.uuid4()) - future = self.loop.create_future() + future = self._loop.create_future() self.futures[correlation_id] = future diff --git a/tests/conftest.py b/tests/conftest.py index 619d1d8a..83c3b547 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ @pytest.fixture -async def add_cleanup(loop): +async def add_cleanup(): entities = [] def payload(func, *args, **kwargs): @@ -34,12 +34,12 @@ def payload(func, *args, **kwargs): @pytest.fixture -async def create_task(loop): +async def create_task(): tasks = [] def payload(coroutine): nonlocal tasks - task = loop.create_task(coroutine) + task = asyncio.create_task(coroutine) tasks.append(task) return task @@ -91,8 +91,8 @@ def connection_fabric(request): @pytest.fixture -def create_connection(connection_fabric, loop, amqp_url): - return partial(connection_fabric, amqp_url, loop=loop) +def create_connection(connection_fabric, amqp_url): + return partial(connection_fabric, amqp_url) @pytest.fixture diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 26092f1e..78ffd4f7 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -46,7 +46,7 @@ def declare_exchange_(declare_exchange): class TestCaseAmqp(TestCaseAmqpBase): - async def test_properties(self, loop, connection: aio_pika.Connection): + async def test_properties(self, connection: aio_pika.Connection): assert not connection.is_closed async def test_channel_close(self, connection: aio_pika.Connection): @@ -1205,7 +1205,7 @@ async def test_transaction_simple_async_rollback( async with channel.transaction(): raise ValueError - async def test_async_for_queue(self, loop, connection, declare_queue): + async def test_async_for_queue(self, connection, declare_queue): channel2 = await self.create_channel(connection) queue = await declare_queue( @@ -1224,7 +1224,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - loop.create_task(publisher()) + asyncio.create_task(publisher()) count = 0 data = list() @@ -1240,7 +1240,7 @@ async def publisher(): assert data == list(map(lambda x: str(x).encode(), range(messages))) async def test_async_for_queue_context( - self, loop, connection, declare_queue, + self, connection, declare_queue, ): channel2 = await self.create_channel(connection) @@ -1260,7 +1260,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - loop.create_task(publisher()) + asyncio.create_task(publisher()) count = 0 data = list() @@ -1277,7 +1277,7 @@ async def publisher(): assert data == list(map(lambda x: str(x).encode(), range(messages))) async def test_async_with_connection( - self, create_connection: Callable, connection, loop, declare_queue, + self, create_connection: Callable, connection, declare_queue, ): async with await create_connection() as connection: @@ -1299,7 +1299,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - loop.create_task(publisher()) + asyncio.create_task(publisher()) count = 0 data = list() @@ -1401,7 +1401,7 @@ async def task_inner(): future.set_exception(e) raise - task = loop.create_task(task_inner()) + task = asyncio.create_task(task_inner()) await event.wait() loop.call_soon(task.cancel) @@ -1415,7 +1415,6 @@ async def task_inner(): async def test_queue_iterator_close_with_noack( self, create_connection: Callable, - loop, add_cleanup: Callable, declare_queue, ): @@ -1455,7 +1454,7 @@ async def task_inner(): Message(body), routing_key=queue_name, ) - task = loop.create_task(task_inner()) + task = asyncio.create_task(task_inner()) await task @@ -1561,7 +1560,7 @@ async def run(): await channel.set_qos(10) async def test_heartbeat_disabling( - self, loop, amqp_url: URL, connection_fabric, + self, amqp_url: URL, connection_fabric, ): url = amqp_url.update_query(heartbeat=0) connection: AbstractConnection = await connection_fabric(url) diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index 92b4d4df..5c0195ca 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -18,8 +18,8 @@ def connection_fabric(): @pytest.fixture -def create_connection(connection_fabric, loop, amqp_url): - return partial(connection_fabric, amqp_url, loop=loop) +def create_connection(connection_fabric, amqp_url): + return partial(connection_fabric, amqp_url) class TestCaseNoRobust(TestCaseAmqp): diff --git a/tests/test_amqp_robust_proxy.py b/tests/test_amqp_robust_proxy.py index e8faf255..474d46ce 100644 --- a/tests/test_amqp_robust_proxy.py +++ b/tests/test_amqp_robust_proxy.py @@ -56,20 +56,19 @@ def connection_fabric(): @pytest.fixture -def create_direct_connection(loop, amqp_direct_url): +def create_direct_connection(amqp_direct_url): return partial( aio_pika.connect, amqp_direct_url.update_query( name=amqp_direct_url.query["name"] + "::direct", heartbeat=30, ), - loop=loop, ) @pytest.fixture -def create_connection(connection_fabric, loop, amqp_url): - return partial(connection_fabric, amqp_url, loop=loop) +def create_connection(connection_fabric, amqp_url): + return partial(connection_fabric, amqp_url) @pytest.fixture @@ -144,7 +143,7 @@ async def test_robust_reconnect( create_connection, direct_connection, proxy: TCPProxy, loop, add_cleanup: Callable, ): - read_conn = await create_connection() # type: aio_pika.RobustConnection + read_conn = await create_connection() # type: aio_pika.RobustConnection reconnect_event = asyncio.Event() read_conn.reconnect_callbacks.add( @@ -191,7 +190,7 @@ async def reader(queue_name): logging.info("Exit reader task") try: - reader_task = loop.create_task(reader(queue.name)) + reader_task = asyncio.create_task(reader(queue.name)) await consumer_event.wait() logging.info("Disconnect all clients") @@ -251,7 +250,7 @@ async def test_channel_locked_resource2(connection: aio_pika.RobustConnection): async def test_channel_close_when_exclusive_queue( - create_connection, create_direct_connection, proxy: TCPProxy, loop, + create_connection, create_direct_connection, proxy: TCPProxy, ): logging.info("Creating connections") direct_conn, proxy_conn = await asyncio.gather( @@ -290,7 +289,7 @@ async def close_after(delay, closer): await closer() logging.info("Closed") - await loop.create_task(close_after(5, direct_conn.close)) + await asyncio.create_task(close_after(5, direct_conn.close)) # reconnect fired await reconnect_event.wait() @@ -421,14 +420,14 @@ async def reader(queue: aio_pika.Queue): @aiomisc.timeout(10) async def test_channel_restore( - connection_fabric, loop, amqp_url, proxy: TCPProxy, add_cleanup: Callable, + connection_fabric, amqp_url, proxy: TCPProxy, add_cleanup: Callable, ): heartbeat = 10 amqp_url = amqp_url.update_query(heartbeat=heartbeat) on_reopen = asyncio.Event() - conn = await connection_fabric(amqp_url, loop=loop) + conn = await connection_fabric(amqp_url) assert isinstance(conn, aio_pika.RobustConnection) async with conn: @@ -451,11 +450,11 @@ async def test_channel_restore( @aiomisc.timeout(20) async def test_channel_reconnect( - connection_fabric, loop, amqp_url, proxy: TCPProxy, add_cleanup: Callable, + connection_fabric, amqp_url, proxy: TCPProxy, add_cleanup: Callable, ): on_reconnect = asyncio.Event() - conn = await connection_fabric(amqp_url, loop=loop) + conn = await connection_fabric(amqp_url) assert isinstance(conn, aio_pika.RobustConnection) conn.reconnect_callbacks.add(lambda *_: on_reconnect.set(), weak=False) @@ -492,9 +491,8 @@ async def test_channel_reconnect_after_5kb( ): connection = await aio_pika.connect_robust( amqp_url.update_query(reconnect_interval=reconnect_timeout), - loop=loop, ) - direct_connection = await aio_pika.connect(amqp_direct_url, loop=loop) + direct_connection = await aio_pika.connect(amqp_direct_url) on_reconnect = asyncio.Event() connection.reconnect_callbacks.add( @@ -561,9 +559,8 @@ async def test_channel_reconnect_stairway( ): connection = await aio_pika.connect_robust( amqp_url.update_query(reconnect_interval=reconnect_timeout), - loop=loop, ) - direct_connection = await aio_pika.connect(amqp_direct_url, loop=loop) + direct_connection = await aio_pika.connect(amqp_direct_url) on_reconnect = asyncio.Event() connection.reconnect_callbacks.add( diff --git a/tests/test_amqps.py b/tests/test_amqps.py index 6dc08785..7e2c1fbb 100644 --- a/tests/test_amqps.py +++ b/tests/test_amqps.py @@ -15,7 +15,7 @@ def connection_fabric(request): @pytest.fixture -def create_connection(connection_fabric, loop, amqp_url): +def create_connection(connection_fabric, amqp_url): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.VerifyMode.CERT_NONE @@ -23,7 +23,6 @@ def create_connection(connection_fabric, loop, amqp_url): return partial( connection_fabric, amqp_url.with_scheme("amqps").with_port(5671), - loop=loop, ssl_context=ssl_context, ) diff --git a/tests/test_pool.py b/tests/test_pool.py index 0e916b3d..39b99090 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("max_size", [50, 10, 5, 1]) -async def test_simple(max_size, loop): +async def test_simple(max_size): counter = 0 async def create_instance(): @@ -16,7 +16,7 @@ async def create_instance(): counter += 1 return counter - pool = Pool(create_instance, max_size=max_size, loop=loop) + pool = Pool(create_instance, max_size=max_size) async def getter(): nonlocal counter, pool @@ -54,7 +54,7 @@ def max_size(self, request): return request.param @pytest.fixture - def pool(self, max_size, instances, loop): + def pool(self, max_size, instances): async def create_instance(): nonlocal instances @@ -62,11 +62,11 @@ async def create_instance(): instances.add(obj) return obj - return Pool(create_instance, max_size=max_size, loop=loop) + return Pool(create_instance, max_size=max_size) class TestInstance(TestInstanceBase): - async def test_close(self, pool, instances, loop, max_size): + async def test_close(self, pool, instances, max_size): async def getter(): async with pool.acquire(): await asyncio.sleep(0.05) @@ -114,7 +114,7 @@ async def getter(): class TestCaseNoMaxSize(TestInstance): - async def test_simple(self, pool, loop): + async def test_simple(self, pool): call_count = 200 counter = 0 diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 98f2bf7d..323110d5 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -137,7 +137,7 @@ async def test_send_unknown_message( await rpc.close() - async def test_close_cancelling(self, channel: aio_pika.Channel, loop): + async def test_close_cancelling(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) async def sleeper(): @@ -150,7 +150,7 @@ async def sleeper(): tasks = set() for _ in range(10): - tasks.add(loop.create_task(rpc.call(method_name))) + tasks.add(asyncio.create_task(rpc.call(method_name))) await rpc.close() diff --git a/tests/test_tools.py b/tests/test_tools.py index 8fcb1feb..106439f8 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -100,7 +100,7 @@ async def coro(arg): shared.append(arg) def task_maker(arg): - return loop.create_task(coro(arg)) + return asyncio.create_task(coro(arg)) collection.add(future.set_result) collection.add(coro) From f4ee7fd6df1409c61c864c688c55b1bcf9fa54e8 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Mon, 26 Sep 2022 14:35:23 +0300 Subject: [PATCH 3/8] Revert "chore: remove direct loop" This reverts commit 54b814b3c6d6c3357e1a29dbb1369aa31f93d980. --- aio_pika/abc.py | 7 +++-- aio_pika/channel.py | 2 ++ aio_pika/connection.py | 19 ++++++------ aio_pika/patterns/master.py | 16 ++++++++-- aio_pika/patterns/rpc.py | 6 ++-- aio_pika/pool.py | 3 ++ aio_pika/robust_channel.py | 2 ++ aio_pika/robust_connection.py | 10 +++++-- aio_pika/tools.py | 10 ++++--- .../examples/6-rpc/rpc_client.py | 10 ++++--- tests/conftest.py | 10 +++---- tests/test_amqp.py | 21 +++++++------- tests/test_amqp_robust.py | 4 +-- tests/test_amqp_robust_proxy.py | 29 ++++++++++--------- tests/test_amqps.py | 3 +- tests/test_pool.py | 12 ++++---- tests/test_rpc.py | 4 +-- tests/test_tools.py | 2 +- 18 files changed, 102 insertions(+), 68 deletions(-) diff --git a/aio_pika/abc.py b/aio_pika/abc.py index b6dd3869..cadc91f1 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -669,9 +669,12 @@ class AbstractConnection(PoolInstance, ABC): transport: Optional[UnderlayConnection] @abstractmethod - def __init__(self, url: URL, **kwargs: Any): + def __init__( + self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any, + ): raise NotImplementedError( - f"Method not implemented, passed: url={url}", + f"Method not implemented, passed: url={url}, loop={loop!r}", ) @property diff --git a/aio_pika/channel.py b/aio_pika/channel.py index 17351762..a9852452 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -60,6 +60,8 @@ def __init__( """ :param connection: :class:`aio_pika.adapter.AsyncioConnection` instance + :param loop: Event loop (:func:`asyncio.get_event_loop()` + when :class:`None`) :param future_store: :class:`aio_pika.common.FutureStore` instance :param publisher_confirms: False if you don't need delivery confirmations (in pursuit of performance) diff --git a/aio_pika/connection.py b/aio_pika/connection.py index 2364c3d5..504b4017 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -51,12 +51,10 @@ def _parse_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: return result def __init__( - self, - url: URL, - ssl_context: Optional[SSLContext] = None, - **kwargs: Any, + self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, + ssl_context: Optional[SSLContext] = None, **kwargs: Any ): - self._loop = asyncio.get_event_loop() + self.loop = loop or asyncio.get_event_loop() self.transport = None self._closed = False self._close_called = False @@ -114,7 +112,7 @@ def channel( import aio_pika - async def main(): + async def main(loop): connection = await aio_pika.connect( "amqp://guest:guest@127.0.0.1/" ) @@ -138,7 +136,7 @@ async def main(): import aio_pika - async def main(): + async def main(loop): connection = await aio_pika.connect( "amqp://guest:guest@127.0.0.1/" ) @@ -180,7 +178,7 @@ async def ready(self) -> None: def __del__(self) -> None: if ( self.is_closed or - self._loop.is_closed() or + self.loop.is_closed() or not hasattr(self, "connection") ): return @@ -258,6 +256,7 @@ async def connect( password: str = "guest", virtualhost: str = "/", ssl: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, ssl_options: Optional[SSLOptions] = None, ssl_context: Optional[SSLContext] = None, timeout: TimeoutType = None, @@ -332,6 +331,8 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds + :param loop: + Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param ssl_context: ssl.SSLContext instance :param connection_class: Factory of a new connection :param kwargs: addition parameters which will be passed to the connection. @@ -356,7 +357,7 @@ async def main(): client_properties=client_properties, **kwargs ), - ssl_context=ssl_context, + loop=loop, ssl_context=ssl_context, ) await connection.connect(timeout=timeout) diff --git a/aio_pika/patterns/master.py b/aio_pika/patterns/master.py index 7767129f..70511fe1 100644 --- a/aio_pika/patterns/master.py +++ b/aio_pika/patterns/master.py @@ -1,3 +1,4 @@ +import asyncio import gzip import json import logging @@ -13,8 +14,10 @@ ConsumerTag, DeliveryMode, ) from aio_pika.message import Message, ReturnedMessage -from .base import Base, Proxy + from ..tools import create_task +from .base import Base, Proxy + log = logging.getLogger(__name__) T = TypeVar("T") @@ -38,11 +41,16 @@ class Worker: __slots__ = ( "queue", "consumer_tag", + "loop", ) - def __init__(self, queue: AbstractQueue, consumer_tag: ConsumerTag): + def __init__( + self, queue: AbstractQueue, consumer_tag: ConsumerTag, + loop: asyncio.AbstractEventLoop, + ): self.queue = queue self.consumer_tag = consumer_tag + self.loop = loop def close(self) -> Awaitable[None]: """ Cancel subscription to the channel @@ -59,6 +67,7 @@ async def closer() -> None: class Master(Base): __slots__ = ( "channel", + "loop", "proxy", ) @@ -90,6 +99,7 @@ def __init__( :param channel: Initialized instance of :class:`aio_pika.Channel` """ self.channel: AbstractChannel = channel + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.proxy = Proxy(self.create_task) self.channel.return_callbacks.add(self.on_message_returned) @@ -176,7 +186,7 @@ async def create_worker( fn = awaitable(func) consumer_tag = await queue.consume(partial(self.on_message, fn)) - return Worker(queue, consumer_tag) + return Worker(queue, consumer_tag, self.loop) async def create_task( self, channel_name: str, diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index 069ab58b..e5832d05 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -45,7 +45,7 @@ class RPCMessageType(str, Enum): class RPC(Base): __slots__ = ( "channel", - "_loop", + "loop", "proxy", "result_queue", "result_consumer_tag", @@ -88,7 +88,7 @@ def multiply(*, x, y): def __init__(self, channel: AbstractChannel): self.channel = channel - self._loop = asyncio.get_event_loop() + self.loop = asyncio.get_event_loop() self.proxy = Proxy(self.call) self.futures: Dict[str, asyncio.Future] = {} self.routes: Dict[str, Callable[..., Any]] = {} @@ -100,7 +100,7 @@ def __remove_future(self, future: asyncio.Future) -> None: self.futures.pop(str(id(future)), None) def create_future(self) -> Tuple[asyncio.Future, str]: - future = self._loop.create_future() + future = self.loop.create_future() log.debug("Create future for RPC call") correlation_id = str(uuid.uuid4()) self.futures[correlation_id] = future diff --git a/aio_pika/pool.py b/aio_pika/pool.py index 1c0a81db..93fc4080 100644 --- a/aio_pika/pool.py +++ b/aio_pika/pool.py @@ -38,6 +38,7 @@ class PoolInvalidStateError(RuntimeError): class Pool(Generic[T]): __slots__ = ( + "loop", "__max_size", "__items", "__constructor", @@ -53,7 +54,9 @@ def __init__( constructor: ConstructorType, *args: Any, max_size: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ): + self.loop = loop or asyncio.get_event_loop() self.__closed = False self.__constructor: Callable[..., Awaitable[Any]] = awaitable( constructor, diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 643c37fa..20714768 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -42,6 +42,8 @@ def __init__( """ :param connection: :class:`aio_pika.adapter.AsyncioConnection` instance + :param loop: + Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param future_store: :class:`aio_pika.common.FutureStore` instance :param publisher_confirms: False if you don't need delivery confirmations diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index 86524a43..8f4b59d9 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -35,9 +35,10 @@ class RobustConnection(Connection, AbstractRobustConnection): def __init__( self, url: URL, + loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any, ): - super().__init__(url=url, **kwargs) + super().__init__(url=url, loop=loop, **kwargs) self.reconnect_interval = self.kwargs.pop("reconnect_interval") self.fail_fast = self.kwargs.pop("fail_fast") @@ -91,7 +92,7 @@ async def __connection_attempt( # from Exception and this needed for catch it first raise except Exception as e: - closing = self._loop.create_future() + closing = self.loop.create_future() closing.set_exception(e) await self.close_callbacks(closing) await asyncio.gather(connection.close(e), return_exceptions=True) @@ -182,6 +183,7 @@ async def connect_robust( password: str = "guest", virtualhost: str = "/", ssl: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, ssl_options: Optional[SSLOptions] = None, ssl_context: Optional[SSLContext] = None, timeout: TimeoutType = None, @@ -258,6 +260,8 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds + :param loop: + Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param ssl_context: ssl.SSLContext instance :param connection_class: Factory of a new connection :param kwargs: addition parameters which will be passed to the connection. @@ -282,7 +286,7 @@ async def main(): client_properties=client_properties, **kwargs ), - ssl_context=ssl_context, + loop=loop, ssl_context=ssl_context, ) await connection.connect(timeout=timeout) diff --git a/aio_pika/tools.py b/aio_pika/tools.py index 37958a1c..0e103d5b 100644 --- a/aio_pika/tools.py +++ b/aio_pika/tools.py @@ -46,12 +46,13 @@ def _task_done(future: asyncio.Future) -> None: def create_task( func: Callable[..., Union[Coroutine[Any, Any, T], Awaitable[T]]], *args: Any, + loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> Awaitable[T]: - loop = asyncio.get_event_loop() + loop = loop or asyncio.get_event_loop() if iscoroutinepartial(func): - task = asyncio.create_task(func(*args, **kwargs)) # type: ignore + task = loop.create_task(func(*args, **kwargs)) # type: ignore task.add_done_callback(_task_done) return task @@ -200,10 +201,11 @@ def __hash__(self) -> int: class OneShotCallback: - __slots__ = ("finished", "__lock", "callback", "__task") + __slots__ = ("loop", "finished", "__lock", "callback", "__task") def __init__(self, callback: Callable[..., Awaitable[T]]): self.callback: Callable[..., Awaitable[T]] = callback + self.loop = asyncio.get_event_loop() self.finished: asyncio.Event = asyncio.Event() self.__lock: asyncio.Lock = asyncio.Lock() self.__task: asyncio.Future @@ -232,7 +234,7 @@ async def __task_inner(self, *args: Any, **kwargs: Any) -> None: def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: if self.finished.is_set(): return STUB_AWAITABLE - self.__task = asyncio.create_task( + self.__task = self.loop.create_task( self.__task_inner(*args, **kwargs), ) return self.__task diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index 24bddd6b..3fb56706 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -12,14 +12,16 @@ class FibonacciRpcClient: connection: AbstractConnection channel: AbstractChannel callback_queue: AbstractQueue - _loop: asyncio.AbstractEventLoop + loop: asyncio.AbstractEventLoop def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self._loop = asyncio.get_running_loop() + self.loop = asyncio.get_running_loop() async def connect(self) -> "FibonacciRpcClient": - self.connection = await connect("amqp://guest:guest@localhost/") + self.connection = await connect( + "amqp://guest:guest@localhost/", loop=self.loop, + ) self.channel = await self.connection.channel() self.callback_queue = await self.channel.declare_queue(exclusive=True) await self.callback_queue.consume(self.on_response) @@ -36,7 +38,7 @@ def on_response(self, message: AbstractIncomingMessage) -> None: async def call(self, n: int) -> int: correlation_id = str(uuid.uuid4()) - future = self._loop.create_future() + future = self.loop.create_future() self.futures[correlation_id] = future diff --git a/tests/conftest.py b/tests/conftest.py index 83c3b547..619d1d8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ @pytest.fixture -async def add_cleanup(): +async def add_cleanup(loop): entities = [] def payload(func, *args, **kwargs): @@ -34,12 +34,12 @@ def payload(func, *args, **kwargs): @pytest.fixture -async def create_task(): +async def create_task(loop): tasks = [] def payload(coroutine): nonlocal tasks - task = asyncio.create_task(coroutine) + task = loop.create_task(coroutine) tasks.append(task) return task @@ -91,8 +91,8 @@ def connection_fabric(request): @pytest.fixture -def create_connection(connection_fabric, amqp_url): - return partial(connection_fabric, amqp_url) +def create_connection(connection_fabric, loop, amqp_url): + return partial(connection_fabric, amqp_url, loop=loop) @pytest.fixture diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 78ffd4f7..26092f1e 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -46,7 +46,7 @@ def declare_exchange_(declare_exchange): class TestCaseAmqp(TestCaseAmqpBase): - async def test_properties(self, connection: aio_pika.Connection): + async def test_properties(self, loop, connection: aio_pika.Connection): assert not connection.is_closed async def test_channel_close(self, connection: aio_pika.Connection): @@ -1205,7 +1205,7 @@ async def test_transaction_simple_async_rollback( async with channel.transaction(): raise ValueError - async def test_async_for_queue(self, connection, declare_queue): + async def test_async_for_queue(self, loop, connection, declare_queue): channel2 = await self.create_channel(connection) queue = await declare_queue( @@ -1224,7 +1224,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - asyncio.create_task(publisher()) + loop.create_task(publisher()) count = 0 data = list() @@ -1240,7 +1240,7 @@ async def publisher(): assert data == list(map(lambda x: str(x).encode(), range(messages))) async def test_async_for_queue_context( - self, connection, declare_queue, + self, loop, connection, declare_queue, ): channel2 = await self.create_channel(connection) @@ -1260,7 +1260,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - asyncio.create_task(publisher()) + loop.create_task(publisher()) count = 0 data = list() @@ -1277,7 +1277,7 @@ async def publisher(): assert data == list(map(lambda x: str(x).encode(), range(messages))) async def test_async_with_connection( - self, create_connection: Callable, connection, declare_queue, + self, create_connection: Callable, connection, loop, declare_queue, ): async with await create_connection() as connection: @@ -1299,7 +1299,7 @@ async def publisher(): Message(body=str(i).encode()), routing_key=queue.name, ) - asyncio.create_task(publisher()) + loop.create_task(publisher()) count = 0 data = list() @@ -1401,7 +1401,7 @@ async def task_inner(): future.set_exception(e) raise - task = asyncio.create_task(task_inner()) + task = loop.create_task(task_inner()) await event.wait() loop.call_soon(task.cancel) @@ -1415,6 +1415,7 @@ async def task_inner(): async def test_queue_iterator_close_with_noack( self, create_connection: Callable, + loop, add_cleanup: Callable, declare_queue, ): @@ -1454,7 +1455,7 @@ async def task_inner(): Message(body), routing_key=queue_name, ) - task = asyncio.create_task(task_inner()) + task = loop.create_task(task_inner()) await task @@ -1560,7 +1561,7 @@ async def run(): await channel.set_qos(10) async def test_heartbeat_disabling( - self, amqp_url: URL, connection_fabric, + self, loop, amqp_url: URL, connection_fabric, ): url = amqp_url.update_query(heartbeat=0) connection: AbstractConnection = await connection_fabric(url) diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index 5c0195ca..92b4d4df 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -18,8 +18,8 @@ def connection_fabric(): @pytest.fixture -def create_connection(connection_fabric, amqp_url): - return partial(connection_fabric, amqp_url) +def create_connection(connection_fabric, loop, amqp_url): + return partial(connection_fabric, amqp_url, loop=loop) class TestCaseNoRobust(TestCaseAmqp): diff --git a/tests/test_amqp_robust_proxy.py b/tests/test_amqp_robust_proxy.py index 474d46ce..e8faf255 100644 --- a/tests/test_amqp_robust_proxy.py +++ b/tests/test_amqp_robust_proxy.py @@ -56,19 +56,20 @@ def connection_fabric(): @pytest.fixture -def create_direct_connection(amqp_direct_url): +def create_direct_connection(loop, amqp_direct_url): return partial( aio_pika.connect, amqp_direct_url.update_query( name=amqp_direct_url.query["name"] + "::direct", heartbeat=30, ), + loop=loop, ) @pytest.fixture -def create_connection(connection_fabric, amqp_url): - return partial(connection_fabric, amqp_url) +def create_connection(connection_fabric, loop, amqp_url): + return partial(connection_fabric, amqp_url, loop=loop) @pytest.fixture @@ -143,7 +144,7 @@ async def test_robust_reconnect( create_connection, direct_connection, proxy: TCPProxy, loop, add_cleanup: Callable, ): - read_conn = await create_connection() # type: aio_pika.RobustConnection + read_conn = await create_connection() # type: aio_pika.RobustConnection reconnect_event = asyncio.Event() read_conn.reconnect_callbacks.add( @@ -190,7 +191,7 @@ async def reader(queue_name): logging.info("Exit reader task") try: - reader_task = asyncio.create_task(reader(queue.name)) + reader_task = loop.create_task(reader(queue.name)) await consumer_event.wait() logging.info("Disconnect all clients") @@ -250,7 +251,7 @@ async def test_channel_locked_resource2(connection: aio_pika.RobustConnection): async def test_channel_close_when_exclusive_queue( - create_connection, create_direct_connection, proxy: TCPProxy, + create_connection, create_direct_connection, proxy: TCPProxy, loop, ): logging.info("Creating connections") direct_conn, proxy_conn = await asyncio.gather( @@ -289,7 +290,7 @@ async def close_after(delay, closer): await closer() logging.info("Closed") - await asyncio.create_task(close_after(5, direct_conn.close)) + await loop.create_task(close_after(5, direct_conn.close)) # reconnect fired await reconnect_event.wait() @@ -420,14 +421,14 @@ async def reader(queue: aio_pika.Queue): @aiomisc.timeout(10) async def test_channel_restore( - connection_fabric, amqp_url, proxy: TCPProxy, add_cleanup: Callable, + connection_fabric, loop, amqp_url, proxy: TCPProxy, add_cleanup: Callable, ): heartbeat = 10 amqp_url = amqp_url.update_query(heartbeat=heartbeat) on_reopen = asyncio.Event() - conn = await connection_fabric(amqp_url) + conn = await connection_fabric(amqp_url, loop=loop) assert isinstance(conn, aio_pika.RobustConnection) async with conn: @@ -450,11 +451,11 @@ async def test_channel_restore( @aiomisc.timeout(20) async def test_channel_reconnect( - connection_fabric, amqp_url, proxy: TCPProxy, add_cleanup: Callable, + connection_fabric, loop, amqp_url, proxy: TCPProxy, add_cleanup: Callable, ): on_reconnect = asyncio.Event() - conn = await connection_fabric(amqp_url) + conn = await connection_fabric(amqp_url, loop=loop) assert isinstance(conn, aio_pika.RobustConnection) conn.reconnect_callbacks.add(lambda *_: on_reconnect.set(), weak=False) @@ -491,8 +492,9 @@ async def test_channel_reconnect_after_5kb( ): connection = await aio_pika.connect_robust( amqp_url.update_query(reconnect_interval=reconnect_timeout), + loop=loop, ) - direct_connection = await aio_pika.connect(amqp_direct_url) + direct_connection = await aio_pika.connect(amqp_direct_url, loop=loop) on_reconnect = asyncio.Event() connection.reconnect_callbacks.add( @@ -559,8 +561,9 @@ async def test_channel_reconnect_stairway( ): connection = await aio_pika.connect_robust( amqp_url.update_query(reconnect_interval=reconnect_timeout), + loop=loop, ) - direct_connection = await aio_pika.connect(amqp_direct_url) + direct_connection = await aio_pika.connect(amqp_direct_url, loop=loop) on_reconnect = asyncio.Event() connection.reconnect_callbacks.add( diff --git a/tests/test_amqps.py b/tests/test_amqps.py index 7e2c1fbb..6dc08785 100644 --- a/tests/test_amqps.py +++ b/tests/test_amqps.py @@ -15,7 +15,7 @@ def connection_fabric(request): @pytest.fixture -def create_connection(connection_fabric, amqp_url): +def create_connection(connection_fabric, loop, amqp_url): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.VerifyMode.CERT_NONE @@ -23,6 +23,7 @@ def create_connection(connection_fabric, amqp_url): return partial( connection_fabric, amqp_url.with_scheme("amqps").with_port(5671), + loop=loop, ssl_context=ssl_context, ) diff --git a/tests/test_pool.py b/tests/test_pool.py index 39b99090..0e916b3d 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("max_size", [50, 10, 5, 1]) -async def test_simple(max_size): +async def test_simple(max_size, loop): counter = 0 async def create_instance(): @@ -16,7 +16,7 @@ async def create_instance(): counter += 1 return counter - pool = Pool(create_instance, max_size=max_size) + pool = Pool(create_instance, max_size=max_size, loop=loop) async def getter(): nonlocal counter, pool @@ -54,7 +54,7 @@ def max_size(self, request): return request.param @pytest.fixture - def pool(self, max_size, instances): + def pool(self, max_size, instances, loop): async def create_instance(): nonlocal instances @@ -62,11 +62,11 @@ async def create_instance(): instances.add(obj) return obj - return Pool(create_instance, max_size=max_size) + return Pool(create_instance, max_size=max_size, loop=loop) class TestInstance(TestInstanceBase): - async def test_close(self, pool, instances, max_size): + async def test_close(self, pool, instances, loop, max_size): async def getter(): async with pool.acquire(): await asyncio.sleep(0.05) @@ -114,7 +114,7 @@ async def getter(): class TestCaseNoMaxSize(TestInstance): - async def test_simple(self, pool): + async def test_simple(self, pool, loop): call_count = 200 counter = 0 diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 323110d5..98f2bf7d 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -137,7 +137,7 @@ async def test_send_unknown_message( await rpc.close() - async def test_close_cancelling(self, channel: aio_pika.Channel): + async def test_close_cancelling(self, channel: aio_pika.Channel, loop): rpc = await RPC.create(channel, auto_delete=True) async def sleeper(): @@ -150,7 +150,7 @@ async def sleeper(): tasks = set() for _ in range(10): - tasks.add(asyncio.create_task(rpc.call(method_name))) + tasks.add(loop.create_task(rpc.call(method_name))) await rpc.close() diff --git a/tests/test_tools.py b/tests/test_tools.py index 106439f8..8fcb1feb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -100,7 +100,7 @@ async def coro(arg): shared.append(arg) def task_maker(arg): - return asyncio.create_task(coro(arg)) + return loop.create_task(coro(arg)) collection.add(future.set_result) collection.add(coro) From effc79c0e3f2be6f69c8014273d4356fc15dbaf9 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 28 Sep 2022 13:57:13 +0300 Subject: [PATCH 4/8] chore: remove loop from docs --- docs/source/rabbitmq-tutorial/2-work-queues.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/rabbitmq-tutorial/2-work-queues.rst b/docs/source/rabbitmq-tutorial/2-work-queues.rst index 104aec42..d6f26391 100644 --- a/docs/source/rabbitmq-tutorial/2-work-queues.rst +++ b/docs/source/rabbitmq-tutorial/2-work-queues.rst @@ -159,7 +159,7 @@ from the worker, once we're done with a task. async def on_message(message: IncomingMessage): print(" [x] Received %r" % message.body) - await asyncio.sleep(message.body.count(b'.'), loop=loop) + await asyncio.sleep(message.body.count(b'.')) print(" [x] Done") await message.ack() From 66ded0fe978d249c84309eb212a8a27913e1fd70 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 28 Sep 2022 13:57:38 +0300 Subject: [PATCH 5/8] chore: hide loop in FibonacciRpcClient --- .../rabbitmq-tutorial/examples/6-rpc/rpc_client.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index 3fb56706..18098820 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -12,16 +12,13 @@ class FibonacciRpcClient: connection: AbstractConnection channel: AbstractChannel callback_queue: AbstractQueue - loop: asyncio.AbstractEventLoop def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self.loop = asyncio.get_running_loop() + self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() async def connect(self) -> "FibonacciRpcClient": - self.connection = await connect( - "amqp://guest:guest@localhost/", loop=self.loop, - ) + self.connection = await connect("amqp://guest:guest@localhost/") self.channel = await self.connection.channel() self.callback_queue = await self.channel.declare_queue(exclusive=True) await self.callback_queue.consume(self.on_response) @@ -38,7 +35,7 @@ def on_response(self, message: AbstractIncomingMessage) -> None: async def call(self, n: int) -> int: correlation_id = str(uuid.uuid4()) - future = self.loop.create_future() + future = self._loop.create_future() self.futures[correlation_id] = future From c8de5e57e012f23f41eb20e7eb2b940cc7637518 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 11 Oct 2023 15:24:47 +0300 Subject: [PATCH 6/8] fix: turn back running loop --- docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index cc11d400..2b98285c 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -15,7 +15,7 @@ class FibonacciRpcClient: def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() async def connect(self) -> "FibonacciRpcClient": self.connection = await connect("amqp://guest:guest@localhost/") From c3d44eeb6d171cee233be2ca101831a474e6eb85 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 11 Oct 2023 15:26:35 +0300 Subject: [PATCH 7/8] chore: rm redundant typehint --- docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index 2b98285c..c60f40fc 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -15,7 +15,7 @@ class FibonacciRpcClient: def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + self._loop = asyncio.get_running_loop() async def connect(self) -> "FibonacciRpcClient": self.connection = await connect("amqp://guest:guest@localhost/") From ff3cc46846d3af86207041150d3da220cb4e2f63 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 11 Oct 2023 15:28:23 +0300 Subject: [PATCH 8/8] chore: rm loop from __init__ --- docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py index c60f40fc..bb313784 100644 --- a/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py +++ b/docs/source/rabbitmq-tutorial/examples/6-rpc/rpc_client.py @@ -15,7 +15,6 @@ class FibonacciRpcClient: def __init__(self) -> None: self.futures: MutableMapping[str, asyncio.Future] = {} - self._loop = asyncio.get_running_loop() async def connect(self) -> "FibonacciRpcClient": self.connection = await connect("amqp://guest:guest@localhost/") @@ -35,7 +34,8 @@ async def on_response(self, message: AbstractIncomingMessage) -> None: async def call(self, n: int) -> int: correlation_id = str(uuid.uuid4()) - future = self._loop.create_future() + loop = asyncio.get_running_loop() + future = loop.create_future() self.futures[correlation_id] = future