Skip to content

Commit

Permalink
infinite loop for passing messages
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jun 6, 2024
1 parent 57cf2e4 commit d920dd8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
40 changes: 37 additions & 3 deletions agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Simple Message Queue."""

import asyncio
import random

from queue import Queue
from collections import deque
from typing import Any, Dict, List, Type
from llama_index.core.bridge.pydantic import Field
from agentfile.message_queues.base import BaseMessageQueue
Expand All @@ -19,7 +20,15 @@ class SimpleMessageQueue(BaseMessageQueue):
consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = Field(
default_factory=dict
)
queues: Dict[str, Queue] = Field(default_factory=dict)
queues: Dict[str, deque] = Field(default_factory=dict)
running: bool = True

def __init__(
self,
consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = {},
queues: Dict[str, deque] = {},
):
super().__init__(consumers=consumers, queues=queues)

def _select_consumer(self, message: BaseMessage) -> BaseMessageQueueConsumer:
"""Select a single consumer to publish a message to."""
Expand All @@ -28,14 +37,36 @@ def _select_consumer(self, message: BaseMessage) -> BaseMessageQueueConsumer:
return self.consumers[message_type_str][consumer_id]

async def _publish(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Publish message to a queue."""
message_type_str = message.class_name()

if message_type_str not in self.consumers:
raise ValueError(f"No consumer for {message_type_str} has been registered.")

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()

self.queues[message_type_str].append(message)

async def _publish_to_consumer(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Publish message to a consumer."""
consumer = self._select_consumer(message)
print(f"consumer: {consumer}")
try:
await consumer.process_message(message, **kwargs)
except Exception:
raise

async def start(self) -> None:
"""A loop for getting messages from queues and sending to consumer."""
while self.running:
print(self.queues)
for queue in self.queues.values():
if queue:
message = queue.popleft()
await self._publish_to_consumer(message)
print(self.queues)
await asyncio.sleep(0.1)

async def register_consumer(
self, consumer: BaseMessageQueueConsumer, **kwargs: Any
) -> None:
Expand All @@ -50,6 +81,9 @@ async def register_consumer(

self.consumers[message_type_str][consumer.id_] = consumer

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()

async def deregister_consumer(self, consumer: BaseMessageQueueConsumer) -> None:
message_type_str = consumer.message_type.class_name()
if consumer.id_ not in self.consumers[message_type_str]:
Expand Down
11 changes: 9 additions & 2 deletions tests/message_queues/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import pytest
from typing import Any, List
from agentfile.message_consumers.base import BaseMessageQueueConsumer
Expand All @@ -7,10 +8,11 @@

class MockMessageConsumer(BaseMessageQueueConsumer):
processed_messages: List[BaseMessage] = []
lock: asyncio.Lock = asyncio.Lock()

async def _process_message(self, message: BaseMessage, **kwargs: Any) -> None:
print(f"Processed: {message.class_name()}")
self.processed_messages.append(message)
async with self.lock:
self.processed_messages.append(message)


class MockMessage(BaseMessage):
Expand Down Expand Up @@ -70,6 +72,7 @@ async def test_simple_publish_consumer() -> None:
consumer_one = MockMessageConsumer()
consumer_two = MockMessageConsumer(message_type=MockMessage)
mq = SimpleMessageQueue()
task = asyncio.create_task(mq.start())

await mq.register_consumer(consumer_one)
await mq.register_consumer(consumer_two)
Expand All @@ -79,6 +82,10 @@ async def test_simple_publish_consumer() -> None:
await mq.publish(MockMessage(id_="2"))
await mq.publish(MockMessage(id_="3"))

# Give some time for last message to get published and sent to consumers
await asyncio.sleep(0.5)
task.cancel()

# Assert
assert ["1"] == [m.id_ for m in consumer_one.processed_messages]
assert ["2", "3"] == [m.id_ for m in consumer_two.processed_messages]

0 comments on commit d920dd8

Please sign in to comment.