From 17c8febdd248996c5eb6f86bd848ec1a2dbb80ca Mon Sep 17 00:00:00 2001 From: Joakim Plate Date: Sun, 25 Feb 2024 08:57:50 +0100 Subject: [PATCH] Rework transport to leave re-connect to user (#150) * Rework transport to leave re-connect to user * Make compatible with older python * Don't signal connection loss when asked to close * Avoid blocking teardown * Split connect from constructions * Allow setting a connection timeout * Raise timeout exception on timeout * Convert more errors * Correct linting * Adjust some linting issues * More flake fixes * More lint fixes * Inject constructed transport * Change init order * Use decorators to hide low level exceptions * Make compatible with legacy python * Fix lint * Suppress errors on close --- RFXtrx/__init__.py | 228 +++++++++++++++++++++----------- RFXtrx/lowlevel.py | 4 +- examples/receive.py | 3 +- examples/send.py | 1 + tests/test_base.py | 22 +-- tests/test_transport_network.py | 99 ++++++++++++++ 6 files changed, 267 insertions(+), 90 deletions(-) create mode 100644 tests/test_transport_network.py diff --git a/RFXtrx/__init__.py b/RFXtrx/__init__.py index 95e4afa..9bafb96 100644 --- a/RFXtrx/__init__.py +++ b/RFXtrx/__init__.py @@ -23,11 +23,12 @@ # pylint: disable=R0903, invalid-name # pylint: disable= too-many-lines +import functools import glob import socket import threading -import time import logging +from contextlib import suppress from time import sleep @@ -674,6 +675,21 @@ def __str__(self): return "{0} device=[{1}]".format( type(self), self.device) + +class ConnectionEvent(RFXtrxEvent): + """ Connection event """ + def __init__(self): + super().__init__(None) + + +class ConnectionLost(ConnectionEvent): + """ Connection lost """ + + +class ConnectionDone(ConnectionEvent): + """ Connection lost """ + + ############################################################################### # DummySerial class ############################################################################### @@ -730,10 +746,19 @@ def close(self): self._close_event.set() +############################################################################### +# RFXtrxTransportError class +############################################################################### + + +class RFXtrxTransportError(Exception): + """ Connection error """ + ############################################################################### # RFXtrxTransport class ############################################################################### + class RFXtrxTransport: """ Abstract superclass for all transport mechanisms """ @@ -757,12 +782,40 @@ def parse(data): return obj return None + def connect(self, timeout=None): + """ connect to device """ + def reset(self): """ reset the rfxtrx device """ def close(self): """ close connection to rfxtrx device """ + def receive_blocking(self): + """ Wait until a packet is received and return with an RFXtrxEvent """ + + def send(self, data): + """ Send the given packet """ + + +def transport_errors(message): + """ Decorator to wrap low level errors in known error. """ + def _errors(func): + @functools.wraps(func) + def __errors(instance: RFXtrxTransport, *args, **kargs): + try: + return func(instance, *args, **kargs) + except (socket.error, + serial.SerialException, + OSError) as exception: + _LOGGER.debug("%s failed: %s", message, + str(exception), exc_info=True) + raise RFXtrxTransportError( + "{0} failed: {1}".format(message, exception) + ) from exception + return __errors + return _errors + ############################################################################### # PySerialTransport class ############################################################################### @@ -774,45 +827,39 @@ class PySerialTransport(RFXtrxTransport): def __init__(self, port): self.port = port self.serial = None - self._run_event = threading.Event() - self._run_event.set() - self.connect() - def connect(self): + @transport_errors("connect") + def connect(self, timeout=None): """ Open a serial connexion """ try: - self.serial = serial.Serial(self.port, 38400, timeout=0.1) - except serial.serialutil.SerialException: + self.serial = serial.Serial(self.port, 38400) + except serial.SerialException: port = glob.glob('/dev/serial/by-id/usb-RFXCOM_*-port0') if len(port) < 1: - return - self.serial = serial.Serial(port[0], 38400, timeout=0.1) + raise + _LOGGER.debug("Attempting connection by name %s", port) + self.serial = serial.Serial(port[0], 38400) + @transport_errors("receive") def receive_blocking(self): + return self._receive_packet() + + def _receive_packet(self): """ Wait until a packet is received and return with an RFXtrxEvent """ - data = None - while self._run_event.is_set(): - try: - data = self.serial.read() - except TypeError: - continue - except serial.serialutil.SerialException: - try: - self.connect() - except serial.serialutil.SerialException: - time.sleep(5) - continue - if not data or data == '\x00': - continue - pkt = bytearray(data) - data = self.serial.read(pkt[0]) + data = self.serial.read() + if data == '\x00': + return None + pkt = bytearray(data) + while len(pkt) < pkt[0]+1: + data = self.serial.read(pkt[0]+1 - len(pkt)) pkt.extend(bytearray(data)) - _LOGGER.debug( - "Recv: %s", - " ".join("0x{0:02x}".format(x) for x in pkt) - ) - return self.parse(pkt) + _LOGGER.debug( + "Recv: %s", + " ".join("0x{0:02x}".format(x) for x in pkt) + ) + return self.parse(pkt) + @transport_errors("send") def send(self, data): """ Send the given packet """ if isinstance(data, bytearray): @@ -827,17 +874,19 @@ def send(self, data): ) self.serial.write(pkt) + @transport_errors("reset") def reset(self): """ Reset the RFXtrx """ - self.send(b'\x0D\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') + self.send(b'\x0D\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00') sleep(0.3) # Should work with 0.05, but not for me self.serial.flushInput() + @transport_errors("close") def close(self): """ close connection to rfxtrx device """ - self._run_event.clear() - self.serial.close() - + with suppress(serial.SerialException): + self.serial.close() ############################################################################### # PyNetworkTransport class @@ -850,44 +899,40 @@ class PyNetworkTransport(RFXtrxTransport): def __init__(self, hostport): self.hostport = hostport # must be a (host, port) tuple self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._run_event = threading.Event() - self._run_event.set() - self.connect() - def connect(self): + @transport_errors("connect") + def connect(self, timeout=None): """ Open a socket connection """ - try: - self.sock.connect(self.hostport) - _LOGGER.info("Connected to network socket") - except socket.error: - _LOGGER.error('Failed to create socket, check host port config') - # This may throw exception for use by caller: - self.sock.connect(self.hostport) + self.sock.settimeout(timeout) + self.sock.connect(self.hostport) + self.sock.settimeout(None) + _LOGGER.debug("Connected to network socket") + @transport_errors("receive") def receive_blocking(self): """ Wait until a packet is received and return with an RFXtrxEvent """ - data = None - while self._run_event.is_set(): - try: - data = self.sock.recv(1) - except socket.error: - try: - self.connect() - except socket.error: - time.sleep(5) - continue - if not data or data == '\x00': - continue - pkt = bytearray(data) - while len(pkt) < pkt[0]+1: - data = self.sock.recv(pkt[0]+1 - len(pkt)) - pkt.extend(bytearray(data)) - _LOGGER.debug( - "Recv: %s", - " ".join("0x{0:02x}".format(x) for x in pkt) - ) - return self.parse(pkt) + return self._receive_packet() + def _receive_packet(self): + """ Wait until a packet is received and return with an RFXtrxEvent """ + data = self.sock.recv(1) + if data == b'': + raise RFXtrxTransportError("Server was shutdown") + if data == '\x00': + return None + pkt = bytearray(data) + while len(pkt) < pkt[0]+1: + data = self.sock.recv(pkt[0]+1 - len(pkt)) + if data == b'': + raise RFXtrxTransportError("Server was shutdown") + pkt.extend(bytearray(data)) + _LOGGER.debug( + "Recv: %s", + " ".join("0x{0:02x}".format(x) for x in pkt) + ) + return self.parse(pkt) + + @transport_errors("send") def send(self, data): """ Send the given packet """ if isinstance(data, bytearray): @@ -902,17 +947,23 @@ def send(self, data): ) self.sock.send(pkt) + @transport_errors("reset") def reset(self): """ Reset the RFXtrx """ - self.send(b'\x0D\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') - sleep(0.3) - self.sock.sendall(b'') - + try: + self.send(b'\x0D\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00') + sleep(0.3) + self.sock.sendall(b'') + except socket.error as exception: + raise RFXtrxTransportError( + "Reset failed: {0}".format(exception)) from exception + + @transport_errors("close") def close(self): """ close connection to rfxtrx device """ - self._run_event.clear() - self.sock.shutdown(socket.SHUT_RDWR) - self.sock.close() + with suppress(socket.error): + self.sock.close() class DummyTransport(RFXtrxTransport): @@ -922,6 +973,9 @@ def __init__(self, device=""): self.device = device self._close_event = threading.Event() + def connect(self, timeout=None): + pass + def receive(self, data=None): """ Emulate a receive by parsing the given data """ if data is None: @@ -958,6 +1012,8 @@ class DummyTransport2(PySerialTransport): def __init__(self, device=""): self.serial = _dummySerial(device, 38400, timeout=0.1) self._run_event = threading.Event() + + def connect(self, timeout=None): self._run_event.set() @@ -966,22 +1022,34 @@ class Connect: Has methods for sensors. """ # pylint: disable=too-many-instance-attributes, too-many-arguments - def __init__(self, device, event_callback=None, - transport_protocol=PySerialTransport, + def __init__(self, transport, event_callback=None, modes=None): self._run_event = threading.Event() self._sensors = {} self._status = None self._modes = modes + self._thread = threading.Thread(target=self._connect, daemon=True) self.event_callback = event_callback + self.transport: RFXtrxTransport = transport - self.transport = transport_protocol(device) - self._thread = threading.Thread(target=self._connect) - self._thread.daemon = True + def connect(self, timeout=None): + """Connect to device.""" + self.transport.connect(timeout) self._thread.start() - self._run_event.wait() + if not self._run_event.wait(timeout): + self.close_connection() + raise TimeoutError() def _connect(self): + try: + self._connect_internal() + except RFXtrxTransportError as exception: + _LOGGER.info("Connection lost %s", exception) + finally: + if self.event_callback and self._run_event.is_set(): + self.event_callback(ConnectionLost()) + + def _connect_internal(self): """Connect """ self.transport.reset() self._status = self.send_get_status() @@ -998,6 +1066,8 @@ def _connect(self): self.send_start() self._run_event.set() + if self.event_callback: + self.event_callback(ConnectionDone()) while self._run_event.is_set(): event = self.transport.receive_blocking() diff --git a/RFXtrx/lowlevel.py b/RFXtrx/lowlevel.py index 0e91d45..6200e64 100644 --- a/RFXtrx/lowlevel.py +++ b/RFXtrx/lowlevel.py @@ -2278,7 +2278,7 @@ def load_receive(self, data): (data[10] << 8) + data[11]) self.prodwatthours = ((data[12] * pow(2, 24)) + (data[13] << 16) + (data[14] << 8) + data[15]) - self.tarif_num = (data[16] & 0x0f) + self.tarif_num = data[16] & 0x0f self.voltage = data[17] + 200 self.currentwatt = (data[18] << 8) + data[19] self.state_byte = data[20] @@ -2378,7 +2378,7 @@ def set_transmit(self, subtype, seqnbr, id1, id2, sound): self.id2 = id2 self.sound = sound self.rssi = 0 - self.rssi_byte = (self.rssi << 4) + self.rssi_byte = self.rssi << 4 self.data = bytearray([self.packetlength, self.packettype, self.subtype, self.seqnbr, self.id1, self.id2, self.sound, diff --git a/examples/receive.py b/examples/receive.py index 4229f24..576d7d0 100644 --- a/examples/receive.py +++ b/examples/receive.py @@ -38,7 +38,8 @@ def main(): modes_list = sys.argv[2].split() if len(sys.argv) > 2 else None print ("modes: ", modes_list) - core = RFXtrx.Core(rfxcom_device, print_callback, modes=modes_list) + core = RFXtrx.Connect(RFXtrx.PySerialTransport(rfxcom_device), print_callback, modes=modes_list) + core.connect() print (core) while True: diff --git a/examples/send.py b/examples/send.py index 8f2cfa9..4ebdb0d 100644 --- a/examples/send.py +++ b/examples/send.py @@ -27,6 +27,7 @@ from time import sleep transport = PySerialTransport('/dev/cu.usbserial-05VN8GHS') +transport.connect() transport.reset() while True: diff --git a/tests/test_base.py b/tests/test_base.py index dc573ab..108cb79 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -19,7 +19,8 @@ def setUp(self): def test_constructor(self): global num_calbacks - core = RFXtrx.Core(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport2) + core = RFXtrx.Connect(RFXtrx.DummyTransport2(self.path), event_callback=_callback) + core.connect() while num_calbacks < 7: time.sleep(0.1) @@ -30,13 +31,15 @@ def test_constructor(self): def test_invalid_packet(self): bytes_array = bytearray([0x09, 0x11, 0xd7, 0x00, 0x01, 0x1d, 0x14, 0x02, 0x79, 0x0a]) - core = RFXtrx.Connect(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() event = core.transport.parse(bytes_array) self.assertIsNone(event) def test_format_packet(self): # Lighting1 - core = RFXtrx.Connect(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() bytes_array = bytearray([0x07, 0x10, 0x00, 0x2a, 0x45, 0x05, 0x01, 0x70]) event = core.transport.parse(bytes_array) self.assertEqual(RFXtrx.ControlEvent, type(event)) @@ -358,7 +361,8 @@ def test_equal_check(self): self.assertFalse(temphum==energy) def test_equal_device_check(self): - core = RFXtrx.Connect(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() data1 = bytearray(b'\x11\x5A\x01\x00\x2E\xB2\x03\x00\x00' b'\x02\xB4\x00\x00\x0C\x46\xA8\x11\x69') energy = core.transport.receive(data1) @@ -391,7 +395,8 @@ def test_equal_device_check(self): core.close_connection() def test_get_device(self): - core = RFXtrx.Connect(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() # Lighting1 bytes_array = bytearray([0x07, 0x10, 0x00, 0x2a, 0x45, 0x05, 0x01, 0x70]) event = core.transport.parse(bytes_array) @@ -437,8 +442,8 @@ def test_get_device(self): core.close_connection() def test_set_recmodes(self): - core = RFXtrx.Connect(self.path, event_callback=_callback, - transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() time.sleep(0.2) self.assertEqual(None, core._modes) @@ -459,7 +464,8 @@ def test_set_recmodes(self): core.set_recmodes(['arc', 'oregon', 'unknown-mode']) def test_receive(self): - core = RFXtrx.Connect(self.path, event_callback=_callback, transport_protocol=RFXtrx.DummyTransport) + core = RFXtrx.Connect(RFXtrx.DummyTransport(self.path), event_callback=_callback) + core.connect() time.sleep(0.2) # Lighting1 bytes_array = bytearray([0x07, 0x10, 0x00, 0x2a, 0x45, 0x05, 0x01, 0x70]) diff --git a/tests/test_transport_network.py b/tests/test_transport_network.py new file mode 100644 index 0000000..48ed431 --- /dev/null +++ b/tests/test_transport_network.py @@ -0,0 +1,99 @@ + +import pytest +import RFXtrx + +import socket +import dataclasses +import threading +from typing import Tuple, List + + +@pytest.fixture(name="server_socket") +def fixture_server_socket(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('127.0.0.1', 0)) + sock.settimeout(1) + sock.listen(1) + try: + yield sock + finally: + sock.close() + + +@dataclasses.dataclass +class Server: + address: Tuple + connections: List[socket.socket] + event = threading.Event() + + +@pytest.fixture(name="server") +def fixture_server(server_socket: socket.socket): + + server = Server(address=server_socket.getsockname(), connections=[]) + + def runner(): + while True: + try: + connection, address = server_socket.accept() + server.connections.append(connection) + server.event.set() + except socket.timeout: + continue + except socket.error: + return + thread = threading.Thread(target=runner, daemon=True) + thread.start() + try: + yield server + finally: + server_socket.close() + for connection in server.connections: + connection.close() + thread.join() + + +def connected_transport(server: Server): + server.event.clear() + transport = RFXtrx.PyNetworkTransport(server.address) + transport.sock.settimeout(10) + transport.connect() + assert server.event.wait(10) + return transport, server.connections[-1] + + +def test_transport_shutdown_between_packet(server: Server): + transport, connection = connected_transport(server) + connection.sendall(bytes([0x09, 0x03, 0x01, 0x04, 0x28, + 0x0a, 0xb7, 0x66, 0x04, 0x70])) + connection.shutdown(socket.SHUT_RDWR) + + pkt = transport.receive_blocking() + assert isinstance(pkt, RFXtrx.SensorEvent) + with pytest.raises(RFXtrx.RFXtrxTransportError): + transport.receive_blocking() + + +def test_transport_shutdown_mid_packet(server: Server): + transport, connection = connected_transport(server) + connection.sendall(bytes([0x09, 0x03, 0x01, 0x04])) + connection.shutdown(socket.SHUT_RDWR) + + with pytest.raises(RFXtrx.RFXtrxTransportError): + transport.receive_blocking() + + +def test_transport_close_mid_packet(server: Server): + transport, connection = connected_transport(server) + connection.sendall(bytes([0x09, 0x03, 0x01, 0x04])) + connection.close() + + with pytest.raises(RFXtrx.RFXtrxTransportError): + transport.receive_blocking() + + +def test_transport_empty_packet(server: Server): + transport, connection = connected_transport(server) + connection.sendall(bytes([0x00])) + + assert transport.receive_blocking() is None