Skip to content

Commit

Permalink
set connection_name - default or set from user (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia authored Oct 19, 2023
1 parent cd50fcb commit 6038781
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 13 deletions.
29 changes: 24 additions & 5 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
ssl_context: Optional[ssl.SSLContext] = None,
frame_max: int,
heartbeat: int,
connection_name: Optional[str] = "",
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
):
self.host = host
Expand All @@ -79,7 +80,9 @@ def __init__(
self._conn: Optional[Connection] = None

self.server_properties: Optional[dict[str, str]] = None

self._client_properties = {
"connection_name": str(connection_name),
"product": "RabbitMQ Stream",
"platform": "Python",
"version": __version__,
Expand Down Expand Up @@ -621,6 +624,7 @@ def __init__(

async def get(
self,
connection_name: Optional[str],
addr: Optional[Addr] = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
) -> Client:
Expand All @@ -637,25 +641,36 @@ async def get(
if desired_addr not in self._clients:
if addr and self.load_balancer_mode:
self._clients[desired_addr] = await self._resolve_broker(
desired_addr, connection_closed_handler
addr=desired_addr,
connection_closed_handler=connection_closed_handler,
connection_name=connection_name,
)
else:
self._clients[desired_addr] = await self.new(
addr=desired_addr, connection_closed_handler=connection_closed_handler
addr=desired_addr,
connection_closed_handler=connection_closed_handler,
connection_name=connection_name,
)

assert self._clients[desired_addr].is_started
return self._clients[desired_addr]

async def _resolve_broker(
self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None
self,
connection_name: Optional[str],
addr: Addr,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
) -> Client:
desired_host, desired_port = addr.host, str(addr.port)

connection_attempts = 0

while connection_attempts < self.max_retries:
client = await self.new(addr=self.addr, connection_closed_handler=connection_closed_handler)
client = await self.new(
addr=self.addr,
connection_closed_handler=connection_closed_handler,
connection_name=connection_name,
)

assert client.server_properties is not None

Expand All @@ -675,7 +690,10 @@ async def _resolve_broker(
)

async def new(
self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None
self,
connection_name: Optional[str],
addr: Addr,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
) -> Client:
host, port = addr
client = Client(
Expand All @@ -684,6 +702,7 @@ async def new(
ssl_context=self.ssl_context,
frame_max=self._frame_max,
heartbeat=self._heartbeat,
connection_name=connection_name,
connection_closed_handler=connection_closed_handler,
)
await client.start()
Expand Down
12 changes: 10 additions & 2 deletions rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
load_balancer_mode: bool = False,
max_retries: int = 20,
connection_closed_handler: Optional[CB_CONN[DisconnectionErrorInfo]] = None,
connection_name: str = None,
):
self._pool = ClientPool(
host,
Expand All @@ -93,6 +94,9 @@ def __init__(
self._stop_event = asyncio.Event()
self._lock = asyncio.Lock()
self._connection_closed_handler = connection_closed_handler
self._connection_name = connection_name
if self._connection_name is None:
self._connection_name = "rstream-consumer"

@property
def default_client(self) -> Client:
Expand All @@ -108,7 +112,9 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)
self._default_client = await self._pool.get(
connection_closed_handler=self._connection_closed_handler, connection_name=self._connection_name
)

def stop(self) -> None:
self._stop_event.set()
Expand All @@ -134,7 +140,9 @@ async def _get_or_create_client(self, stream: str) -> Client:
leader, replicas = await self.default_client.query_leader_and_replicas(stream)
broker = random.choice(replicas) if replicas else leader
self._clients[stream] = await self._pool.get(
addr=Addr(broker.host, broker.port), connection_closed_handler=self._connection_closed_handler
addr=Addr(broker.host, broker.port),
connection_closed_handler=self._connection_closed_handler,
connection_name=self._connection_name,
)

return self._clients[stream]
Expand Down
12 changes: 10 additions & 2 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
max_retries: int = 20,
default_batch_publishing_delay: float = 0.2,
default_context_switch_value: int = 1000,
connection_name: str = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
):
self._pool = ClientPool(
Expand Down Expand Up @@ -110,6 +111,9 @@ def __init__(
self._default_context_switch_value = default_context_switch_value
self._connection_closed_handler = connection_closed_handler
self._close_called = False
self._connection_name = connection_name
if self._connection_name is None:
self._connection_name = "rstream-producer"

@property
def default_client(self) -> Client:
Expand All @@ -126,7 +130,9 @@ async def __aexit__(self, *_: Any) -> None:

async def start(self) -> None:
self._close_called = False
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)
self._default_client = await self._pool.get(
connection_closed_handler=self._connection_closed_handler, connection_name=self._connection_name
)

async def close(self) -> None:
self._close_called = True
Expand Down Expand Up @@ -161,7 +167,9 @@ async def _get_or_create_client(self, stream: str) -> Client:
if stream not in self._clients:
leader, _ = await self.default_client.query_leader_and_replicas(stream)
self._clients[stream] = await self._pool.get(
Addr(leader.host, leader.port), self._connection_closed_handler
connection_name=self._connection_name,
addr=Addr(leader.host, leader.port),
connection_closed_handler=self._connection_closed_handler,
)

return self._clients[stream]
Expand Down
12 changes: 10 additions & 2 deletions rstream/superstream_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
load_balancer_mode: bool = False,
max_retries: int = 20,
super_stream: str,
connection_name: str = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
):
self._pool = ClientPool(
Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(
self._stop_event = asyncio.Event()
self._subscribers: dict[str, str] = defaultdict(str)
self._connection_closed_handler = connection_closed_handler
self._connection_name = connection_name
if self._connection_name is None:
self._connection_name = "rstream-consumer"

@property
def default_client(self) -> Client:
Expand All @@ -93,7 +97,9 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)
self._default_client = await self._pool.get(
connection_closed_handler=self._connection_closed_handler, connection_name="rstream-locator"
)

def stop(self) -> None:
self._stop_event.set()
Expand All @@ -116,7 +122,9 @@ async def _get_or_create_client(self, stream: str) -> Client:
leader, replicas = await self.default_client.query_leader_and_replicas(stream)
broker = random.choice(replicas) if replicas else leader
self._clients[stream] = await self._pool.get(
Addr(broker.host, broker.port), connection_closed_handler=self._connection_closed_handler
addr=Addr(broker.host, broker.port),
connection_closed_handler=self._connection_closed_handler,
connection_name=self._connection_name,
)

return self._clients[stream]
Expand Down
11 changes: 9 additions & 2 deletions rstream/superstream_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(
load_balancer_mode: bool = False,
max_retries: int = 20,
default_batch_publishing_delay: float = 0.2,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
connection_name: str = None,
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -84,6 +85,9 @@ def __init__(
self._producer: Producer | None = None
self._routing_strategy: RoutingStrategy
self._connection_closed_handler = connection_closed_handler
self._connection_name = connection_name
if self._connection_name is None:
self._connection_name = "rstream-producer"

async def _get_producer(self) -> Producer:
if self._producer is None:
Expand All @@ -99,6 +103,7 @@ async def _get_producer(self) -> Producer:
load_balancer_mode=self.load_balancer_mode,
default_batch_publishing_delay=self.default_batch_publishing_delay,
connection_closed_handler=self._connection_closed_handler,
connection_name=self._connection_name,
)
await producer.start()
self._producer = producer
Expand Down Expand Up @@ -130,7 +135,9 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)
self._default_client = await self._pool.get(
connection_closed_handler=self._connection_closed_handler, connection_name="rstream-locator"
)
self.super_stream_metadata = DefaultSuperstreamMetadata(self.super_stream, self._default_client)
if self.routing == RouteType.Hash:
self._routing_strategy = HashRoutingMurmurStrategy(self.routing_extractor)
Expand Down

0 comments on commit 6038781

Please sign in to comment.