diff --git a/aioamqp/channel.py b/aioamqp/channel.py index 7f0f402..a1d54ed 100644 --- a/aioamqp/channel.py +++ b/aioamqp/channel.py @@ -41,6 +41,7 @@ def __init__(self, protocol, channel_id, return_callback=None): self._queue_bind_lock = asyncio.Lock() self._futures = {} self._ctag_events = {} + self._ctags_queue_map = {} def _set_waiter(self, rpc_name): if rpc_name in self._futures: @@ -466,6 +467,77 @@ async def basic_server_nack(self, frame, delivery_tag=None): logger.debug('Received nack for delivery tag %r', delivery_tag) fut.set_exception(exceptions.PublishFailed(delivery_tag)) + def consume(self, queue_name='', consumer_tag='', no_local=False, no_ack=False, + exclusive=False, no_wait=False, arguments=None): + + + consumer_tag = consumer_tag or 'ctag%i.%s' % (self.channel_id, uuid.uuid4().hex) + if arguments is None: + arguments = {} + + + class MessageQueue: + def __init__(self, channel, queue_name, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments): + self.channel = channel + self.aio_queue = asyncio.Queue(loop=channel._loop) + self.amqp_response = None + self.queue_name = queue_name + self.consumer_tag = consumer_tag + self.no_local = no_local + self.no_ack = no_ack + self.exclusive = exclusive + self.no_wait = no_wait + self.arguments = arguments + + async def __aenter__(self): + request = pamqp.specification.Basic.Consume( + queue=self.queue_name, + consumer_tag=self.consumer_tag, + no_local=self.no_local, + no_ack=self.no_ack, + exclusive=self.exclusive, + nowait=self.no_wait, + arguments=self.arguments + ) + if not self.no_wait: + self.channel._ctag_events[consumer_tag] = asyncio.Event(loop=self.channel._loop) + + self.amqp_response = await self.channel._write_frame_awaiting_response( + 'basic_consume' + consumer_tag, self.channel.channel_id, request, no_wait + ) + + return self + + def __aiter__(self): + return self + + async def __aexit__(self, *args, **kwars): + pass + + async def __anext__(self): + return await self.aio_queue.get() + + async def put(self, message_list): + return await self.aio_queue.put(message_list) + + async def qos(self, prefetch_size=0, prefetch_count=0, connection_global=False): + return await self.channel.basic_qos(prefetch_size, prefetch_count, connection_global) + + + queue = MessageQueue( + self, + queue_name, + consumer_tag, + no_local, + no_ack, + exclusive, + no_wait, + arguments, + ) + self._ctags_queue_map[consumer_tag] = queue + + return queue + async def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False, no_ack=False, exclusive=False, no_wait=False, arguments=None): """Starts the consumption of message into a queue. @@ -502,13 +574,13 @@ async def basic_consume(self, callback, queue_name='', consumer_tag='', no_local self.consumer_callbacks[consumer_tag] = callback self.last_consumer_tag = consumer_tag - - return_value = await self._write_frame_awaiting_response( - 'basic_consume' + consumer_tag, self.channel_id, request, no_wait) if no_wait: return_value = {'consumer_tag': consumer_tag} else: - self._ctag_events[consumer_tag].set() + self._ctag_events[consumer_tag] = asyncio.Event(loop=self._loop) + + return_value = await self._write_frame_awaiting_response( + 'basic_consume' + consumer_tag, self.channel_id, request, no_wait) return return_value async def basic_consume_ok(self, frame): @@ -518,7 +590,7 @@ async def basic_consume_ok(self, frame): } future = self._get_waiter('basic_consume' + ctag) future.set_result(results) - self._ctag_events[ctag] = asyncio.Event(loop=self._loop) + self._ctag_events[ctag].set() async def basic_deliver(self, frame): consumer_tag = frame.consumer_tag @@ -538,14 +610,19 @@ async def basic_deliver(self, frame): envelope = Envelope(consumer_tag, delivery_tag, exchange_name, routing_key, is_redeliver) properties = amqp_properties.from_pamqp(content_header_frame.properties) - callback = self.consumer_callbacks[consumer_tag] + consumer_queue = self._ctags_queue_map.get(consumer_tag) + if consumer_queue: + await consumer_queue.put([body, envelope, properties]) + + callback = self.consumer_callbacks.get(consumer_tag) event = self._ctag_events.get(consumer_tag) if event: await event.wait() del self._ctag_events[consumer_tag] - await callback(self, body, envelope, properties) + if callback: + await callback(self, body, envelope, properties) async def server_basic_cancel(self, frame): # https://www.rabbitmq.com/consumer-cancel.html diff --git a/examples/receive.py b/examples/receive.py index b1ab163..f2dfa97 100644 --- a/examples/receive.py +++ b/examples/receive.py @@ -7,16 +7,18 @@ import aioamqp -async def callback(channel, body, envelope, properties): - print(" [x] Received %r" % body) - async def receive(): transport, protocol = await aioamqp.connect() channel = await protocol.channel() await channel.queue_declare(queue_name='hello') - await channel.basic_consume(callback, queue_name='hello') + x = channel.consume(queue_name='hello') + await x.qos(prefetch_size=0, prefetch_count=1) + async with x as consumer: + async for message in consumer: + body, envelope, properties = message + await channel.basic_client_ack(delivery_tag=envelope.delivery_tag) event_loop = asyncio.get_event_loop()