Skip to content

Commit

Permalink
Fixing some init bugs, and moving some parts to the constructor again…
Browse files Browse the repository at this point in the history
… if possible.
  • Loading branch information
rcschrg committed Oct 14, 2024
1 parent 6419dab commit 4d6d84c
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 165 deletions.
6 changes: 6 additions & 0 deletions mango/container/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def register(self, agent: Agent, suggested_aid: str = None):
logger.debug("Successfully registered agent;%s", aid)
return agent

def _get_aid(self, agent):
for aid, a in self._agents.items():
if id(a) == id(agent):
return aid
return None

def include(self, agent: A, suggested_aid: str = None) -> A:
"""Include the agent in the container. Return the agent for
convenience.
Expand Down
19 changes: 5 additions & 14 deletions mango/container/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
*,
client_id: str,
broker_addr: tuple | dict | str,
loop: asyncio.AbstractEventLoop,
clock: Clock,
codec: Codec,
inbox_topic: None | str = None,
Expand All @@ -76,7 +75,6 @@ def __init__(
super().__init__(
codec=codec,
addr=broker_addr,
loop=loop,
clock=clock,
name=client_id,
**kwargs,
Expand All @@ -92,6 +90,7 @@ def __init__(
self.pending_sub_request: None | asyncio.Future = None

async def start(self):
self._loop = asyncio.get_event_loop()
if not self.client_id:
raise ValueError("client_id is required!")
if not self.addr:
Expand Down Expand Up @@ -137,9 +136,7 @@ async def start(self):
# callbacks to check for successful connection
def on_con(client, userdata, flags, reason_code, properties):
logger.info("Connection Callback with the following flags: %s", flags)
asyncio.get_running_loop().call_soon_threadsafe(
connected.set_result, reason_code
)
self._loop.call_soon_threadsafe(connected.set_result, reason_code)

mqtt_messenger.on_connect = on_con

Expand Down Expand Up @@ -205,9 +202,7 @@ def on_con(client, userdata, flags, reason_code, properties):

# set up subscription callback
def on_sub(client, userdata, mid, reason_code_list, properties):
asyncio.get_running_loop().call_soon_threadsafe(
subscribed.set_result, True
)
self._loop.call_soon_threadsafe(subscribed.set_result, True)

mqtt_messenger.on_subscribe = on_sub

Expand Down Expand Up @@ -268,9 +263,7 @@ def on_discon(client, userdata, disconnect_flags, reason_code, properties):
self.mqtt_client.on_disconnect = on_discon

def on_sub(client, userdata, mid, reason_code_list, properties):
asyncio.get_running_loop().call_soon_threadsafe(
self.pending_sub_request.set_result, 0
)
self._loop.call_soon_threadsafe(self.pending_sub_request.set_result, 0)

self.mqtt_client.on_subscribe = on_sub

Expand All @@ -289,9 +282,7 @@ def on_message(client, userdata, message):
# update meta dict
meta.update(message_meta)
# put information to inbox
asyncio.get_running_loop().call_soon_threadsafe(
self.inbox.put_nowait, (0, content, meta)
)
self._loop.call_soon_threadsafe(self.inbox.put_nowait, (0, content, meta))

self.mqtt_client.on_message = on_message
self.mqtt_client.enable_logger(logger)
Expand Down
13 changes: 6 additions & 7 deletions mango/container/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ class TCPConnectionPool:

def __init__(
self,
asyncio_loop,
ttl_in_sec: float = 30.0,
max_connections_per_target: int = 10,
) -> None:
self._loop = asyncio_loop
self._available_connections = {}
self._connection_counts = {}
self._ttl_in_sec = ttl_in_sec
Expand Down Expand Up @@ -93,7 +91,7 @@ async def obtain_connection(
addr_key,
(
(
await self._loop.create_connection(
await asyncio.get_running_loop().create_connection(
lambda: protocol,
host,
port,
Expand Down Expand Up @@ -182,18 +180,17 @@ def __init__(
**kwargs,
)

self._tcp_connection_pool = None
self.server = None # will be set within start
self.running = False

async def start(self):
self._tcp_connection_pool = TCPConnectionPool(
asyncio.get_running_loop(),
ttl_in_sec=self._kwargs.get(TCP_CONNECTION_TTL, 30),
max_connections_per_target=self._kwargs.get(
TCP_MAX_CONNECTIONS_PER_TARGET, 10
),
)

async def start(self):
# create a TCP server bound to host and port that uses the
# specified protocol
self.server = await asyncio.get_running_loop().create_server(
Expand Down Expand Up @@ -303,7 +300,9 @@ async def shutdown(self):
calls shutdown() from super class Container and closes the server
"""
await super().shutdown()
await self._tcp_connection_pool.shutdown()

if self._tcp_connection_pool is not None:
await self._tcp_connection_pool.shutdown()

if self.server is not None:
self.server.close()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_message_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def handle_message(self, content, meta):
async def start(self):
if getattr(self.container, "subscribe_for_agent", None):
await self.container.subscribe_for_agent(
aid=self.aid, topic=self.target.addr
aid=self.aid, topic=self.target.protocol_addr
)

await asyncio.sleep(0.1)
Expand Down Expand Up @@ -119,7 +119,7 @@ def handle_message(self, content, meta):
async def start(self):
if getattr(self.container, "subscribe_for_agent", None):
await self.container.subscribe_for_agent(
aid=self.aid, topic=self.target.addr
aid=self.aid, topic=self.target.protocol_addr
)

# await "Hello"
Expand Down
30 changes: 9 additions & 21 deletions tests/unit_tests/container/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ async def test_connection_pool_obtain_release():
await c2.start()

addr = "127.0.0.2", 5556
connection_pool = TCPConnectionPool(asyncio.get_event_loop())
raw_prot = ContainerProtocol(
container=c, loop=asyncio.get_event_loop(), codec=c.codec
)
connection_pool = TCPConnectionPool()
raw_prot = ContainerProtocol(container=c, codec=c.codec)
protocol = await connection_pool.obtain_connection(addr[0], addr[1], raw_prot)

assert connection_pool._available_connections[addr].qsize() == 0
Expand All @@ -49,18 +47,14 @@ async def test_connection_pool_double_obtain_release():
await c2.start()

addr = "127.0.0.2", 5556
connection_pool = TCPConnectionPool(asyncio.get_event_loop())
raw_prot = ContainerProtocol(
container=c, loop=asyncio.get_event_loop(), codec=c.codec
)
connection_pool = TCPConnectionPool()
raw_prot = ContainerProtocol(container=c, codec=c.codec)
protocol = await connection_pool.obtain_connection(addr[0], addr[1], raw_prot)

assert connection_pool._available_connections[addr].qsize() == 0
assert connection_pool._connection_counts[addr] == 1

raw_prot = ContainerProtocol(
container=c, loop=asyncio.get_event_loop(), codec=c.codec
)
raw_prot = ContainerProtocol(container=c, codec=c.codec)
protocol2 = await connection_pool.obtain_connection(addr[0], addr[1], raw_prot)

assert connection_pool._available_connections[addr].qsize() == 0
Expand Down Expand Up @@ -92,10 +86,8 @@ async def test_ttl():
await c2.start()
await c3.start()

connection_pool = TCPConnectionPool(asyncio.get_event_loop(), ttl_in_sec=0.1)
raw_prot = ContainerProtocol(
container=c, loop=asyncio.get_event_loop(), codec=c.codec
)
connection_pool = TCPConnectionPool(ttl_in_sec=0.1)
raw_prot = ContainerProtocol(container=c, codec=c.codec)
protocol = await connection_pool.obtain_connection(addr[0], addr[1], raw_prot)

assert connection_pool._available_connections[addr].qsize() == 0
Expand Down Expand Up @@ -134,12 +126,8 @@ async def test_max_connections():
await c2.start()

addr = "127.0.0.2", 5556
connection_pool = TCPConnectionPool(
asyncio.get_event_loop(), max_connections_per_target=1
)
raw_prot = ContainerProtocol(
container=c, loop=asyncio.get_event_loop(), codec=c.codec
)
connection_pool = TCPConnectionPool(max_connections_per_target=1)
raw_prot = ContainerProtocol(container=c, codec=c.codec)
protocol = await connection_pool.obtain_connection(addr[0], addr[1], raw_prot)

with pytest.raises(asyncio.TimeoutError):
Expand Down
25 changes: 13 additions & 12 deletions tests/unit_tests/core/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ async def test_register_aid_pattern_match():
suggested_aid = "agent12"

# WHEN
actual_aid = c.register(agent, suggested_aid)
agent_r = c.register(agent, suggested_aid)

# THEN
assert actual_aid == "agent0"
assert c._get_aid(agent_r) == "agent0"
await c.shutdown()


Expand All @@ -35,10 +35,10 @@ async def test_register_aid_success():
suggested_aid = "cagent12"

# WHEN
actual_aid = c.register(agent, suggested_aid)
agent_r = c.register(agent, suggested_aid)

# THEN
assert actual_aid == suggested_aid
assert c._get_aid(agent_r) == suggested_aid
await c.shutdown()


Expand All @@ -49,10 +49,10 @@ async def test_register_no_suggested():
agent = LooksLikeAgent()

# WHEN
actual_aid = c.register(agent)
agent_r = c.register(agent)

# THEN
assert actual_aid == "agent0"
assert c._get_aid(agent_r) == "agent0"
await c.shutdown()


Expand All @@ -64,10 +64,10 @@ async def test_register_pattern_half_match():
suggested_aid = "agentABC"

# WHEN
actual_aid = c.register(agent, suggested_aid)
agent_r = c.register(agent, suggested_aid)

# THEN
assert actual_aid == "agentABC"
assert c._get_aid(agent_r) == "agentABC"
await c.shutdown()


Expand All @@ -76,15 +76,16 @@ async def test_register_existing():
# GIVEN
c = create_tcp_container(addr=("127.0.0.2", 5555))
agent = LooksLikeAgent()
agent2 = LooksLikeAgent()
suggested_aid = "agentABC"

# WHEN
actual_aid = c.register(agent, suggested_aid)
actual_aid2 = c.register(agent, suggested_aid)
agent_r = c.register(agent, suggested_aid)
agent_r2 = c.register(agent2, suggested_aid)

# THEN
assert actual_aid == "agentABC"
assert actual_aid2 == "agent0"
assert c._get_aid(agent_r) == "agentABC"
assert c._get_aid(agent_r2) == "agent0"
await c.shutdown()


Expand Down
Loading

0 comments on commit 4d6d84c

Please sign in to comment.