diff --git a/wsdiscovery/discovery.py b/wsdiscovery/discovery.py index 3b40cde..fb039da 100644 --- a/wsdiscovery/discovery.py +++ b/wsdiscovery/discovery.py @@ -13,7 +13,7 @@ class Discovery: - "networking-agnostic generic remote service discovery mixin" + """networking-agnostic generic remote service discovery mixin""" def __init__(self, **kwargs): self._remoteServices = {} diff --git a/wsdiscovery/service.py b/wsdiscovery/service.py index e84927f..da89978 100644 --- a/wsdiscovery/service.py +++ b/wsdiscovery/service.py @@ -1,4 +1,5 @@ """Discoverable WS-Discovery service.""" +import socket from .util import _getNetworkAddrs @@ -38,10 +39,11 @@ def getXAddrs(self): for xAddr in self._xAddrs: if '{ip}' in xAddr: if ipAddrs is None: - ipAddrs = _getNetworkAddrs() + ipAddrs = _getNetworkAddrs(socket.AF_INET) + ipAddrs.append(_getNetworkAddrs(socket.AF_INET6)) for ipAddr in ipAddrs: - if ipAddr != '127.0.0.1': - ret.append(xAddr.format(ip=ipAddr)) + if not ipAddr.is_loopback: + ret.append(xAddr.format(ip=str(ipAddr))) else: ret.append(xAddr) return ret diff --git a/wsdiscovery/threaded.py b/wsdiscovery/threaded.py index 294728a..4d8b0bf 100644 --- a/wsdiscovery/threaded.py +++ b/wsdiscovery/threaded.py @@ -1,31 +1,26 @@ """Threaded networking facilities for implementing threaded WS-Discovery daemons.""" - +import ipaddress import logging -import random -import time -import uuid +import platform +import selectors import socket import struct import threading -import selectors -import platform +import time from typing import cast -from .udp import UDPMessage from .actions import * -from .uri import URI -from .util import _getNetworkAddrs, dom2Str from .message import createSOAPMessage, parseSOAPMessage -from .service import Service - +from .udp import UDPMessage +from .util import _getNetworkAddrs, dom2Str logger = logging.getLogger("threading") - BUFFER_SIZE = 0xffff NETWORK_ADDRESSES_CHECK_TIMEOUT = 5 MULTICAST_PORT = 3702 MULTICAST_IPV4_ADDRESS = "239.255.255.250" +MULTICAST_IPV6_ADDRESS = "FF02::C" class _StoppableDaemonThread(threading.Thread): @@ -33,6 +28,7 @@ class _StoppableDaemonThread(threading.Thread): run() method shall exit, when self._quitEvent.wait() returned True """ + def __init__(self): self._quitEvent = threading.Event() super(_StoppableDaemonThread, self).__init__() @@ -48,14 +44,15 @@ def schedule_stop(self): class AddressMonitorThread(_StoppableDaemonThread): "trigger address change callbacks when local service addresses change" - def __init__(self, wsd): + def __init__(self, wsd, protocol_version): self._addrs = set() self._wsd = wsd + self._protocolVersion = protocol_version super(AddressMonitorThread, self).__init__() self._updateAddrs() def _updateAddrs(self): - addrs = set(_getNetworkAddrs()) + addrs = set(_getNetworkAddrs(self._protocolVersion)) disappeared = self._addrs.difference(addrs) new = addrs.difference(self._addrs) @@ -74,38 +71,60 @@ def run(self): class NetworkingThread(_StoppableDaemonThread): - def __init__(self, observer, capture=None): + def __init__(self, observer): super(NetworkingThread, self).__init__() self.daemon = True - self._queue = [] # FIXME synchronisation + self._queue = [] # FIXME synchronisation self._knownMessageIds = set() self._iidMap = {} self._observer = observer self._capture = observer._capture - self._seqnum = 1 # capture sequence number + + self._seqnum = 1 # capture sequence number self._selector = selectors.DefaultSelector() - @staticmethod - def _makeMreq(addr): - return struct.pack("4s4s", socket.inet_aton(MULTICAST_IPV4_ADDRESS), socket.inet_aton(addr)) + def _makeMreq(self, addr) -> bytes: + pass + + def _get_inet(self) -> int: + pass + + def _get_multicast(self) -> int: + pass - @staticmethod - def _createMulticastOutSocket(addr, ttl): - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + def _get_ip_proto(self) -> int: + pass + + def _get_ip_join(self) -> int: + pass + + def _get_ip_leave(self) -> int: + pass + + def _get_multicast_ttl(self) -> int: + pass + + def _createMulticastOutSocket(self, addr, ttl): + ip_proto = self._get_ip_proto() + sock = socket.socket(self._get_inet(), socket.SOCK_DGRAM) sock.setblocking(0) - sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) - if addr is None: - sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.INADDR_ANY) + sock.setsockopt(ip_proto, self._get_multicast_ttl(), ttl) + + if not addr: + iface = socket.INADDR_ANY + elif self._get_inet() == socket.AF_INET: + iface = addr.packed else: - sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(addr)) + iface = int(addr.scope_id) + + sock.setsockopt(ip_proto, self._get_multicast(), iface) return sock - @staticmethod - def _createMulticastInSocket(): - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + def _createMulticastInSocket(self): + sock = socket.socket(self._get_inet(), socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if platform.system() in ["Darwin", "FreeBSD"]: @@ -119,9 +138,9 @@ def _createMulticastInSocket(): def addSourceAddr(self, addr): """None means 'system default'""" try: - self._multiInSocket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, self._makeMreq(addr)) - except socket.error: # if 1 interface has more than 1 address, exception is raised for the second - pass + self._multiInSocket.setsockopt(self._get_ip_proto(), self._get_ip_join(), self._makeMreq(addr)) + except socket.error as e: + logger.warning(f"Interface has more than 1 address: {e}") sock = self._createMulticastOutSocket(addr, self._observer.ttl) self._multiOutUniInSockets[addr] = sock @@ -129,9 +148,9 @@ def addSourceAddr(self, addr): def removeSourceAddr(self, addr): try: - self._multiInSocket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, self._makeMreq(addr)) - except socket.error: # see comments for setsockopt(.., socket.IP_ADD_MEMBERSHIP.. - pass + self._multiInSocket.setsockopt(self._get_ip_proto(), self._get_ip_leave(), self._makeMreq(addr)) + except socket.error as e: + logger.warning(f"Interface has more than 1 address: {e}") sock = self._multiOutUniInSockets[addr] self._selector.unregister(sock) @@ -172,14 +191,15 @@ def _recvMessages(self): env = parseSOAPMessage(data, addr[0]) - if env is None: # fault or failed to parse + if env is None: # fault or failed to parse if self._capture: - self._capture.write("%i WARNING: BAD RECV %s:%s TS=%s\n" % (self._seqnum, addr[0], addr[1], time.time() - self.t0)) + self._capture.write( + "%i WARNING: BAD RECV %s:%s TS=%s\n" % (self._seqnum, addr[0], addr[1], time.time() - self.t0)) self._capture.write(dom2Str(data)) self._seqnum += 1 continue - _own_addrs = self._observer._addrsMonitorThread._addrs + _own_addrs = self._observer._addrsMonitorThread_v4._addrs if addr[0] not in _own_addrs: if env.getAction() == NS_ACTION_PROBE_MATCH: prms = "\n ".join((str(prm) for prm in env.getProbeResolveMatches())) @@ -187,13 +207,14 @@ def _recvMessages(self): logger.debug(msg, addr[0], prms) if self._capture: - self._capture.write("%i RECV %s:%s TS=%s\n" % (self._seqnum, addr[0], addr[1], time.time() - self.t0)) + self._capture.write( + "%i RECV %s:%s TS=%s\n" % (self._seqnum, addr[0], addr[1], time.time() - self.t0)) self._capture.write(dom2Str(data)) self._seqnum += 1 mid = env.getMessageId() if mid in self._knownMessageIds: - continue # https://github.com/andreikop/python-ws-discovery/issues/38 # TODO + continue # https://github.com/andreikop/python-ws-discovery/issues/38 # TODO else: if self._capture: self._capture.write("NEW KNOWN MSG IDS %s\n" % (mid)) @@ -222,7 +243,8 @@ def _sendMsg(self, msg): if msg.msgType() == UDPMessage.UNICAST: self._uniOutSocket.sendto(data, (msg.getAddr(), msg.getPort())) if self._capture: - self._capture.write("%i SEND %s:%s TS=%s\n" % (self._seqnum, msg.getAddr(), msg.getPort(), time.time() - self.t0)) + self._capture.write( + "%i SEND %s:%s TS=%s\n" % (self._seqnum, msg.getAddr(), msg.getPort(), time.time() - self.t0)) self._capture.write(dom2Str(data)) self._seqnum += 1 else: @@ -234,9 +256,10 @@ def _sendMsg(self, msg): # An example of the first case is a wireguard vpn interface. # In either case just log as debug and ignore the error. logger.debug("Interface for %s does not support multicast or is not UP.\n\tOSError %s", - socket.inet_ntoa(sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, 4)), e) + socket.inet_ntoa(sock.getsockopt(self._get_ip_proto(), self._get_multicast(), 4)), e) if self._capture: - self._capture.write("%i SEND %s:%s iface=%s TS=%s\n" % (self._seqnum, msg.getAddr(), msg.getPort(), addr, time.time() - self.t0)) + self._capture.write("%i SEND %s:%s iface=%s TS=%s\n" % ( + self._seqnum, msg.getAddr(), msg.getPort(), addr, time.time() - self.t0)) self._capture.write(dom2Str(data)) self._seqnum += 1 @@ -258,14 +281,14 @@ def _sendPendingMessages(self): def start(self): super(NetworkingThread, self).start() - self._uniOutSocket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._uniOutSocket = socket.socket(self._get_inet(), socket.SOCK_DGRAM) self._multiInSocket = self._createMulticastInSocket() self._selector.register(self._multiInSocket, selectors.EVENT_WRITE | selectors.EVENT_READ) self._multiOutUniInSockets = {} # FIXME synchronisation - def join(self): + def join(self, **kwargs): assert self._quitEvent.is_set() super(NetworkingThread, self).join() @@ -276,61 +299,136 @@ def join(self): for sock in self._multiOutUniInSockets.values(): try: sock.close() - except socket.error: - ... + except socket.error as e: + logger.error(e) + + +class NetworkingThreadIPv4(NetworkingThread): + def __init__(self, observer): + super().__init__(observer) + + def _makeMreq(self, addr): + return struct.pack("4s4s", socket.inet_aton(MULTICAST_IPV4_ADDRESS), addr.packed) + + def _get_inet(self): + return socket.AF_INET + + def _get_multicast(self): + return socket.IP_MULTICAST_IF + + def _get_ip_proto(self): + return socket.IPPROTO_IP + + def _get_ip_join(self): + return socket.IP_ADD_MEMBERSHIP + + def _get_ip_leave(self): + return socket.IP_DROP_MEMBERSHIP + + def _get_multicast_ttl(self): + return socket.IP_MULTICAST_TTL + + +class NetworkingThreadIPv6(NetworkingThread): + def __init__(self, observer): + super().__init__(observer) + + def _makeMreq(self, addr): + return struct.pack("=16si", socket.inet_pton(socket.AF_INET6, MULTICAST_IPV6_ADDRESS), int(addr.scope_id)) + + def _get_inet(self): + return socket.AF_INET6 + + def _get_multicast(self): + return socket.IPV6_MULTICAST_IF + + def _get_ip_proto(self): + return socket.IPPROTO_IPV6 + + def _get_ip_join(self): + return socket.IPV6_JOIN_GROUP + + def _get_ip_leave(self): + return socket.IPV6_LEAVE_GROUP + + def _get_multicast_ttl(self): + return socket.IPV6_MULTICAST_HOPS class ThreadedNetworking: "handle threaded networking start & stop, address add/remove & message sending" def __init__(self, **kwargs): - self._networkingThread = None + self._networkingThread_v4 = None + self._networkingThread_v6 = None + self._addrsMonitorThread_v4 = None + self._addrsMonitorThread_v6 = None self._serverStarted = False super().__init__(**kwargs) def _startThreads(self): - if self._networkingThread is not None: + if self._networkingThread_v4 is not None: return - self._networkingThread = NetworkingThread(self) - self._networkingThread.start() - logger.debug("networking thread started") - self._addrsMonitorThread = AddressMonitorThread(self) - self._addrsMonitorThread.start() - logger.debug("address monitoring thread started") + self._networkingThread_v4 = NetworkingThreadIPv4(self) + self._networkingThread_v6 = NetworkingThreadIPv6(self) + self._networkingThread_v4.start() + self._networkingThread_v6.start() + logger.debug("networking threads started") + + self._addrsMonitorThread_v4 = AddressMonitorThread(self, socket.AF_INET) + self._addrsMonitorThread_v6 = AddressMonitorThread(self, socket.AF_INET6) + self._addrsMonitorThread_v4.start() + self._addrsMonitorThread_v6.start() + logger.debug("address monitoring threads started") def _stopThreads(self): - if self._networkingThread is None: + if self._networkingThread_v4 is None: return - self._networkingThread.schedule_stop() - self._addrsMonitorThread.schedule_stop() + self._networkingThread_v4.schedule_stop() + self._addrsMonitorThread_v4.schedule_stop() + self._networkingThread_v6.schedule_stop() + self._addrsMonitorThread_v6.schedule_stop() - self._networkingThread.join() - self._addrsMonitorThread.join() + self._networkingThread_v4.join() + self._addrsMonitorThread_v4.join() + self._networkingThread_v6.join() + self._addrsMonitorThread_v6.join() - self._networkingThread = None + self._networkingThread_v4 = None + self._networkingThread_v6 = None def start(self): - "start networking - should be called before using other methods" + """start networking - should be called before using other methods""" self._startThreads() self._serverStarted = True def stop(self): - "cleans up and stops networking" + """cleans up and stops networking""" self._stopThreads() self._serverStarted = False def addSourceAddr(self, addr): - self._networkingThread.addSourceAddr(addr) + version = ipaddress.ip_address(addr).version + if version == 4: + self._networkingThread_v4.addSourceAddr(addr) + elif version == 6: + self._networkingThread_v6.addSourceAddr(addr) def removeSourceAddr(self, addr): - self._networkingThread.removeSourceAddr(addr) + version = ipaddress.ip_address(addr).version + if version == 4: + self._networkingThread_v4.removeSourceAddr(addr) + elif version == 6: + self._networkingThread_v6.removeSourceAddr(addr) def sendUnicastMessage(self, env, host, port, initialDelay=0): "handle unicast message sending" - self._networkingThread.addUnicastMessage(env, host, port, initialDelay) + self._networkingThread_v4.addUnicastMessage(env, host, port, initialDelay) + self._networkingThread_v6.addUnicastMessage(env, host, port, initialDelay) def sendMulticastMessage(self, env, initialDelay=0): "handle multicast message sending" - self._networkingThread.addMulticastMessage(env, MULTICAST_IPV4_ADDRESS, MULTICAST_PORT, initialDelay) + self._networkingThread_v4.addMulticastMessage(env, MULTICAST_IPV4_ADDRESS, MULTICAST_PORT, initialDelay) + self._networkingThread_v6.addMulticastMessage(env, MULTICAST_IPV6_ADDRESS, MULTICAST_PORT, initialDelay) diff --git a/wsdiscovery/util.py b/wsdiscovery/util.py index 0214bd8..3454e66 100644 --- a/wsdiscovery/util.py +++ b/wsdiscovery/util.py @@ -1,6 +1,7 @@ """Various utilities used by different parts of the package.""" import io +import ipaddress import string import random import netifaces @@ -243,16 +244,17 @@ def getQNameFromValue(value, node): return QName(ns, localName, prefix) -def _getNetworkAddrs(): +def _getNetworkAddrs(protocol_version): result = [] for if_name in netifaces.interfaces(): iface_info = netifaces.ifaddresses(if_name) - if netifaces.AF_INET in iface_info: - for addrDict in iface_info[netifaces.AF_INET]: + if protocol_version in iface_info: + for addrDict in iface_info[protocol_version]: addr = addrDict['addr'] - if addr != '127.0.0.1': - result.append(addr) + ip_address = ipaddress.ip_address(addr) + if not ip_address.is_loopback: + result.append(ip_address) return result