Skip to content

Commit

Permalink
Support automatic recovery (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Feb 3, 2024
1 parent 74a4c57 commit b477898
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
uses: crazy-max/ghaction-setup-docker@v3

- name: Start Centrifugo
run: docker run -d -p 8000:8000 -e CENTRIFUGO_PRESENCE=true -e CENTRIFUGO_HISTORY_TTL=300s -e CENTRIFUGO_HISTORY_SIZE=100 centrifugo/centrifugo:v5 centrifugo --client_insecure
run: docker run -d -p 8000:8000 -e CENTRIFUGO_PRESENCE=true -e CENTRIFUGO_HISTORY_TTL=300s -e CENTRIFUGO_HISTORY_SIZE=100 -e CENTRIFUGO_FORCE_RECOVERY=true centrifugo/centrifugo:v5 centrifugo --client_insecure

- name: Install dependencies
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ When using Protobuf protocol:

* all payloads you pass to the library must be `bytes` or `None` if optional. If you pass non-`bytes` data – exception will be raised.
* all payloads received from the library will be `bytes` or `None` if not present.
* don't forget that when using Protobuf protocol you can still have JSON payloads - just encode them to `bytes` before passing to the library.

## Run tests

Expand Down
136 changes: 110 additions & 26 deletions centrifuge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
# Turned out legacy is not really legacy in websockets.
# See more in https://websockets.readthedocs.io/en/stable/faq/ (grep "legacy").
from websockets.legacy.client import WebSocketClientProtocol
from asyncio import AbstractEventLoop

logger = logging.getLogger("centrifuge")

Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
min_reconnect_delay: float = 0.1,
max_reconnect_delay: float = 20.0,
headers: Optional[Dict[str, str]] = None,
loop: Any = None,
loop: Optional["AbstractEventLoop"] = None,
):
"""Initializes new Client instance.
Expand Down Expand Up @@ -196,8 +197,12 @@ def new_subscription(
events: Optional[SubscriptionEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[SubscriptionTokenContext], Awaitable[str]]] = None,
data: Optional[Any] = None,
min_resubscribe_delay=0.1,
max_resubscribe_delay=10.0,
positioned: bool = False,
recoverable: bool = False,
join_leave: bool = False,
) -> "Subscription":
"""Creates new subscription to channel. If subscription already exists then
DuplicateSubscriptionError exception will be raised.
Expand All @@ -212,8 +217,12 @@ def new_subscription(
events=events,
token=token,
get_token=get_token,
data=data,
min_resubscribe_delay=min_resubscribe_delay,
max_resubscribe_delay=max_resubscribe_delay,
positioned=positioned,
recoverable=recoverable,
join_leave=join_leave,
)
self._subs[channel] = sub
return sub
Expand Down Expand Up @@ -297,12 +306,7 @@ async def _create_connection(self) -> bool:
asyncio.ensure_future(self._schedule_reconnect())
return False

self._delay = self._min_reconnect_delay
connect = {}

if self._token:
connect["token"] = self._token
elif self._get_token:
if not self._token and self._get_token:
try:
token = await self._get_token(ConnectionTokenContext())
except Exception as e:
Expand All @@ -319,19 +323,17 @@ async def _create_connection(self) -> bool:
return False

self._token = token
connect["token"] = token

if self.state != ClientState.CONNECTING:
return False

asyncio.ensure_future(self._listen())
asyncio.ensure_future(self._process_messages())

self._delay = self._min_reconnect_delay

cmd_id = self._next_command_id()
command = {
"id": cmd_id,
"connect": connect,
}
command = self._construct_connect_command(cmd_id)
async with self._register_future_with_done(cmd_id) as future:
await self._send_commands([command])

Expand Down Expand Up @@ -418,6 +420,38 @@ async def _create_connection(self) -> bool:

await self._process_server_subs(connect.get("subs", {}))

def _construct_connect_command(self, cmd_id: int) -> Dict[str, Any]:
connect = {}

if self._token:
connect["token"] = self._token

if self._data:
connect["data"] = self._encode_data(self._data)

if self._name:
connect["name"] = self._name

if self._version:
connect["version"] = self._version

subs = {}
for channel, sub in self._server_subs.items():
if sub.recoverable:
subs[channel] = {
"recover": True,
"offset": sub.offset,
"epoch": sub.epoch,
}
if subs:
connect["subs"] = subs

command = {
"id": cmd_id,
"connect": connect,
}
return command

async def _process_server_subs(self, subs: Dict[str, Dict[str, Any]]):
logger.debug("process server subs: %s", subs)
for channel, subscribe in subs.items():
Expand Down Expand Up @@ -468,17 +502,19 @@ async def _process_server_publication(self, channel: str, pub: Any):
handler = self.events.on_publication
info = pub.get("info")
client_info = self._extract_client_info(info) if info else None
offset = int(pub.get("offset", 0))
await handler(
ServerPublicationContext(
channel=channel,
pub=Publication(
offset=int(pub.get("offset", 0)),
offset=offset,
data=self._decode_data(pub.get("data")),
info=client_info,
),
)
)
# TODO: manage offsets here.
if offset > 0:
self._server_subs[channel].offset = offset

def _clear_connecting_state(self) -> None:
self._reconnect_attempts = 0
Expand Down Expand Up @@ -664,11 +700,7 @@ async def _subscribe(self, channel):

logger.debug("subscribe to channel %s", channel)

subscribe = {"channel": channel}

if sub._token:
subscribe["token"] = sub._token
elif sub._get_token:
if not sub._token and sub._get_token:
try:
token = await sub._get_token(SubscriptionTokenContext(channel=channel))
except Exception as e:
Expand All @@ -687,13 +719,9 @@ async def _subscribe(self, channel):
return False

sub._token = token
subscribe["token"] = token

cmd_id = self._next_command_id()
command = {
"id": cmd_id,
"subscribe": subscribe,
}
command = self._construct_subscribe_command(sub, cmd_id)
async with self._register_future_with_done(cmd_id) as future:
await self._send_commands([command])
try:
Expand Down Expand Up @@ -745,6 +773,37 @@ async def _subscribe(self, channel):
else:
await sub._move_subscribed(reply["subscribe"])

def _construct_subscribe_command(self, sub: "Subscription", cmd_id: int) -> Dict[str, Any]:
subscribe = {
"channel": sub.channel,
}

if sub._token:
subscribe["token"] = sub._token

if sub._data:
subscribe["data"] = self._encode_data(sub._data)

if sub._positioned:
subscribe["positioned"] = True

if sub._recoverable:
subscribe["recoverable"] = True

if sub._join_leave:
subscribe["join_leave"] = True

if sub._need_recover():
subscribe["recover"] = True
subscribe["epoch"] = sub._epoch
subscribe["offset"] = sub._offset

command = {
"id": cmd_id,
"subscribe": subscribe,
}
return command

async def _resubscribe(self, sub: "Subscription"):
self._subs[sub.channel] = sub
asyncio.ensure_future(self._subscribe(sub.channel))
Expand Down Expand Up @@ -1299,8 +1358,12 @@ def _initialize(
events: Optional[SubscriptionEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[SubscriptionTokenContext], Awaitable[str]]] = None,
data: Optional[Any] = None,
min_resubscribe_delay: float = 0.1,
max_resubscribe_delay: float = 10.0,
positioned: bool = False,
recoverable: bool = False,
join_leave: bool = False,
) -> None:
"""Initializes Subscription instance.
Note: use Client.new_subscription method to create new subscriptions in your app.
Expand All @@ -1313,11 +1376,18 @@ def _initialize(
self._client: Optional[Client] = client
self._token = token
self._get_token = get_token
self._data = data
self._min_resubscribe_delay = min_resubscribe_delay
self._max_resubscribe_delay = max_resubscribe_delay
self._positioned = positioned
self._recoverable = recoverable
self._join_leave = join_leave
self._resubscribe_attempts = 0
self._refresh_timer: Optional[TimerHandle] = None
self._resubscribe_timer: Optional[TimerHandle] = None
self._recover: bool = False
self._offset: int = 0
self._epoch: str = ""

@classmethod
def _create_instance(cls, *args: Any, **kwargs: Any) -> "Subscription":
Expand Down Expand Up @@ -1481,6 +1551,11 @@ async def _move_subscribed(self, subscribe: Dict[str, Any]) -> None:
epoch=subscribe.get("epoch", ""),
)

if recoverable:
self._recover = True
self._offset = stream_position.offset
self._epoch = stream_position.epoch

expires = subscribe.get("expires", False)
if expires:
ttl = subscribe["ttl"]
Expand All @@ -1507,15 +1582,18 @@ async def _move_subscribed(self, subscribe: Dict[str, Any]) -> None:
for pub in publications:
info = pub.get("info")
client_info = self._client._extract_client_info(info) if info else None
offset = int(pub.get("offset", 0))
await on_publication_handler(
PublicationContext(
pub=Publication(
offset=int(pub.get("offset", 0)),
offset=offset,
data=self._client._decode_data(pub.get("data")),
info=client_info,
),
),
)
if offset > 0:
self._offset = offset

self._clear_subscribing_state()

Expand Down Expand Up @@ -1548,12 +1626,18 @@ async def _resubscribe(self) -> None:
async def _process_publication(self, pub: Any) -> None:
info = pub.get("info")
client_info = self._client._extract_client_info(info) if info else None
offset = int(pub.get("offset", 0))
await self.events.on_publication(
PublicationContext(
pub=Publication(
offset=int(pub.get("offset", 0)),
offset=offset,
data=self._client._decode_data(pub.get("data")),
info=client_info,
),
),
)
if offset > 0:
self._offset = offset

def _need_recover(self):
return self._recover
71 changes: 70 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
import logging
import unittest
import uuid
from typing import List

from centrifuge import Client, ClientState, SubscriptionState, PublicationContext
from centrifuge import (
Client,
ClientState,
SubscriptionState,
PublicationContext,
SubscribedContext,
)

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -104,3 +111,65 @@ async def on_publication(ctx: PublicationContext) -> None:
result = await future
self.assertEqual(result, payload)
await client.disconnect()


class TestAutoRecovery(unittest.IsolatedAsyncioTestCase):
async def test_auto_recovery(self) -> None:
for use_protobuf in (False, True):
with self.subTest(use_protobuf=use_protobuf):
await self._test_auto_recovery(use_protobuf=use_protobuf)

async def _test_auto_recovery(self, use_protobuf=False) -> None:
client1 = Client(
"ws://localhost:8000/connection/websocket",
use_protobuf=use_protobuf,
)

client2 = Client(
"ws://localhost:8000/connection/websocket",
use_protobuf=use_protobuf,
)

# First subscribe both clients to the same channel.
channel = "recovery_channel" + uuid.uuid4().hex
sub1 = client1.new_subscription(channel)
sub2 = client2.new_subscription(channel)

futures: List[asyncio.Future] = [asyncio.Future() for _ in range(5)]

async def on_publication(ctx: PublicationContext) -> None:
futures[ctx.pub.offset - 1].set_result(ctx.pub.data)

async def on_subscribed(ctx: SubscribedContext) -> None:
self.assertFalse(ctx.recovered)
self.assertFalse(ctx.was_recovering)

sub1.events.on_publication = on_publication
sub1.events.on_subscribed = on_subscribed

await client1.connect()
await sub1.subscribe()
await client2.connect()
await sub2.subscribe()

# Now disconnect client1 and publish some messages using client2.
await client1.disconnect()

for _ in range(10):
payload = {"input": "test"}
if use_protobuf:
payload = json.dumps(payload).encode()
await sub2.publish(data=payload)

async def on_subscribed_after_recovery(ctx: SubscribedContext) -> None:
self.assertTrue(ctx.recovered)
self.assertTrue(ctx.was_recovering)

sub1.events.on_subscribed = on_subscribed_after_recovery

# Now reconnect client1 and check that it receives all missed messages.
await client1.connect()
results = await asyncio.gather(*futures)
self.assertEqual(len(results), 5)
await client1.disconnect()
await client2.disconnect()

0 comments on commit b477898

Please sign in to comment.