From fb169605b2c9120fd799dedafe586af6b1ffef72 Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Sun, 23 Oct 2016 13:27:40 +0100 Subject: [PATCH] test unregister --- pulsar/apps/data/channels.py | 35 +++++++++++++++++++++++++++++++++++ tests/stores/channels.py | 13 +++++++++++++ 2 files changed, 48 insertions(+) diff --git a/pulsar/apps/data/channels.py b/pulsar/apps/data/channels.py index 1c65a65f..f051f495 100644 --- a/pulsar/apps/data/channels.py +++ b/pulsar/apps/data/channels.py @@ -137,6 +137,17 @@ async def register(self, channel_name, event, callback): self.channels[channel.name] = channel await channel.connect() channel.register(event, callback) + return channel + + async def unregister(self, channel_name, event, callback): + name = channel_name.lower() + channel = self.channels.get(name) + if channel: + channel.unregister(event, callback) + if not channel: + await channel.disconnect() + self.channels.pop(name) + return channel async def close(self): self.pubsub.remove_callback('connection_lost', self._connection_lost) @@ -220,6 +231,15 @@ def __init__(self, channels, name): def __repr__(self): return repr(self.callbacks) + def __len__(self): + return len(self.callbacks) + + def __contains__(self, regex): + return regex in self.callbacks + + def __iter__(self): + return iter(self.channels.values()) + def __call__(self, message): event = message.pop('event', '') data = message.get('data') @@ -241,6 +261,12 @@ async def connect(self): channel_name = channels.prefixed(self.name) await self.channels.pubsub.subscribe(channel_name) + async def disconnect(self): + channels = self.channels + if channels.status == StatusType.connected: + channel_name = channels.prefixed(self.name) + await self.channels.pubsub.unsubscribe(channel_name) + def register(self, event, callback): """Register a ``callback`` for ``event`` """ @@ -254,3 +280,12 @@ def register(self, event, callback): entry.callbacks.append(callback) return entry + + def unregister(self, event, callback): + regex = redis_to_py_pattern(event) + entry = self.callbacks.get(regex) + if entry: + if callback in entry.callbacks: + entry.callbacks.remove(callback) + if not entry.callbacks: + self.callbacks.pop(regex) diff --git a/tests/stores/channels.py b/tests/stores/channels.py index 031f2597..a4a68e0c 100644 --- a/tests/stores/channels.py +++ b/tests/stores/channels.py @@ -84,6 +84,19 @@ async def test_fail_publish(self): self.assertEqual(len(args), 3) self.assertEqual(args[1], channels) + async def test_unregister(self): + channels = self.channels() + + def fire(_, event, data): + return data + + channel = await channels.register('test4', '*', fire) + self.assertEqual(len(channel), 1) + self.assertEqual(len(channels), 1) + channel = await channels.unregister('test4', '*', fire) + self.assertEqual(len(channel), 0) + self.assertEqual(len(channels), 0) + def _log_error(self, coro, *args, **kwargs): coro.switch((args, kwargs))