From def5508ba8aff4c871f97de096f6082922d3caa8 Mon Sep 17 00:00:00 2001 From: Dobromir Marinov Date: Thu, 1 Aug 2024 17:22:46 +0100 Subject: [PATCH 1/2] * Added reconnect functionality. * Added `keepLinked` and `keepSynced`. * Added error messages for missing host, node and lane uri when opening downlinks. * Added method for waiting until the client is closed manually. --- swimos/__init__.py | 4 +- swimos/client/__init__.py | 3 +- swimos/client/_connections.py | 136 +++++++++++++++++++++---- swimos/client/_downlinks/_downlinks.py | 51 ++++++++-- swimos/client/_downlinks/_utils.py | 3 +- swimos/client/_swim_client.py | 34 +++++-- swimos/client/_utils.py | 6 ++ 7 files changed, 197 insertions(+), 40 deletions(-) diff --git a/swimos/__init__.py b/swimos/__init__.py index f2a74be..15cce3a 100644 --- a/swimos/__init__.py +++ b/swimos/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .client import SwimClient +from .client import SwimClient, IntervalStrategy -__all__ = [SwimClient] +__all__ = [SwimClient, IntervalStrategy] diff --git a/swimos/client/__init__.py b/swimos/client/__init__.py index 671950c..dea0d42 100644 --- a/swimos/client/__init__.py +++ b/swimos/client/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from ._swim_client import SwimClient +from ._connections import IntervalStrategy -__all__ = [SwimClient] +__all__ = [SwimClient, IntervalStrategy] diff --git a/swimos/client/_connections.py b/swimos/client/_connections.py index ca82e03..5925959 100644 --- a/swimos/client/_connections.py +++ b/swimos/client/_connections.py @@ -13,39 +13,102 @@ # limitations under the License. import asyncio +from abc import ABC, abstractmethod import websockets from enum import Enum +from websockets import ConnectionClosed from swimos.warp._warp import _Envelope from typing import TYPE_CHECKING, Any +from ._utils import exception_warn if TYPE_CHECKING: from ._downlinks._downlinks import _DownlinkModel from ._downlinks._downlinks import _DownlinkView +class RetryStrategy(ABC): + @abstractmethod + async def retry(self) -> bool: + """ + Wait for a period of time that is defined by the retry strategy. + """ + raise NotImplementedError + + @abstractmethod + def reset(self): + """ + Reset the retry strategy to its original state. + """ + raise NotImplementedError + + +class IntervalStrategy(RetryStrategy): + + def __init__(self, retries_limit=None, delay=3) -> None: + super().__init__() + self.retries_limit = retries_limit + self.delay = delay + self.retries = 0 + + async def retry(self) -> bool: + if self.retries_limit is None or self.retries_limit >= self.retries: + await asyncio.sleep(self.delay) + self.retries += 1 + return True + else: + return False + + def reset(self): + self.retries = 0 + + +class ExponentialStrategy(RetryStrategy): + + def __init__(self, retries_limit=None, max_interval=16) -> None: + super().__init__() + self.retries_limit = retries_limit + self.max_interval = max_interval + self.retries = 0 + + async def retry(self) -> bool: + if self.retries_limit is None or self.retries_limit >= self.retries: + await asyncio.sleep(min(2 ** self.retries, self.max_interval)) + self.retries += 1 + return True + else: + return False + + def reset(self): + self.retries = 0 + + class _ConnectionPool: - def __init__(self) -> None: + def __init__(self, retry_strategy: RetryStrategy = None) -> None: self.__connections = dict() + self.retry_strategy = retry_strategy @property def _size(self) -> int: return len(self.__connections) - async def _get_connection(self, host_uri: str, scheme: str) -> '_WSConnection': + async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool, + keep_synced: bool) -> '_WSConnection': """ Return a WebSocket connection to the given Host URI. If it is a new host or the existing connection is closing, create a new connection. :param host_uri: - URI of the connection host. :param scheme: - URI scheme. + :param keep_linked: - Whether the link should be automatically re-established after connection failures. + :param keep_synced: - Whether the link should synchronize its state with the remote lane. :return: - WebSocket connection. """ connection = self.__connections.get(host_uri) if connection is None or connection.status == _ConnectionStatus.CLOSED: - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, keep_linked, keep_synced, self.retry_strategy) self.__connections[host_uri] = connection return connection @@ -70,7 +133,9 @@ async def _add_downlink_view(self, downlink_view: '_DownlinkView') -> None: """ host_uri = downlink_view._host_uri scheme = downlink_view._scheme - connection = await self._get_connection(host_uri, scheme) + keep_linked = downlink_view._keep_linked + keep_synced = downlink_view._keep_synced + connection = await self._get_connection(host_uri, scheme, keep_linked, keep_synced) downlink_view._connection = connection await connection._subscribe(downlink_view) @@ -95,12 +160,19 @@ async def _remove_downlink_view(self, downlink_view: '_DownlinkView') -> None: class _WSConnection: - def __init__(self, host_uri: str, scheme: str) -> None: + def __init__(self, host_uri: str, scheme: str, keep_linked, keep_synced, + retry_strategy: RetryStrategy = None) -> None: self.host_uri = host_uri self.scheme = scheme + self.retry_strategy = retry_strategy + self.connected = asyncio.Event() self.websocket = None self.status = _ConnectionStatus.CLOSED + self.init_message = None + + self.keep_linked = keep_linked + self.keep_synced = keep_synced self.__subscribers = _DownlinkManagerPool() @@ -108,16 +180,24 @@ async def _open(self) -> None: if self.status == _ConnectionStatus.CLOSED: self.status = _ConnectionStatus.CONNECTING - try: - if self.scheme == "wss": - self.websocket = await websockets.connect(self.host_uri, ssl=True) - else: - self.websocket = await websockets.connect(self.host_uri) - except Exception as error: - self.status = _ConnectionStatus.CLOSED - raise error - - self.status = _ConnectionStatus.IDLE + while self.status == _ConnectionStatus.CONNECTING: + try: + if self.scheme == "wss": + self.websocket = await websockets.connect(self.host_uri, ssl=True) + self.retry_strategy.reset() + self.status = _ConnectionStatus.IDLE + else: + self.websocket = await websockets.connect(self.host_uri) + self.retry_strategy.reset() + self.status = _ConnectionStatus.IDLE + except Exception as error: + if self.keep_linked and await self.retry_strategy.retry(): + exception_warn(error) + continue + else: + self.status = _ConnectionStatus.CLOSED + raise error + self.connected.set() async def _close(self) -> None: @@ -129,6 +209,20 @@ async def _close(self) -> None: await self.websocket.close() self.connected.clear() + def _set_init_message(self, message: str) -> None: + """ + Set the initial message that gets sent when the underlying downlink is established. + """ + + self.init_message = message + + async def _send_init_message(self) -> None: + """ + Send the initial message for the underlying downlink if it is set. + """ + if self.init_message is not None: + await self._send_message(self.init_message) + def _has_subscribers(self) -> bool: """ Check if the connection has any subscribers. @@ -181,18 +275,20 @@ async def _wait_for_messages(self) -> None: Wait for messages from the remote agent and propagate them to all subscribers. """ - - if self.status == _ConnectionStatus.IDLE: + while self.status == _ConnectionStatus.IDLE: self.status = _ConnectionStatus.RUNNING try: while self.status == _ConnectionStatus.RUNNING: message = await self.websocket.recv() response = _Envelope._parse_recon(message) await self.__subscribers._receive_message(response) - # except: - # pass - finally: + except ConnectionClosed as error: + exception_warn(error) await self._close() + if self.keep_linked and await self.retry_strategy.retry(): + await self._open() + await self._send_init_message() + continue class _ConnectionStatus(Enum): diff --git a/swimos/client/_downlinks/_downlinks.py b/swimos/client/_downlinks/_downlinks.py index 04523fe..8d9e88b 100644 --- a/swimos/client/_downlinks/_downlinks.py +++ b/swimos/client/_downlinks/_downlinks.py @@ -37,6 +37,9 @@ def __init__(self, client: 'SwimClient') -> None: self.host_uri = None self.node_uri = None self.lane_uri = None + self.keep_linked = None + self.keep_synced = None + self.task = None self.connection = None self.linked = asyncio.Event() @@ -143,6 +146,9 @@ def __init__(self, client: 'SwimClient') -> None: self._will_unlink_callback = None self._did_unlink_callback = None + self._keep_linked = True + self._keep_synced = True + self.__registered_classes = dict() self.__deregistered_classes = set() self.__clear_classes = False @@ -181,6 +187,13 @@ def registered_classes(self) -> dict: return self._downlink_manager.registered_classes def open(self) -> '_DownlinkView': + if self._host_uri is None: + raise Exception(f'Downlink cannot be opened without first setting the host URI!') + if self._node_uri is None: + raise Exception(f'Downlink cannot be opened without first setting the node URI!') + if self._lane_uri is None: + raise Exception(f'Downlink cannot be opened without first setting the lane URI!') + if not self._is_open: task = self._client._schedule_task(self._client._add_downlink_view, self) if task is not None: @@ -210,6 +223,16 @@ def set_lane_uri(self, lane_uri: str) -> '_DownlinkView': self._lane_uri = lane_uri return self + @before_open + def keep_linked(self, keep_linked: bool) -> '_DownlinkView': + self._keep_linked = keep_linked + return self + + @before_open + def keep_synced(self, keep_synced: bool) -> '_DownlinkView': + self._keep_synced = keep_synced + return self + def did_open(self, function: Callable) -> '_DownlinkView': """ Set the `did_open` callback of the current downlink view to a given function. @@ -426,6 +449,8 @@ async def _initalise_model(self, manager: '_DownlinkManager', model: '_DownlinkM model.host_uri = self._host_uri model.node_uri = self._node_uri model.lane_uri = self._lane_uri + model.keep_linked = self.keep_linked + model.keep_synced = self.keep_synced async def _assign_manager(self, manager: '_DownlinkManager') -> None: """ @@ -463,8 +488,10 @@ def __register_class(self, custom_class: Any) -> None: class _EventDownlinkModel(_DownlinkModel): async def _establish_downlink(self) -> None: - link_request = _LinkRequest(self.node_uri, self.lane_uri) - await self.connection._send_message(link_request._to_recon()) + request = _LinkRequest(self.node_uri, self.lane_uri) + + self.connection._set_init_message(request._to_recon()) + await self.connection._send_init_message() async def _receive_event(self, message: _Envelope) -> None: converter = RecordConverter.get_converter() @@ -520,8 +547,13 @@ def __init__(self, client: 'SwimClient') -> None: self._synced = asyncio.Event() async def _establish_downlink(self) -> None: - sync_request = _SyncRequest(self.node_uri, self.lane_uri) - await self.connection._send_message(sync_request._to_recon()) + if self.keep_synced: + request = _SyncRequest(self.node_uri, self.lane_uri) + else: + request = _LinkRequest(self.node_uri, self.lane_uri) + + self.connection._set_init_message(request._to_recon()) + await self.connection._send_init_message() async def _receive_event(self, message: '_Envelope') -> None: await self.__set_value(message) @@ -550,7 +582,7 @@ async def _get_value(self) -> Any: async def __set_value(self, message: '_Envelope') -> None: """ - Set the value of the the downlink and trigger the `did_set` callback of the downlink subscribers. + Set the value of the downlink and trigger the `did_set` callback of the downlink subscribers. :param message: - The message from the remote agent. :return: @@ -702,8 +734,13 @@ def __init__(self, client: 'SwimClient') -> None: self._synced = asyncio.Event() async def _establish_downlink(self) -> None: - sync_request = _SyncRequest(self.node_uri, self.lane_uri) - await self.connection._send_message(sync_request._to_recon()) + if self.keep_synced: + request = _SyncRequest(self.node_uri, self.lane_uri) + else: + request = _LinkRequest(self.node_uri, self.lane_uri) + + self.connection._set_init_message(request._to_recon()) + await self.connection._send_init_message() async def _receive_event(self, message: '_Envelope') -> None: if message._body._tag == 'update': diff --git a/swimos/client/_downlinks/_utils.py b/swimos/client/_downlinks/_utils.py index 6dc87b8..01d7d1c 100644 --- a/swimos/client/_downlinks/_utils.py +++ b/swimos/client/_downlinks/_utils.py @@ -56,7 +56,8 @@ def wrapper(*args, **kwargs): return function(*args, **kwargs) else: try: - raise Exception(f'Cannot execute "{function.__name__}" before the downlink has been opened!') + raise Exception( + f'Cannot execute "{function.__name__}" before the downlink has been opened or after it has closed!') except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() args[0]._client._handle_exception(exc_value, exc_traceback) diff --git a/swimos/client/_swim_client.py b/swimos/client/_swim_client.py index a941ec4..c7d540a 100644 --- a/swimos/client/_swim_client.py +++ b/swimos/client/_swim_client.py @@ -16,23 +16,23 @@ import os import sys import traceback -import warnings from asyncio import Future from concurrent.futures import CancelledError from threading import Thread from traceback import TracebackException from typing import Callable, Any, Optional -from ._connections import _ConnectionPool, _WSConnection +from ._connections import _ConnectionPool, _WSConnection, RetryStrategy, IntervalStrategy from ._downlinks._downlinks import _ValueDownlinkView, _EventDownlinkView, _DownlinkView, _MapDownlinkView -from ._utils import _URI, after_started +from ._utils import _URI, after_started, exception_warn from swimos.structures import RecordConverter from swimos.warp._warp import _CommandMessage class SwimClient: - def __init__(self, terminate_on_exception: bool = False, execute_on_exception: Callable = None, + def __init__(self, retry_strategy: RetryStrategy = IntervalStrategy(), terminate_on_exception: bool = False, + execute_on_exception: Callable = None, debug: bool = False) -> None: self.debug = debug self.execute_on_exception = execute_on_exception @@ -41,7 +41,7 @@ def __init__(self, terminate_on_exception: bool = False, execute_on_exception: C self._loop = None self._loop_thread = None self._has_started = False - self.__connection_pool = _ConnectionPool() + self.__connection_pool = _ConnectionPool(retry_strategy) def __enter__(self) -> 'SwimClient': self.start() @@ -70,6 +70,19 @@ def start(self) -> 'SwimClient': return self + def join(self, timeout=None) -> 'SwimClient': + """ + Wait until the Swim client thread terminates. + This blocks the calling thread until the Swim client thread terminates + or until the optional timeout is reached. + It should be noted that when the timeout is reached, the method returns, but the thread is not terminated. + + :param timeout: - Time to wait in seconds (Optional). + """ + self._loop_thread.join(timeout=timeout) + + return self + def stop(self) -> 'SwimClient': """ Stop the client. @@ -129,15 +142,18 @@ async def _remove_downlink_view(self, downlink_view: '_DownlinkView') -> None: """ await self.__connection_pool._remove_downlink_view(downlink_view) - async def _get_connection(self, host_uri: str, scheme: str) -> '_WSConnection': + async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool, + keep_synced: bool) -> '_WSConnection': """ Get a WebSocket connection to the specified host from the connection pool. :param host_uri: - URI of the host. :param scheme: - URI scheme. + :param keep_linked: - Whether the link should be automatically re-established after connection failures. + :param keep_synced: - Whether the link should synchronize its state with the remote lane. :return: - WebSocket connection to the host. """ - connection = await self.__connection_pool._get_connection(host_uri, scheme) + connection = await self.__connection_pool._get_connection(host_uri, scheme, keep_linked, keep_synced) return connection @after_started @@ -165,7 +181,7 @@ def _handle_exception(self, exc_value: Optional[Exception], exc_traceback: Optio :param exc_value: - Exception value. :param exc_traceback: - Exception traceback. """ - warnings.warn(str(exc_value)) + exception_warn(exc_value) if self.debug: traceback.print_tb(exc_traceback) @@ -203,7 +219,7 @@ async def __send_command(self, host_uri: str, node_uri: str, lane_uri: str, body record = RecordConverter.get_converter().object_to_record(body) host_uri, scheme = _URI._parse_uri(host_uri) message = _CommandMessage(node_uri, lane_uri, body=record) - connection = await self._get_connection(host_uri, scheme) + connection = await self._get_connection(host_uri, scheme, True, False) await connection._send_message(message._to_recon()) def __start_event_loop(self) -> None: diff --git a/swimos/client/_utils.py b/swimos/client/_utils.py index 49845d1..2cb389f 100644 --- a/swimos/client/_utils.py +++ b/swimos/client/_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import warnings from typing import Callable, Optional, Tuple from urllib.parse import urlparse, ParseResult @@ -39,6 +40,11 @@ def wrapper(*args, **kwargs): return wrapper +def exception_warn(exc_value) -> None: + warnings.simplefilter('always', UserWarning) + warnings.warn(str(exc_value)) + + class _URI: @staticmethod From 874bd8f1a72e34fa51e723dcfd550f30d058fe24 Mon Sep 17 00:00:00 2001 From: Dobromir Marinov Date: Fri, 2 Aug 2024 16:18:08 +0100 Subject: [PATCH 2/2] * Changed keep synced behaviour to override keep linked. * Fixed unit tests. * Added default retry strategy. --- swimos/client/_connections.py | 30 ++--- swimos/client/_downlinks/_downlinks.py | 6 +- test/client/downlinks/test_downlink_utils.py | 6 +- test/client/downlinks/test_downlinks.py | 17 ++- test/client/test_connections.py | 110 +++++++++---------- test/client/test_swim_client.py | 13 ++- test/utils.py | 8 ++ 7 files changed, 108 insertions(+), 82 deletions(-) diff --git a/swimos/client/_connections.py b/swimos/client/_connections.py index 5925959..49c5b0a 100644 --- a/swimos/client/_connections.py +++ b/swimos/client/_connections.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -from abc import ABC, abstractmethod import websockets from enum import Enum @@ -27,20 +26,15 @@ from ._downlinks._downlinks import _DownlinkView -class RetryStrategy(ABC): - @abstractmethod +class RetryStrategy: async def retry(self) -> bool: """ Wait for a period of time that is defined by the retry strategy. """ - raise NotImplementedError + return False - @abstractmethod def reset(self): - """ - Reset the retry strategy to its original state. - """ - raise NotImplementedError + pass class IntervalStrategy(RetryStrategy): @@ -52,7 +46,7 @@ def __init__(self, retries_limit=None, delay=3) -> None: self.retries = 0 async def retry(self) -> bool: - if self.retries_limit is None or self.retries_limit >= self.retries: + if self.retries_limit is None or self.retries_limit > self.retries: await asyncio.sleep(self.delay) self.retries += 1 return True @@ -85,7 +79,7 @@ def reset(self): class _ConnectionPool: - def __init__(self, retry_strategy: RetryStrategy = None) -> None: + def __init__(self, retry_strategy: RetryStrategy = RetryStrategy()) -> None: self.__connections = dict() self.retry_strategy = retry_strategy @@ -161,7 +155,7 @@ async def _remove_downlink_view(self, downlink_view: '_DownlinkView') -> None: class _WSConnection: def __init__(self, host_uri: str, scheme: str, keep_linked, keep_synced, - retry_strategy: RetryStrategy = None) -> None: + retry_strategy: RetryStrategy = RetryStrategy()) -> None: self.host_uri = host_uri self.scheme = scheme self.retry_strategy = retry_strategy @@ -191,7 +185,7 @@ async def _open(self) -> None: self.retry_strategy.reset() self.status = _ConnectionStatus.IDLE except Exception as error: - if self.keep_linked and await self.retry_strategy.retry(): + if self.should_reconnect() and await self.retry_strategy.retry(): exception_warn(error) continue else: @@ -209,6 +203,14 @@ async def _close(self) -> None: await self.websocket.close() self.connected.clear() + def should_reconnect(self) -> bool: + """ + Return a boolean flag indicating whether the connection should try to reconnect on failure. + + :return: - True if the connection should try to reconnect on failure. + """ + return self.keep_linked or self.keep_synced + def _set_init_message(self, message: str) -> None: """ Set the initial message that gets sent when the underlying downlink is established. @@ -285,7 +287,7 @@ async def _wait_for_messages(self) -> None: except ConnectionClosed as error: exception_warn(error) await self._close() - if self.keep_linked and await self.retry_strategy.retry(): + if self.should_reconnect() and await self.retry_strategy.retry(): await self._open() await self._send_init_message() continue diff --git a/swimos/client/_downlinks/_downlinks.py b/swimos/client/_downlinks/_downlinks.py index 8d9e88b..95ba9fc 100644 --- a/swimos/client/_downlinks/_downlinks.py +++ b/swimos/client/_downlinks/_downlinks.py @@ -188,11 +188,11 @@ def registered_classes(self) -> dict: def open(self) -> '_DownlinkView': if self._host_uri is None: - raise Exception(f'Downlink cannot be opened without first setting the host URI!') + raise Exception('Downlink cannot be opened without first setting the host URI!') if self._node_uri is None: - raise Exception(f'Downlink cannot be opened without first setting the node URI!') + raise Exception('Downlink cannot be opened without first setting the node URI!') if self._lane_uri is None: - raise Exception(f'Downlink cannot be opened without first setting the lane URI!') + raise Exception('Downlink cannot be opened without first setting the lane URI!') if not self._is_open: task = self._client._schedule_task(self._client._add_downlink_view, self) diff --git a/test/client/downlinks/test_downlink_utils.py b/test/client/downlinks/test_downlink_utils.py index 4a68620..7a655d5 100644 --- a/test/client/downlinks/test_downlink_utils.py +++ b/test/client/downlinks/test_downlink_utils.py @@ -93,7 +93,8 @@ def test_after_open_invalid_with_args(self, mock_warn): # When downlink_view.get(False) # Then - self.assertEqual('Cannot execute "get" before the downlink has been opened!', mock_warn.call_args_list[0][0][0]) + self.assertEqual('Cannot execute "get" before the downlink has been opened or after it has closed!', + mock_warn.call_args_list[0][0][0]) def test_after_open_valid_with_kwargs(self): # Given @@ -115,7 +116,8 @@ def test_after_open_invalid_with_kwargs(self, mock_warn): # When downlink_view.get(wait_sync=False) # Then - self.assertEqual('Cannot execute "get" before the downlink has been opened!', mock_warn.call_args_list[0][0][0]) + self.assertEqual('Cannot execute "get" before the downlink has been opened or after it has closed!', + mock_warn.call_args_list[0][0][0]) def test_map_request_get_key_item_primitive(self): # Given diff --git a/test/client/downlinks/test_downlinks.py b/test/client/downlinks/test_downlinks.py index a886df3..0cf42ae 100644 --- a/test/client/downlinks/test_downlinks.py +++ b/test/client/downlinks/test_downlinks.py @@ -1140,6 +1140,7 @@ async def test_value_downlink_model_establish_downlink(self): downlink_model = _ValueDownlinkModel(client) downlink_model.node_uri = 'foo' downlink_model.lane_uri = 'bar' + downlink_model.keep_synced = True downlink_model.connection = MockConnection() # When @@ -1448,7 +1449,8 @@ async def test_value_downlink_view_get_before_open(self, mock_warn): downlink_view.get() # Then - self.assertEqual('Cannot execute "get" before the downlink has been opened!', mock_warn.call_args_list[0][0][0]) + self.assertEqual('Cannot execute "get" before the downlink has been opened or after it has closed!', + mock_warn.call_args_list[0][0][0]) @patch('concurrent.futures._base.Future.result') async def test_value_downlink_view_set_blocking(self, mock_result): @@ -1511,7 +1513,8 @@ async def test_value_downlink_view_set_before_open(self, mock_warn): downlink_view.set('value') # Then - self.assertEqual('Cannot execute "set" before the downlink has been opened!', mock_warn.call_args_list[0][0][0]) + self.assertEqual('Cannot execute "set" before the downlink has been opened or after it has closed!', + mock_warn.call_args_list[0][0][0]) async def test_value_downlink_view_execute_did_set(self): # Given @@ -1580,6 +1583,7 @@ async def test_map_downlink_model_establish_downlink(self): downlink_model = _MapDownlinkModel(client) downlink_model.node_uri = 'dog' downlink_model.lane_uri = 'bark' + downlink_model.keep_synced = True downlink_model.connection = MockConnection() # When @@ -2060,7 +2064,8 @@ async def test_map_downlink_view_get_before_open(self, mock_warn): downlink_view.get('a') # Then - self.assertEqual('Cannot execute "get" before the downlink has been opened!', mock_warn.call_args_list[0][0][0]) + self.assertEqual('Cannot execute "get" before the downlink has been opened or after it has closed!', + mock_warn.call_args_list[0][0][0]) async def test_map_downlink_view_get_all_immediate(self): # Given @@ -2147,7 +2152,7 @@ async def test_map_downlink_view_get_all_before_open(self, mock_warn): downlink_view.get_all() # Then - self.assertEqual('Cannot execute "get_all" before the downlink has been opened!', + self.assertEqual('Cannot execute "get_all" before the downlink has been opened or after it has closed!', mock_warn.call_args_list[0][0][0]) @patch('concurrent.futures._base.Future.result') @@ -2207,7 +2212,7 @@ async def test_map_downlink_view_put_before_open(self, mock_warn): downlink_view.put('key_map', 'value_map') # Then - self.assertEqual('Cannot execute "put" before the downlink has been opened!', + self.assertEqual('Cannot execute "put" before the downlink has been opened or after it has closed!', mock_warn.call_args_list[0][0][0]) @patch('concurrent.futures._base.Future.result') @@ -2267,7 +2272,7 @@ async def test_map_downlink_view_remove_before_open(self, mock_warn): downlink_view.remove('key_map_remove') # Then - self.assertEqual('Cannot execute "remove" before the downlink has been opened!', + self.assertEqual('Cannot execute "remove" before the downlink has been opened or after it has closed!', mock_warn.call_args_list[0][0][0]) async def test_map_downlink_view_execute_did_update(self): diff --git a/test/client/test_connections.py b/test/client/test_connections.py index bd5fdae..89445b3 100644 --- a/test/client/test_connections.py +++ b/test/client/test_connections.py @@ -44,7 +44,7 @@ async def test_pool_get_connection_new(self): uri = 'ws://foo_bar:9000' scheme = 'ws' # When - actual = await pool._get_connection(uri, scheme) + actual = await pool._get_connection(uri, scheme, True, True) # Then self.assertEqual(uri, actual.host_uri) self.assertEqual(None, actual.websocket) @@ -56,10 +56,10 @@ async def test_pool_get_connection_existing(self): pool = _ConnectionPool() uri = 'ws://foo_bar:9000' scheme = 'ws' - expected = await pool._get_connection(uri, scheme) + expected = await pool._get_connection(uri, scheme, True, True) expected.status = _ConnectionStatus.IDLE # When - actual = await pool._get_connection(uri, scheme) + actual = await pool._get_connection(uri, scheme, True, True) # Then self.assertEqual(expected, actual) self.assertEqual(1, pool._size) @@ -69,9 +69,9 @@ async def test_pool_get_connection_closed(self): pool = _ConnectionPool() uri = 'ws://foo_bar:9000' scheme = 'ws' - expected = await pool._get_connection(uri, scheme) + expected = await pool._get_connection(uri, scheme, True, True) # When - actual = await pool._get_connection(uri, scheme) + actual = await pool._get_connection(uri, scheme, True, True) # Then self.assertNotEqual(expected, actual) self.assertEqual(uri, actual.host_uri) @@ -84,7 +84,7 @@ async def test_pool_remove_connection_existing(self): pool = _ConnectionPool() uri = 'ws://foo_bar:9000' scheme = 'ws' - connection = await pool._get_connection(uri, scheme) + connection = await pool._get_connection(uri, scheme, True, True) connection.status = _ConnectionStatus.IDLE # When await pool._remove_connection(uri) @@ -107,7 +107,7 @@ async def test_pool_add_downlink_view_existing_connection(self, mock_subscribe): pool = _ConnectionPool() uri = 'ws://foo_bar:9000' scheme = 'ws' - connection = await pool._get_connection(uri, scheme) + connection = await pool._get_connection(uri, scheme, True, True) connection.status = _ConnectionStatus.IDLE client = SwimClient() client._has_started = True @@ -172,7 +172,7 @@ async def test_pool_remove_downlink_view_existing_connection_closed(self, mock_u downlink_view = client.downlink_value() downlink_view.set_host_uri(uri) await pool._add_downlink_view(downlink_view) - connection = pool._get_connection(uri, scheme) + connection = pool._get_connection(uri, scheme, True, True) connection.close() # When await pool._remove_downlink_view(downlink_view) @@ -205,7 +205,7 @@ async def test_ws_connection(self): host_uri = 'ws://localhost:9001' scheme = 'ws' # When - actual = _WSConnection(host_uri, scheme) + actual = _WSConnection(host_uri, scheme, True, True) # Then self.assertEqual(host_uri, actual.host_uri) self.assertIsNone(actual.websocket) @@ -217,7 +217,7 @@ async def test_ws_connection_subscribe_single(self, mock_add_view, mock_websocke # Given host_uri = 'ws://localhost:9001' scheme = 'ws' - actual = _WSConnection(host_uri, scheme) + actual = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -240,7 +240,7 @@ async def test_ws_connection_subscribe_multiple(self, mock_add_view, mock_websoc # Given host_uri = 'ws://1.1.1.1:9001' scheme = 'ws' - actual = _WSConnection(host_uri, scheme) + actual = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True @@ -272,7 +272,7 @@ async def test_ws_connection_unsubscribe_all(self, mock_remove_view, mock_add_vi # Given host_uri = 'ws://0.0.0.0:9001' scheme = 'ws' - actual = _WSConnection(host_uri, scheme) + actual = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -300,7 +300,7 @@ async def test_ws_connection_unsubscribe_one(self, mock_remove_view, mock_add_vi # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - actual = _WSConnection(host_uri, scheme) + actual = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True @@ -334,7 +334,7 @@ async def test_ws_connection_open_new(self, mock_websocket): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When await connection._open() # Then @@ -346,7 +346,7 @@ async def test_wss_connection_open_new(self, mock_websocket): # Given host_uri = 'wss://1.2.3.4:9001' scheme = 'wss' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When await connection._open() # Then @@ -359,7 +359,7 @@ async def test_ws_connection_open_error(self, mock_websocket): MockWebsocket.get_mock_websocket().raise_exception = True host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When with self.assertRaises(Exception) as error: # noinspection PyTypeChecker @@ -376,7 +376,7 @@ async def test_wss_connection_open_error(self, mock_websocket): MockWebsocket.get_mock_websocket().raise_exception = True host_uri = 'wss://1.2.3.4:9001' scheme = 'wss' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When with self.assertRaises(Exception) as error: # noinspection PyTypeChecker @@ -392,7 +392,7 @@ async def test_ws_connection_open_already_opened(self, mock_websocket): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() # When await connection._open() @@ -405,7 +405,7 @@ async def test_ws_connection_close_opened(self, mock_websocket): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() # When await connection._close() @@ -419,7 +419,7 @@ async def test_ws_connection_close_missing_websocket(self, mock_websocket): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When await connection._close() # Then @@ -432,7 +432,7 @@ async def test_ws_connection_close_already_closed(self, mock_websocket): # Given host_uri = 'ws://5.5.5.5:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() await connection._close() # When @@ -448,7 +448,7 @@ async def test_ws_connection_send_message_existing_websocket_single(self, mock_w host_uri = 'ws://1.2.3.4:9001' message = 'Hello, World' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() # When await connection._send_message(message) @@ -463,7 +463,7 @@ async def test_ws_connection_send_message_existing_websocket_multiple(self, mock first_message = 'Hello, World' second_message = 'Hello, Friend' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() # When await connection._send_message(first_message) @@ -479,7 +479,7 @@ async def test_ws_connection_send_message_non_existing_websocket(self, mock_webs host_uri = 'ws://1.2.3.4:9001' message = 'Hello, World' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When await connection._send_message(message) # Then @@ -492,7 +492,7 @@ async def test_ws_connection_send_message_closed(self, mock_websocket): host_uri = 'ws://1.2.3.4:9001' message = 'Hello, World' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) await connection._open() await connection._close() # When @@ -506,7 +506,7 @@ async def test_ws_connection_wait_for_message_closed(self): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When await connection._wait_for_messages() # Then @@ -520,7 +520,7 @@ async def test_ws_connection_wait_for_message_receive_single(self, mock_receive_ # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) MockWebsocket.get_mock_websocket().connection = connection client = SwimClient() @@ -551,7 +551,7 @@ async def test_ws_connection_wait_for_message_receive_multiple(self, mock_receiv # Given host_uri = 'ws://2.2.2.2:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) MockWebsocket.get_mock_websocket().connection = connection mock_receive_message.set_call_count(3) @@ -594,7 +594,7 @@ async def test_ws_connection_wait_for_message_receive_exception(self, mock_recei # Given host_uri = 'ws://5.5.5.5:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) MockWebsocket.get_mock_websocket().connection = connection mock_websocket.set_raise_exception(True) @@ -795,7 +795,7 @@ async def test_downlink_manager(self): # Given host_uri = 'ws://5.5.5.5:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) # When actual = _DownlinkManager(connection) # Then @@ -810,7 +810,7 @@ async def test_downlink_manager_open_new(self, mock_schedule_task): # Given host_uri = 'ws://5.5.5.5:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -833,7 +833,7 @@ async def test_downlink_manager_open_existing(self, mock_schedule_task): # Given host_uri = 'ws://5.5.5.5:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -857,7 +857,7 @@ async def test_downlink_manager_close_running(self, mock_schedule_task): # Given host_uri = 'ws://1.2.3.4:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -880,7 +880,7 @@ async def test_downlink_manager_close_stopped(self, mock_schedule_task): # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -901,7 +901,7 @@ async def test_downlink_manager_init_downlink_model(self): # Given host_uri = 'ws://100.100.100.100:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -926,7 +926,7 @@ async def test_downlink_manager_init_downlink_model_strict_classes(self): # Given host_uri = 'ws://100.100.100.100:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client.start() downlink_view = client.downlink_value() @@ -957,7 +957,7 @@ async def test_downlink_manager_add_view_single(self, mock_schedule_task, mock_s # Given host_uri = 'ws://99.99.99.99:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -986,7 +986,7 @@ async def test_downlink_manager_add_view_multiple(self, mock_schedule_task, mock # Given host_uri = 'ws://11.22.33.44:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True first_downlink_view = client.downlink_value() @@ -1025,7 +1025,7 @@ async def test_downlink_manager_remove_view_single(self, mock_schedule_task, moc # Given host_uri = 'ws://11.11.11.11:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1055,7 +1055,7 @@ async def test_downlink_manager_remove_view_multiple(self, mock_schedule_task, m # Given host_uri = 'ws://44.33.22.11:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True first_downlink_view = client.downlink_value() @@ -1092,7 +1092,7 @@ async def test_downlink_manager_remove_view_non_existing(self, mock_schedule_tas # Given host_uri = 'ws://66.66.66.66:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1114,7 +1114,7 @@ async def test_downlink_manager_receive_message_linked(self, mock_schedule_task, # Given host_uri = 'ws://66.66.66.66:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1136,7 +1136,7 @@ async def test_downlink_manager_receive_message_synced(self, mock_schedule_task, # Given host_uri = 'ws://11.11.11.11:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1158,7 +1158,7 @@ async def test_downlink_manager_receive_message_event(self, mock_schedule_task, # Given host_uri = 'ws://33.33.33.33:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1181,7 +1181,7 @@ async def test_downlink_manager_receive_message_multiple(self, mock_schedule_tas # Given host_uri = 'ws://44.44.44.44:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1210,7 +1210,7 @@ async def test_downlink_manager_subscribers_did_set_single(self, mock_schedule_t # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_value() @@ -1235,7 +1235,7 @@ async def test_downlink_manager_subscribers_did_set_multiple(self, mock_schedule # Given host_uri = 'ws://10.9.8.7:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True did_set_callback = mock_did_set_callback @@ -1287,7 +1287,7 @@ async def test_downlink_manager_subscribers_on_event_single(self, mock_schedule_ # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_event() @@ -1311,7 +1311,7 @@ async def test_downlink_manager_subscribers_on_event_multiple(self, mock_schedul # Given host_uri = 'ws://10.9.8.7:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True on_event_callback = mock_on_event_callback @@ -1352,7 +1352,7 @@ async def test_downlink_manager_subscribers_did_update_single(self, mock_schedul # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_map() @@ -1378,7 +1378,7 @@ async def test_downlink_manager_subscribers_did_update_multiple(self, mock_sched # Given host_uri = 'ws://10.9.8.7:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True did_update_callback = mock_did_update_callback @@ -1425,7 +1425,7 @@ async def test_downlink_manager_subscribers_did_remove_single(self, mock_schedul # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_map() @@ -1450,7 +1450,7 @@ async def test_downlink_manager_subscribers_did_remove_multiple(self, mock_sched # Given host_uri = 'ws://10.9.8.7:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True did_remove_callback = mock_on_event_callback @@ -1494,7 +1494,7 @@ async def test_downlink_manager_close_views_single(self, mock_schedule_task, moc # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True downlink_view = client.downlink_map() @@ -1514,7 +1514,7 @@ async def test_downlink_manager_close_views_multiple(self, mock_schedule_task, m # Given host_uri = 'ws://4.3.2.1:9001' scheme = 'ws' - connection = _WSConnection(host_uri, scheme) + connection = _WSConnection(host_uri, scheme, True, True) client = SwimClient() client._has_started = True actual = _DownlinkManager(connection) diff --git a/test/client/test_swim_client.py b/test/client/test_swim_client.py index 20ff35e..a7c64a1 100644 --- a/test/client/test_swim_client.py +++ b/test/client/test_swim_client.py @@ -221,6 +221,9 @@ def test_swim_client_downlink_event_open_before_client_started(self, mock_warn): swim_client = SwimClient() # When downlink_event = swim_client.downlink_event() + downlink_event.set_host_uri("ws://test.com") + downlink_event.set_node_uri("node") + downlink_event.set_lane_uri("lane") downlink_event.open() # Then self.assertEqual('Cannot execute "_add_downlink_view" before the client has been started!', @@ -242,6 +245,9 @@ def test_swim_client_downlink_map_open_before_client_started(self, mock_warn): swim_client = SwimClient() # When downlink_map = swim_client.downlink_map() + downlink_map.set_host_uri("ws://test.com") + downlink_map.set_node_uri("node") + downlink_map.set_lane_uri("lane") downlink_map.open() # Then self.assertEqual('Cannot execute "_add_downlink_view" before the client has been started!', @@ -263,6 +269,9 @@ def test_swim_client_downlink_value_open_before_client_started(self, mock_warn): swim_client = SwimClient() # When downlink_value = swim_client.downlink_value() + downlink_value.set_host_uri("ws://test.com") + downlink_value.set_node_uri("node") + downlink_value.set_lane_uri("lane") downlink_value.open() # Then self.assertEqual('Cannot execute "_add_downlink_view" before the client has been started!', @@ -327,10 +336,10 @@ async def test_swim_client_get_connection(self, mock_get_connection): downlink_view._node_uri = node_uri downlink_view._lane_uri = lane_uri # When - await swim_client._get_connection(host_uri, scheme) + await swim_client._get_connection(host_uri, scheme, True, True) # Then - mock_get_connection.assert_called_once_with(host_uri, scheme) + mock_get_connection.assert_called_once_with(host_uri, scheme, True, True) async def test_swim_client_test_schedule_task(self): # Given diff --git a/test/utils.py b/test/utils.py index 3a18fb0..a0c3d3f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -174,6 +174,7 @@ def __init__(self): self.owner = None self.messages_sent = list() self.messages_to_receive = list() + self.init_message = None @staticmethod def get_mock_connection(): @@ -197,6 +198,13 @@ async def _wait_for_messages(self): async def _send_message(self, message): self.messages_sent.append(message) + async def _send_init_message(self) -> None: + if self.init_message is not None: + await self._send_message(self.init_message) + + def _set_init_message(self, message: str) -> None: + self.init_message = message + def mock_did_set_confirmation(): print(1)