diff --git a/vpn_slice/__main__.py b/vpn_slice/__main__.py index 689eb7f..01c60b8 100755 --- a/vpn_slice/__main__.py +++ b/vpn_slice/__main__.py @@ -128,7 +128,7 @@ def do_disconnect(env, args): print(f"Killed pid {pid} from {pidfile}", file=stderr) if 'hosts' in providers: - removed = providers.hosts.write_hosts({}, args.name) + removed = providers.hosts.write_hosts((), args.name) if args.verbose: print(f"Removed {removed} hosts from /etc/hosts", file=stderr) diff --git a/vpn_slice/posix.py b/vpn_slice/posix.py index ca267cd..143e3c1 100644 --- a/vpn_slice/posix.py +++ b/vpn_slice/posix.py @@ -99,14 +99,13 @@ def __init__(self, path): def write_hosts(self, host_map, name): tag = f'vpn-slice-{name} AUTOCREATED' - with open(self.path, 'r+') as hostf: + with open(self.path, 'r+b') as hostf: fcntl.flock(hostf, fcntl.LOCK_EX) # POSIX only, obviously lines = hostf.readlines() - keeplines = [l for l in lines if not l.endswith(f'# {tag}\n')] + keeplines = [l for l in lines if not l.endswith(f'# {tag}\n'.encode())] hostf.seek(0, 0) hostf.writelines(keeplines) - for ip, names in host_map: - print(f"{ip} {' '.join(names)}\t\t# {tag}", file=hostf) + hostf.writelines(f"{ip} {' '.join(names)}\t\t# {tag}\n".encode() for ip, names in host_map) hostf.truncate() return len(host_map) or len(lines) - len(keeplines) diff --git a/vpn_slice/provider.py b/vpn_slice/provider.py index 3a5e10e..c9c43f7 100644 --- a/vpn_slice/provider.py +++ b/vpn_slice/provider.py @@ -129,13 +129,19 @@ def lookup_srv(self, query): class HostsProvider(metaclass=ABCMeta): @abstractmethod def write_hosts(self, host_map, name): - """Write information to the hosts file. + """Update local overrides for hostname-to-IP address mapping. - Lines include a tag so we can identify which lines to remove. - The tag is derived from the name. + Each local override added by this instance of vpn-slice should include a tag + a tag derived from the 'name' parameter, so that we can later identify those + owned by this instance in order to remove/replace them, while leaving others + untouched. - host_map maps IP addresses to host names, like the hosts file expects. + 'host_map' should be a list of (IP address, lists of hostnames) tuples, e.g. + from typing import List, Tuple, Union + from ipaddress import ip_address + def write_hosts(self, host_map: List[Tuple[Union[ip_address, str], List[str]]], name: str): + ... """ class TunnelPrepProvider: