From 92a780a43d031c80f3ee49875ee8a779fc28fec9 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Sun, 24 Apr 2022 13:26:48 +0200 Subject: [PATCH] multi: async grpc lnmd+lndmd, update tests --- lndmanage/grpc_compiled/build_grpc.sh | 3 + lndmanage/grpc_compiled/manager.proto | 25 ++ lndmanage/grpc_compiled/manager_pb2.py | 171 +++++++++++ lndmanage/grpc_compiled/manager_pb2_grpc.py | 70 +++++ lndmanage/lib/chan_acceptor.py | 61 ++-- lndmanage/lib/managed.py | 69 +++-- lndmanage/lib/node.py | 65 ++-- lndmanage/main_lndmanage.py | 104 ++++--- lndmanage/main_lndmanaged.py | 11 +- test/test_chanacceptor.py | 77 +++-- test/test_circle.py | 218 +++++++------- test/test_lndmanage.py | 8 +- test/test_openchannels.py | 318 ++++++++++---------- test/test_rebalance.py | 192 ++++++++---- test/testing_common.py | 1 + 15 files changed, 932 insertions(+), 461 deletions(-) create mode 100644 lndmanage/grpc_compiled/manager.proto create mode 100644 lndmanage/grpc_compiled/manager_pb2.py create mode 100644 lndmanage/grpc_compiled/manager_pb2_grpc.py diff --git a/lndmanage/grpc_compiled/build_grpc.sh b/lndmanage/grpc_compiled/build_grpc.sh index 1808c21..79f6929 100755 --- a/lndmanage/grpc_compiled/build_grpc.sh +++ b/lndmanage/grpc_compiled/build_grpc.sh @@ -15,6 +15,7 @@ python -m grpc_tools.protoc --proto_path=googleapis:. --python_out=. --grpc_pyth python -m grpc_tools.protoc --proto_path=googleapis:. --python_out=. --grpc_python_out=. router.proto python -m grpc_tools.protoc --proto_path=googleapis:. --python_out=. --grpc_python_out=. walletkit.proto python -m grpc_tools.protoc --proto_path=googleapis:. --python_out=. --grpc_python_out=. signer.proto +python -m grpc_tools.protoc --proto_path=googleapis:. --python_out=. --grpc_python_out=. manager.proto # fix import paths sed -i -- 's@import lightning_pb2 as lightning__pb2@from lndmanage.grpc_compiled import lightning_pb2 as lightning__pb2@' lightning_pb2_grpc.py @@ -30,3 +31,5 @@ sed -i -- 's@import signer_pb2 as signer__pb2@from lndmanage.grpc_compiled impor sed -i -- 's@import signer_pb2 as signer__pb2@from lndmanage.grpc_compiled import signer_pb2 as signer__pb2@' walletkit_pb2_grpc.py sed -i -- 's@import walletkit_pb2 as walletkit__pb2@from lndmanage.grpc_compiled import walletkit_pb2 as walletkit__pb2@' walletkit_pb2_grpc.py + +sed -i -- 's@import manager_pb2 as manager__pb2@from lndmanage.grpc_compiled import manager_pb2 as manager__pb2@' manager_pb2_grpc.py diff --git a/lndmanage/grpc_compiled/manager.proto b/lndmanage/grpc_compiled/manager.proto new file mode 100644 index 0000000..7b592ce --- /dev/null +++ b/lndmanage/grpc_compiled/manager.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +option objc_class_prefix = "MNG"; + +package managerpc; + +// blah. +service Mangager { + // blah. + rpc RunningServices(RunningServicesRequest) returns (RunningServicesResponse) {} +} + +// blah. +message RunningServicesRequest { +} + +// blah. +message RunningServicesResponse { + repeated RunningService services = 1; +} + +// blah. +message RunningService { + string name = 1; +} diff --git a/lndmanage/grpc_compiled/manager_pb2.py b/lndmanage/grpc_compiled/manager_pb2.py new file mode 100644 index 0000000..7546f92 --- /dev/null +++ b/lndmanage/grpc_compiled/manager_pb2.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: manager.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='manager.proto', + package='managerpc', + syntax='proto3', + serialized_options=b'\242\002\003MNG', + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\rmanager.proto\x12\tmanagerpc\"\x18\n\x16RunningServicesRequest\"F\n\x17RunningServicesResponse\x12+\n\x08services\x18\x01 \x03(\x0b\x32\x19.managerpc.RunningService\"\x1e\n\x0eRunningService\x12\x0c\n\x04name\x18\x01 \x01(\t2f\n\x08Mangager\x12Z\n\x0fRunningServices\x12!.managerpc.RunningServicesRequest\x1a\".managerpc.RunningServicesResponse\"\x00\x42\x06\xa2\x02\x03MNGb\x06proto3' +) + + + + +_RUNNINGSERVICESREQUEST = _descriptor.Descriptor( + name='RunningServicesRequest', + full_name='managerpc.RunningServicesRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=28, + serialized_end=52, +) + + +_RUNNINGSERVICESRESPONSE = _descriptor.Descriptor( + name='RunningServicesResponse', + full_name='managerpc.RunningServicesResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='services', full_name='managerpc.RunningServicesResponse.services', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=54, + serialized_end=124, +) + + +_RUNNINGSERVICE = _descriptor.Descriptor( + name='RunningService', + full_name='managerpc.RunningService', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='managerpc.RunningService.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=126, + serialized_end=156, +) + +_RUNNINGSERVICESRESPONSE.fields_by_name['services'].message_type = _RUNNINGSERVICE +DESCRIPTOR.message_types_by_name['RunningServicesRequest'] = _RUNNINGSERVICESREQUEST +DESCRIPTOR.message_types_by_name['RunningServicesResponse'] = _RUNNINGSERVICESRESPONSE +DESCRIPTOR.message_types_by_name['RunningService'] = _RUNNINGSERVICE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +RunningServicesRequest = _reflection.GeneratedProtocolMessageType('RunningServicesRequest', (_message.Message,), { + 'DESCRIPTOR' : _RUNNINGSERVICESREQUEST, + '__module__' : 'manager_pb2' + # @@protoc_insertion_point(class_scope:managerpc.RunningServicesRequest) + }) +_sym_db.RegisterMessage(RunningServicesRequest) + +RunningServicesResponse = _reflection.GeneratedProtocolMessageType('RunningServicesResponse', (_message.Message,), { + 'DESCRIPTOR' : _RUNNINGSERVICESRESPONSE, + '__module__' : 'manager_pb2' + # @@protoc_insertion_point(class_scope:managerpc.RunningServicesResponse) + }) +_sym_db.RegisterMessage(RunningServicesResponse) + +RunningService = _reflection.GeneratedProtocolMessageType('RunningService', (_message.Message,), { + 'DESCRIPTOR' : _RUNNINGSERVICE, + '__module__' : 'manager_pb2' + # @@protoc_insertion_point(class_scope:managerpc.RunningService) + }) +_sym_db.RegisterMessage(RunningService) + + +DESCRIPTOR._options = None + +_MANGAGER = _descriptor.ServiceDescriptor( + name='Mangager', + full_name='managerpc.Mangager', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=158, + serialized_end=260, + methods=[ + _descriptor.MethodDescriptor( + name='RunningServices', + full_name='managerpc.Mangager.RunningServices', + index=0, + containing_service=None, + input_type=_RUNNINGSERVICESREQUEST, + output_type=_RUNNINGSERVICESRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_MANGAGER) + +DESCRIPTOR.services_by_name['Mangager'] = _MANGAGER + +# @@protoc_insertion_point(module_scope) diff --git a/lndmanage/grpc_compiled/manager_pb2_grpc.py b/lndmanage/grpc_compiled/manager_pb2_grpc.py new file mode 100644 index 0000000..d9c3fdd --- /dev/null +++ b/lndmanage/grpc_compiled/manager_pb2_grpc.py @@ -0,0 +1,70 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from lndmanage.grpc_compiled import manager_pb2 as manager__pb2 + + +class MangagerStub(object): + """blah. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.RunningServices = channel.unary_unary( + '/managerpc.Mangager/RunningServices', + request_serializer=manager__pb2.RunningServicesRequest.SerializeToString, + response_deserializer=manager__pb2.RunningServicesResponse.FromString, + ) + + +class MangagerServicer(object): + """blah. + """ + + def RunningServices(self, request, context): + """blah. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_MangagerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'RunningServices': grpc.unary_unary_rpc_method_handler( + servicer.RunningServices, + request_deserializer=manager__pb2.RunningServicesRequest.FromString, + response_serializer=manager__pb2.RunningServicesResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'managerpc.Mangager', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Mangager(object): + """blah. + """ + + @staticmethod + def RunningServices(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/managerpc.Mangager/RunningServices', + manager__pb2.RunningServicesRequest.SerializeToString, + manager__pb2.RunningServicesResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/lndmanage/lib/chan_acceptor.py b/lndmanage/lib/chan_acceptor.py index 903ec5a..5083a30 100644 --- a/lndmanage/lib/chan_acceptor.py +++ b/lndmanage/lib/chan_acceptor.py @@ -1,6 +1,8 @@ """Implements logic for accepting channels dynamically.""" import asyncio from typing import TYPE_CHECKING +from google.protobuf import text_format +import textwrap import lndmanage.grpc_compiled.lightning_pb2 as lnd @@ -60,46 +62,53 @@ def configure(self): ) ) - async def manage_channel_openings(self): + async def accept_channels(self): + logger.info("Channel acceptor started.") response_queue = asyncio.queues.Queue() - try: - # async way to use a bidirectional streaming grpc endpoint - # with an async iterator - async for r in self.node.async_rpc.ChannelAcceptor( - self.request_iterator(response_queue)): - await response_queue.put(r) - except asyncio.CancelledError: - logger.info("channel acceptor cancelled") - return + + # Use an async bidirectional streaming grpc endpoint with an async iterator. + # Note: no exceptions escape from there, handle them inside the iterator. + async for r in self.node.async_rpc.ChannelAcceptor( + self.request_iterator(response_queue)): + if isinstance(r, Exception): + raise r + await response_queue.put(r) async def request_iterator(self, channel_details: asyncio.Queue): - logger.info("channel acceptor started") - while True: - channel_detail = await channel_details.get() - if self.accept_channel(channel_detail): - yield lnd.ChannelAcceptResponse( - accept=True, pending_chan_id=channel_detail.pending_chan_id - ) - else: - yield lnd.ChannelAcceptResponse( - accept=False, pending_chan_id=channel_detail.pending_chan_id - ) + # Be careful, exceptions don't leave from here, only get logged. + try: + while True: + channel_detail = await channel_details.get() + if self.accept_channel(channel_detail): + yield lnd.ChannelAcceptResponse( + accept=True, pending_chan_id=channel_detail.pending_chan_id + ) + else: + yield lnd.ChannelAcceptResponse( + accept=False, pending_chan_id=channel_detail.pending_chan_id + ) + except asyncio.CancelledError: + logger.info("canceled") def accept_channel(self, channel_detail) -> bool: - # be careful, exceptions from here seem to not get raised up to the main - # loop - # TODO: raise exceptions from here logger.info( - f"about to make a decision about channel:\n{channel_detail}") + f"About to make a decision about a channel:") + logger.info(textwrap.indent(str(channel_detail), " ")) + node_pubkey = channel_detail.node_pubkey.hex() is_private = self.network_analysis.is_private(node_pubkey) - logger.info(f"is private {is_private}") + + # We apply different policies for private or public channels. if is_private: if (self.min_size_private < channel_detail.funding_amt < self.max_size_private): + logger.debug(f"Private channel accepted.") return True else: if (self.min_size_public < channel_detail.funding_amt < self.max_size_public): + logger.debug(f"Public channel accepted.") return True + + logger.debug(f"Channel open rejected.") return False \ No newline at end of file diff --git a/lndmanage/lib/managed.py b/lndmanage/lib/managed.py index 95686f6..23677bf 100644 --- a/lndmanage/lib/managed.py +++ b/lndmanage/lib/managed.py @@ -1,15 +1,19 @@ """Implements a daemon for constant watching of an LND node.""" -import time import asyncio +import time from signal import SIGINT, SIGTERM -from typing import Optional +from typing import Optional, Dict, Coroutine import warnings import os +import grpc + from lndmanage.lib.node import LndNode from lndmanage.lib.chan_acceptor import ChanAcceptor import lndmanage.grpc_compiled.lightning_pb2 as lnd import lndmanage.grpc_compiled.router_pb2 as lndrouter +import lndmanage.grpc_compiled.manager_pb2_grpc as manager_grpc +import lndmanage.grpc_compiled.manager_pb2 as manager from lndmanage import settings @@ -42,6 +46,7 @@ def __init__(self, lndm_config_path: Optional[str] = None, self.lnd_home = lnd_home self.lnd_host = lnd_host self.regtest = regtest + self.running_services: Dict[str, Coroutine] = {} self.node = LndNode(config_file=lndm_config_path, lnd_home=lnd_home, lnd_host=lnd_host, regtest=regtest) self.config = settings.read_config(self.lndmd_config_path) @@ -77,13 +82,22 @@ async def service_channel_acceptor(self): """Handles channel opening requests.""" channel_acceptor = ChanAcceptor(self.node, self.config) try: - await channel_acceptor.manage_channel_openings() + await channel_acceptor.accept_channels() except asyncio.CancelledError: logger.debug('channel_acceptor shutting down') - def run_services(self): + async def service_grpc(self): + server = grpc.aio.server() + manager_grpc.add_MangagerServicer_to_server( + ManagerBackend(self), server) + server.add_insecure_port('[::]:50051') + await server.start() + logger.info("Rpc server started.") + await server.wait_for_termination() + + async def run_services(self): """Main method to start registered services.""" - loop = self.node.loop + loop = asyncio.get_event_loop() if ASYNCIO_DEBUG: loop.set_debug(True) loop.slow_callback_duration = 0.01 @@ -93,36 +107,37 @@ def run_services(self): for sig in (SIGTERM, SIGINT): loop.add_signal_handler(sig, handler, sig) + self.running_services = { + 'grpc': self.service_grpc(), + 'channel_acceptor': self.service_channel_acceptor(), + 'alive_message': self.service_alive_message(), + # 'htlc_stream': self.service_htlc_stream(), + # 'graph_stream': self.service_graph_stream(), + } + # run services try: - services = asyncio.gather( - self.service_channel_acceptor(), - self.service_alive_message(), - self.service_htlc_stream(), - # self.service_graph_stream(), - ) - loop.run_until_complete(services) + async with self.node: + await asyncio.gather(*self.running_services.values()) except asyncio.CancelledError: logger.debug('main shutting down') except Exception as e: logger.exception("exception occured") - finally: - self.node.disconnect_rpcs() -def main(): - lndm_config_path = os.path.join(settings.home_dir, 'config.ini') - lndmd_config_path = os.path.join(settings.home_dir, 'lndmanaged.ini') +class ManagerBackend(manager_grpc.MangagerServicer): + def __init__(self, managed: LNDManageDaemon): + self.managed = managed + + def RunningServices( + self, + request: manager.RunningServicesRequest, + context, + ) -> manager.RunningServicesResponse: - lndmd = LNDManageDaemon( - lndm_config_path=lndm_config_path, - lndmd_config_path=lndmd_config_path, - ) - lndmd.run_services() + service_names = [] + for name, service in self.managed.running_services.items(): + service_names.append(manager.RunningService(name=name)) -if __name__ == '__main__': - lndmd = LNDManageDaemon( - lndm_config_path="/home/user/.lndmanage/config.ini", - lndmd_config_path="/home/user/.lndmanage/lndmanaged.ini") - lndmd.run_services() + return manager.RunningServicesResponse(services=service_names) diff --git a/lndmanage/lib/node.py b/lndmanage/lib/node.py index c2fab8d..c1f1689 100644 --- a/lndmanage/lib/node.py +++ b/lndmanage/lib/node.py @@ -18,6 +18,8 @@ import lndmanage.grpc_compiled.router_pb2_grpc as lndrouterrpc import lndmanage.grpc_compiled.walletkit_pb2 as lndwalletkit import lndmanage.grpc_compiled.walletkit_pb2_grpc as lndwalletkitrpc +import lndmanage.grpc_compiled.manager_pb2 as managermsg +import lndmanage.grpc_compiled.manager_pb2_grpc as managerrpc from lndmanage.lib.network import Network from lndmanage.lib.exceptions import PaymentTimeOut, NoRoute, OurNodeFailure @@ -117,12 +119,12 @@ def __init__(self, config_file: Optional[str] = None, if self.lnd_host is None: raise ValueError( 'if lnd_home is given, lnd_host must be given also') - lnd_host = self.lnd_host + self.lnd_host = self.lnd_host else: cert_file = os.path.expanduser(self.config['network']['tls_cert_file']) macaroon_file = \ os.path.expanduser(self.config['network']['admin_macaroon_file']) - lnd_host = self.config['network']['lnd_grpc_host'] + self.lnd_host = self.config['network']['lnd_grpc_host'] cert = None try: @@ -148,45 +150,68 @@ def metadata_callback(context, callback): cert_creds = grpc.ssl_channel_credentials(cert) auth_creds = grpc.metadata_call_credentials(metadata_callback) - creds = grpc.composite_channel_credentials(cert_creds, auth_creds) + self.creds = grpc.composite_channel_credentials(cert_creds, auth_creds) else: - creds = grpc.ssl_channel_credentials(cert) + self.creds = grpc.ssl_channel_credentials(cert) + async def connect_async_rpcs(self): + logger.debug("connecting async rpcs") # necessary to circumvent standard size limitation - self._sync_channel = grpc.secure_channel( - lnd_host, creds, + self._async_channel = grpc.aio.secure_channel( + self.lnd_host, self.creds, options=[('grpc.max_receive_message_length', 50 * 1024 * 1024)]) + self.loop = self._async_channel._loop + # establish async connections to rpc servers + self.async_rpc = lndrpc.LightningStub(self._async_channel) + self.async_routerrpc = lndrouterrpc.RouterStub(self._async_channel) + + # optionally connect to lndmanged + self._async_manager_channel = grpc.aio.insecure_channel( + 'localhost:50051') + self.async_managerrpc = managerrpc.MangagerStub(self._async_manager_channel) + + def connect_sync_rpcs(self): # necessary to circumvent standard size limitation - self._async_channel = grpc.aio.secure_channel( - lnd_host, creds, + self._sync_channel = grpc.secure_channel( + self.lnd_host, self.creds, options=[('grpc.max_receive_message_length', 50 * 1024 * 1024)]) - self.loop = self._async_channel._loop # necessary to circumvent standard size limitation - channel = grpc.secure_channel(lnd_host, creds, options=[ + lndchannel = grpc.secure_channel(self.lnd_host, self.creds, options=[ ('grpc.max_receive_message_length', 50 * 1024 * 1024) ]) # establish connections to rpc servers - self._rpc = lndrpc.LightningStub(channel) - self._routerrpc = lndrouterrpc.RouterStub(channel) - self._walletrpc = lndwalletkitrpc.WalletKitStub(channel) + self._rpc = lndrpc.LightningStub(lndchannel) + self._routerrpc = lndrouterrpc.RouterStub(lndchannel) + self._walletrpc = lndwalletkitrpc.WalletKitStub(lndchannel) - # establish async connections to rpc servers - self.async_rpc = lndrpc.LightningStub(self._async_channel) - self.async_routerrpc = lndrouterrpc.RouterStub(self._async_channel) + async def start(self): + logger.debug("node interface starting") + # connect rpcs + self.connect_sync_rpcs() + await self.connect_async_rpcs() + # initialize stuff (TODO: clean up) self.set_info() self.network = Network(self) self.update_blockheight() self.set_channel_summary() - def disconnect_rpcs(self): + async def stop(self): logger.debug("disconnecting rpcs") self._sync_channel.close() - asyncio.run_coroutine_threadsafe(self._async_channel.close(), self.loop) + await self._async_channel.close() + # wait a bit to really close all transports + await asyncio.sleep(0.01) + + async def __aenter__(self): + await self.start() + + async def __aexit__(self, exc_type, exc, tb): + await self.stop() def update_blockheight(self): info = self._rpc.GetInfo(lnd.GetInfoRequest()) @@ -1037,3 +1062,7 @@ def pubkey_to_channel_map(self): node_to_channel_map[cv['remote_pubkey']].append(c) return node_to_channel_map + + async def running_services(self): + resp = await self.async_managerrpc.RunningServices(managermsg.RunningServicesRequest()) + print(resp) \ No newline at end of file diff --git a/lndmanage/main_lndmanage.py b/lndmanage/main_lndmanage.py index 8537279..93e23cf 100755 --- a/lndmanage/main_lndmanage.py +++ b/lndmanage/main_lndmanage.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import asyncio import argparse import time import os @@ -428,6 +429,13 @@ def __init__(self): help='Update the fees without asking the user explicitly.', action='store_true') + # lndmanaged stuff + + # cmd: services + parser_services = subparsers.add_parser( + 'services', + help="displays services running in lndmanaged") + def check_for_lncli(self): """ Looks for lncli in PATH or in LNDMANAGE_HOME folder. Sets self.lncli_path. @@ -446,7 +454,7 @@ def check_for_lncli(self): def parse_arguments(self): return self.parser.parse_args() - def run_commands(self, node, args): + async def run_commands(self, node, args): # program execution if args.loglevel: # update the loglevel of the stdout handler to the user choice @@ -602,9 +610,12 @@ def run_commands(self, node, args): init=args.init, reckless=args.reckless ) + elif args.cmd == 'services': + logger.info("running") + await node.running_services() -def main(): +async def _main(): parser = Parser() # config.ini is expected to be in home/.lndmanage directory @@ -614,8 +625,10 @@ def main(): if len(sys.argv) > 1: # take arguments from sys.argv args = parser.parse_arguments() - node = LndNode(config_file=config_file) - parser.run_commands(node, args) + + lndnode = LndNode(config_file=config_file) + async with lndnode: + await parser.run_commands(lndnode, args) # otherwise enter an interactive mode else: @@ -628,50 +641,55 @@ def main(): logger.info("Running in interactive mode. " "You can type 'help' or 'exit'.") - node = LndNode(config_file=config_file) - if parser.lncli_path: - logger.info("Enabled lncli: using " + parser.lncli_path) - - while True: - try: - user_input = input("$ lndmanage ") - except KeyboardInterrupt: - logger.info("") - continue - except EOFError: - readline.write_history_file(history_file) - logger.info("exit") - return 0 + lndnode = LndNode(config_file=config_file) + async with lndnode: + if parser.lncli_path: + logger.info("Enabled lncli: using " + parser.lncli_path) - if not user_input or user_input in ['help', '-h', '--help']: - parser.parser.print_help() - continue - elif user_input == 'exit': - readline.write_history_file(history_file) - return 0 - - args_list = user_input.split(" ") + while True: + try: + user_input = input("$ lndmanage ") + except KeyboardInterrupt: + logger.info("") + continue + except EOFError: + readline.write_history_file(history_file) + logger.info("exit") + return 0 - # lncli execution - if args_list[0] == 'lncli': - if parser.lncli_path: - lncli = Lncli(parser.lncli_path, config_file) - lncli.lncli(args_list[1:]) + if not user_input or user_input in ['help', '-h', '--help']: + parser.parser.print_help() continue - else: - logger.info("lncli not enabled, put lncli in PATH or in ~/.lndmanage") + elif user_input == 'exit': + readline.write_history_file(history_file) + return 0 + + args_list = user_input.split(" ") + + # lncli execution + if args_list[0] == 'lncli': + if parser.lncli_path: + lncli = Lncli(parser.lncli_path, config_file) + lncli.lncli(args_list[1:]) + continue + else: + logger.info("lncli not enabled, put lncli in PATH or in ~/.lndmanage") + continue + try: + # need to run with parse_known_args to get an exception + args = parser.parser.parse_args(args_list) + await parser.run_commands(lndnode, args) + except SystemExit: + # argparse may raise SystemExit on incorrect user input, + # which is a graceful exit. The user gets the standard output + # from argparse of what went wrong. continue - try: - # need to run with parse_known_args to get an exception - args = parser.parser.parse_args(args_list) - parser.run_commands(node, args) - except SystemExit: - # argparse may raise SystemExit on incorrect user input, - # which is a graceful exit. The user gets the standard output - # from argparse of what went wrong. - continue + + +def main(): + asyncio.run(_main()) if __name__ == '__main__': - main() + asyncio.run(_main()) diff --git a/lndmanage/main_lndmanaged.py b/lndmanage/main_lndmanaged.py index a2ae69c..c20df00 100644 --- a/lndmanage/main_lndmanaged.py +++ b/lndmanage/main_lndmanaged.py @@ -1,8 +1,12 @@ +import asyncio import os + from lndmanage import settings from lndmanage.lib.managed import LNDManageDaemon +# TODO: configuration, command line flags + def main(): lndm_config_path = os.path.join(settings.home_dir, 'config.ini') lndmd_config_path = os.path.join(settings.home_dir, 'lndmanaged.ini') @@ -11,11 +15,8 @@ def main(): lndm_config_path=lndm_config_path, lndmd_config_path=lndmd_config_path, ) - lndmd.run_services() + asyncio.run(lndmd.run_services()) if __name__ == '__main__': - lndmd = LNDManageDaemon( - lndm_config_path="/home/user/.lndmanage/config.ini", - lndmd_config_path="/home/user/.lndmanage/lndmanaged.ini") - lndmd.run_services() + main() diff --git a/test/test_chanacceptor.py b/test/test_chanacceptor.py index 3fc62fe..b74d485 100644 --- a/test/test_chanacceptor.py +++ b/test/test_chanacceptor.py @@ -28,26 +28,42 @@ async def async_open_channel(opener_node: 'LND', pubkey: str, local_sat: int, return info -def was_channel_accepted_helper(chan_acceptor, loop, node_id_from, node_id_to, - local_sat, remote_sat) -> bool: - # accept: test channel open - open_task = async_open_channel( +async def channel_accepted( + lndnode, + chan_acceptor, node_id_from, node_id_to, local_sat, - remote_sat, - ) - acceptor_task = loop.create_task( - chan_acceptor.manage_channel_openings()) - open_task = loop.create_task(open_task) - # callback to cancel acceptor - open_task.add_done_callback(cancel_all_tasks_callback) - results = loop.run_until_complete(asyncio.gather(acceptor_task, open_task)) - - if results[1] == '': # channel open failed - return False - elif len(results[1]['funding_txid']) == 64: # channel open succeeded - return True + remote_sat +) -> bool: + + async with lndnode: + open_task = asyncio.create_task(async_open_channel( + node_id_from, + node_id_to, + local_sat, + remote_sat, + )) + acceptor_task = asyncio.create_task( + chan_acceptor.accept_channels()) + + tasks = [open_task, acceptor_task] + + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + result = None + for task in tasks: + if not task.done(): + task.cancel() + else: + result = task.result() + + if result == '': # channel open failed + return False + elif len(result['funding_txid']) == 64: # channel open succeeded + return True + elif None: + raise Exception("test didn't work") class LndmanagedTest(TestNetwork): @@ -68,42 +84,43 @@ def test_channel_acceptor(self): chan_acceptor.min_size_private = 0 chan_acceptor.max_size_private = 2_000_000 # public: - chan_acceptor.min_size_public = 4_000_000 + chan_acceptor.max_size_public = 4_000_000 master_node_id = self.testnet.node_mapping['A'] + # We mock the private/public nature of opened channels. with patch.object(chan_acceptor.network_analysis, 'is_private') as mock: - # PRIVATE NODES + # Private nodes mock.return_value = True - accepted = was_channel_accepted_helper( + accepted = asyncio.run(channel_accepted( + self.lndnode, chan_acceptor, - self.lndnode.loop, self.testnet.ln_nodes['B'], master_node_id, local_sat=1_000_000, remote_sat=0, - ) + )) self.assertTrue(accepted) - accepted = was_channel_accepted_helper( + accepted = asyncio.run(channel_accepted( + self.lndnode, chan_acceptor, - self.lndnode.loop, self.testnet.ln_nodes['C'], master_node_id, local_sat=3_000_000, remote_sat=0, - ) + )) self.assertFalse(accepted) - # PUBLIC NODES + # Public nodes mock.return_value = False - accepted = was_channel_accepted_helper( + accepted = asyncio.run(channel_accepted( + self.lndnode, chan_acceptor, - self.lndnode.loop, self.testnet.ln_nodes['C'], master_node_id, local_sat=3_000_000, remote_sat=0, - ) - self.assertFalse(accepted) + )) + self.assertTrue(accepted) diff --git a/test/test_circle.py b/test/test_circle.py index 49d6d61..4a31342 100644 --- a/test/test_circle.py +++ b/test/test_circle.py @@ -1,4 +1,5 @@ """Tests for circular self-payments.""" +import asyncio import math import time from typing import List @@ -28,7 +29,7 @@ class CircleTest(TestNetwork): """ network_definition = None - def circular_rebalance_and_check( + async def circular_rebalance_and_check( self, channel_numbers_send: List[int], channel_numbers_receive: List[int], @@ -49,56 +50,57 @@ def circular_rebalance_and_check( :param dry: if it should be a dry run """ - self.rebalancer = Rebalancer( - self.lndnode, - max_effective_fee_rate=max_effective_fee_rate, - budget_sat=budget_sat, - force=True, - ) - - graph_before = self.testnet.assemble_graph() - send_channels = {} - self.rebalancer.channels = self.lndnode.get_unbalanced_channels() - - for c in channel_numbers_send: - channel_id = self.testnet.channel_mapping[c]['channel_id'] - send_channels[channel_id] = self.rebalancer.channels[channel_id] - receive_channels = {} - for c in channel_numbers_receive: - channel_id = self.testnet.channel_mapping[c]['channel_id'] - receive_channels[channel_id] = self.rebalancer.channels[channel_id] - - invoice = self.lndnode.get_invoice(amount_sat, '') - payment_hash, payment_address = invoice.r_hash, invoice.payment_addr - - fees_msat = self.rebalancer._rebalance( - send_channels=send_channels, - receive_channels=receive_channels, - amt_sat=amount_sat, - payment_hash=payment_hash, - payment_address=payment_address, - budget_sat=budget_sat, - dry=dry - ) - - time.sleep(SLEEP_SEC_AFTER_REBALANCING) # needed to let lnd update the balances - - graph_after = self.testnet.assemble_graph() - - self.assertEqual(expected_fees_msat, fees_msat) - - # check that we send the amount we wanted and that it's conserved - # TODO: this depends on channel reserves, we assume we opened the channels - sent = 0 - received = 0 - for c in channel_numbers_send: - sent += (graph_before['A'][c]['local_balance'] - graph_after['A'][c]['local_balance']) - for c in channel_numbers_receive: - received += (graph_before['A'][c]['remote_balance'] - graph_after['A'][c]['remote_balance']) - assert sent - math.ceil(expected_fees_msat / 1000) == received - - listchannels = ListChannels(self.lndnode) - listchannels.print_all_channels('rev_alias') + async with self.lndnode: + self.rebalancer = Rebalancer( + self.lndnode, + max_effective_fee_rate=max_effective_fee_rate, + budget_sat=budget_sat, + force=True, + ) + + graph_before = self.testnet.assemble_graph() + send_channels = {} + self.rebalancer.channels = self.lndnode.get_unbalanced_channels() + + for c in channel_numbers_send: + channel_id = self.testnet.channel_mapping[c]['channel_id'] + send_channels[channel_id] = self.rebalancer.channels[channel_id] + receive_channels = {} + for c in channel_numbers_receive: + channel_id = self.testnet.channel_mapping[c]['channel_id'] + receive_channels[channel_id] = self.rebalancer.channels[channel_id] + + invoice = self.lndnode.get_invoice(amount_sat, '') + payment_hash, payment_address = invoice.r_hash, invoice.payment_addr + + fees_msat = self.rebalancer._rebalance( + send_channels=send_channels, + receive_channels=receive_channels, + amt_sat=amount_sat, + payment_hash=payment_hash, + payment_address=payment_address, + budget_sat=budget_sat, + dry=dry + ) + + time.sleep(SLEEP_SEC_AFTER_REBALANCING) # needed to let lnd update the balances + + graph_after = self.testnet.assemble_graph() + + self.assertEqual(expected_fees_msat, fees_msat) + + # check that we send the amount we wanted and that it's conserved + # TODO: this depends on channel reserves, we assume we opened the channels + sent = 0 + received = 0 + for c in channel_numbers_send: + sent += (graph_before['A'][c]['local_balance'] - graph_after['A'][c]['local_balance']) + for c in channel_numbers_receive: + received += (graph_before['A'][c]['remote_balance'] - graph_after['A'][c]['remote_balance']) + assert sent - math.ceil(expected_fees_msat / 1000) == received + + listchannels = ListChannels(self.lndnode) + listchannels.print_all_channels('rev_alias') def graph_test(self): """ @@ -108,7 +110,6 @@ def graph_test(self): raise NotImplementedError -@unittest.skip class TestCircleLiquid(CircleTest): network_definition = test_graphs_paths['star_ring_3_liquid'] @@ -125,12 +126,12 @@ def test_circle_success_1_2(self): amount_sat = 10000 expected_fees_msat = 43 - self.circular_rebalance_and_check( + asyncio.run(self.circular_rebalance_and_check( channel_numbers_from, channel_numbers_to, amount_sat, expected_fees_msat - ) + )) def test_circle_success_1_6(self): """ @@ -141,12 +142,12 @@ def test_circle_success_1_6(self): amount_sat = 10000 expected_fees_msat = 33 - self.circular_rebalance_and_check( + asyncio.run(self.circular_rebalance_and_check( channel_numbers_from, channel_numbers_to, amount_sat, expected_fees_msat - ) + )) def test_circle_6_1_fail_rebalance_failure_no_funds(self): """ @@ -160,11 +161,13 @@ def test_circle_6_1_fail_rebalance_failure_no_funds(self): self.assertRaises( OurNodeFailure, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat, + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat + ) ) def test_circle_1_6_fail_budget_too_expensive(self): @@ -179,12 +182,14 @@ def test_circle_1_6_fail_budget_too_expensive(self): self.assertRaises( TooExpensive, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat, - budget_sat, + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + budget_sat + ) ) def test_circle_1_6_fail_max_fee_rate_too_expensive(self): @@ -201,14 +206,15 @@ def test_circle_1_6_fail_max_fee_rate_too_expensive(self): self.assertRaises( TooExpensive, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat, - budget, - max_effective_fee_rate, - + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + budget, + max_effective_fee_rate, + ) ) def test_circle_1_6_success_channel_reserve(self): @@ -232,11 +238,13 @@ def test_circle_1_6_success_channel_reserve(self): expected_fees_msat = 2_959 - self.circular_rebalance_and_check( - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat, + asyncio.run( + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + ) ) def test_circle_1_6_fail_rebalance_dry(self): @@ -250,12 +258,14 @@ def test_circle_1_6_fail_rebalance_dry(self): self.assertRaises( DryRun, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat, - dry=True + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + dry=True + ) ) @unittest.skip @@ -283,11 +293,13 @@ def test_circle_fail_2_3_no_route(self): self.assertRaises( NoRoute, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + ) ) def test_circle_1_2_fail_max_trials_exhausted(self): @@ -303,11 +315,13 @@ def test_circle_1_2_fail_max_trials_exhausted(self): self.assertRaises( RebalancingTrialsExhausted, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + ) ) def test_circle_1_2_fail_no_route_multi_trials(self): @@ -319,9 +333,11 @@ def test_circle_1_2_fail_no_route_multi_trials(self): self.assertRaises( RebalancingTrialsExhausted, - self.circular_rebalance_and_check, - channel_numbers_from, - channel_numbers_to, - amount_sat, - expected_fees_msat + asyncio.run, + self.circular_rebalance_and_check( + channel_numbers_from, + channel_numbers_to, + amount_sat, + expected_fees_msat, + ) ) diff --git a/test/test_lndmanage.py b/test/test_lndmanage.py index c547750..1e002f8 100644 --- a/test/test_lndmanage.py +++ b/test/test_lndmanage.py @@ -1,4 +1,6 @@ """ Integration tests for lndmanage.""" +import asyncio + from test.testing_common import test_graphs_paths, TestNetwork @@ -15,4 +17,8 @@ def graph_test(self): def test_empty(self): # LND interface of lndmanage is initialized in setUp method of super # class, so nothing is needed here. - pass + async def run(): + async with self.lndnode: + pass + + asyncio.run(run()) diff --git a/test/test_openchannels.py b/test/test_openchannels.py index 3ce938b..93b43ab 100644 --- a/test/test_openchannels.py +++ b/test/test_openchannels.py @@ -1,4 +1,5 @@ """Integration tests for batch opening of channels.""" +import asyncio import time from unittest import TestCase @@ -33,163 +34,166 @@ def test_batchopen(self): channel_partner.pubkey for channel_partner in channel_partner_pubkeys ]) - with self.subTest(msg="(implicit), using amounts, change created, high fees"): - amount1 = 111_111 - amount2 = 222_222 - - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - channel_opener.open_channels( - pubkeys=pubkey_input, - amounts=f"{amount1},{amount2}", - reckless=True, - sat_per_vbyte=20, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(0, len(wallet_utxos_before) - len(wallet_utxos_after)) - - channel_capacities_after = [channel['capacity'] for channel in channels_after.values()] - self.assertIn(amount1, channel_capacities_after) - self.assertIn(amount2, channel_capacities_after) - - with self.subTest(msg="(implicit), total amount, private, change created"): - total_amount = 4_444_444 - - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - channel_opener.open_channels( - pubkeys=pubkey_input, - total_amount=total_amount, - reckless=True, - private=True, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(0, len(wallet_utxos_before) - len(wallet_utxos_after)) - - total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) - total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) - self.assertEqual(total_amount, total_capacity_after - total_capacity_before) - num_private_channels = len([True for v in channels_after.values() if v['private']]) - self.assertEqual(2, num_private_channels) - - with self.subTest(msg="(explicit), spend fully, no change created"): - address = self.testnet.master_node.getaddress() - self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) - confirm_transactions(self.testnet) - - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - # maybe make sure we select the correct utxo - spent_utxo = wallet_utxos_before[0] - utxo_input = f"{spent_utxo.txid}:{spent_utxo.output_index}" - channel_opener.open_channels( - utxos=utxo_input, - pubkeys=pubkey_input, - reckless=True, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) - - self.assertNotIn(spent_utxo, wallet_utxos_after) - - # clear wallet, but keep anchor reserves, leaves 50000 sat - self.testnet.master_node.rpc(["sendcoins", "--sweepall", "bcrt1qs758ursh4q9z627kt3pp5yysm78ddny6txaqgw"]) - confirm_transactions(self.testnet) - - with self.subTest(msg="implicit coins, relative amounts, anchor reserve created"): - address = self.testnet.master_node.getaddress() - self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) - confirm_transactions(self.testnet) - - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - channel_opener.open_channels( - amounts="1,2", - pubkeys=pubkey_input, - reckless=True, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) - - wallet_utxo_amounts = [utxo.amount_sat for utxo in wallet_utxos_after] - self.assertIn(openchannels.ANCHOR_RESERVE, wallet_utxo_amounts) - - with self.subTest(msg="implicit coins, nested-P2WKH, too large amounts"): - amount1 = 5_000_000 - amount2 = 6_000_000 - address = self.testnet.master_node.getaddress(address_type='np2wkh') - self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) - confirm_transactions(self.testnet) - - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - channel_opener.open_channels( - amounts=f"{amount1},{amount2}", - pubkeys=pubkey_input, - reckless=True, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) - - total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) - total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) - # test that we have reduced the amounts - self.assertGreater(amount1 + amount2, total_capacity_after - total_capacity_before) - - with self.subTest(msg="implicit coins, nested-P2WKH, too large amounts"): - address = self.testnet.master_node.getaddress(address_type='np2wkh') - self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) - confirm_transactions(self.testnet) - total_amount = 20_000_000 - wallet_utxos_before = self.lndnode.get_utxos() - channels_before = self.lndnode.get_open_channels() - channel_opener.open_channels( - total_amount=total_amount, - pubkeys=pubkey_input, - reckless=True, - ) - confirm_transactions(self.testnet) - wallet_utxos_after = self.lndnode.get_utxos() - channels_after = self.lndnode.get_open_channels() - - self.assertEqual(2, len(channels_after) - len(channels_before)) - self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) - - total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) - total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) - # test that we have reduced the total amount - self.assertGreater(total_amount, total_capacity_after - total_capacity_before) - - with self.subTest(msg="implicit coins, full spend, wumbo violation"): - address = self.testnet.master_node.getaddress() - self.testnet.bitcoind.sendtoaddress(address, (2 * openchannels.WUMBO_LIMIT + 1000) * 1E-8) - - confirm_transactions(self.testnet) - self.assertRaises( - ValueError, channel_opener.open_channels, - pubkeys=pubkey_input, - reckless=True, - ) + async def run_tests(): + async with self.lndnode: + with self.subTest(msg="(implicit), using amounts, change created, high fees"): + amount1 = 111_111 + amount2 = 222_222 + + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + channel_opener.open_channels( + pubkeys=pubkey_input, + amounts=f"{amount1},{amount2}", + reckless=True, + sat_per_vbyte=20, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(0, len(wallet_utxos_before) - len(wallet_utxos_after)) + + channel_capacities_after = [channel['capacity'] for channel in channels_after.values()] + self.assertIn(amount1, channel_capacities_after) + self.assertIn(amount2, channel_capacities_after) + + with self.subTest(msg="(implicit), total amount, private, change created"): + total_amount = 4_444_444 + + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + channel_opener.open_channels( + pubkeys=pubkey_input, + total_amount=total_amount, + reckless=True, + private=True, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(0, len(wallet_utxos_before) - len(wallet_utxos_after)) + + total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) + total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) + self.assertEqual(total_amount, total_capacity_after - total_capacity_before) + num_private_channels = len([True for v in channels_after.values() if v['private']]) + self.assertEqual(2, num_private_channels) + + with self.subTest(msg="(explicit), spend fully, no change created"): + address = self.testnet.master_node.getaddress() + self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) + confirm_transactions(self.testnet) + + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + # maybe make sure we select the correct utxo + spent_utxo = wallet_utxos_before[0] + utxo_input = f"{spent_utxo.txid}:{spent_utxo.output_index}" + channel_opener.open_channels( + utxos=utxo_input, + pubkeys=pubkey_input, + reckless=True, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) + + self.assertNotIn(spent_utxo, wallet_utxos_after) + + # clear wallet, but keep anchor reserves, leaves 50000 sat + self.testnet.master_node.rpc(["sendcoins", "--sweepall", "bcrt1qs758ursh4q9z627kt3pp5yysm78ddny6txaqgw"]) + confirm_transactions(self.testnet) + + with self.subTest(msg="implicit coins, relative amounts, anchor reserve created"): + address = self.testnet.master_node.getaddress() + self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) + confirm_transactions(self.testnet) + + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + channel_opener.open_channels( + amounts="1,2", + pubkeys=pubkey_input, + reckless=True, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) + + wallet_utxo_amounts = [utxo.amount_sat for utxo in wallet_utxos_after] + self.assertIn(openchannels.ANCHOR_RESERVE, wallet_utxo_amounts) + + with self.subTest(msg="implicit coins, nested-P2WKH, too large amounts"): + amount1 = 5_000_000 + amount2 = 6_000_000 + address = self.testnet.master_node.getaddress(address_type='np2wkh') + self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) + confirm_transactions(self.testnet) + + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + channel_opener.open_channels( + amounts=f"{amount1},{amount2}", + pubkeys=pubkey_input, + reckless=True, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) + + total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) + total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) + # test that we have reduced the amounts + self.assertGreater(amount1 + amount2, total_capacity_after - total_capacity_before) + + with self.subTest(msg="implicit coins, nested-P2WKH, too large amounts"): + address = self.testnet.master_node.getaddress(address_type='np2wkh') + self.testnet.bitcoind.sendtoaddress(address, 0.10_000_000) + confirm_transactions(self.testnet) + total_amount = 20_000_000 + wallet_utxos_before = self.lndnode.get_utxos() + channels_before = self.lndnode.get_open_channels() + channel_opener.open_channels( + total_amount=total_amount, + pubkeys=pubkey_input, + reckless=True, + ) + confirm_transactions(self.testnet) + wallet_utxos_after = self.lndnode.get_utxos() + channels_after = self.lndnode.get_open_channels() + + self.assertEqual(2, len(channels_after) - len(channels_before)) + self.assertEqual(1, len(wallet_utxos_before) - len(wallet_utxos_after)) + + total_capacity_before = sum([channel['capacity'] for channel in channels_before.values()]) + total_capacity_after = sum([channel['capacity'] for channel in channels_after.values()]) + # test that we have reduced the total amount + self.assertGreater(total_amount, total_capacity_after - total_capacity_before) + + with self.subTest(msg="implicit coins, full spend, wumbo violation"): + address = self.testnet.master_node.getaddress() + self.testnet.bitcoind.sendtoaddress(address, (2 * openchannels.WUMBO_LIMIT + 1000) * 1E-8) + + confirm_transactions(self.testnet) + self.assertRaises( + ValueError, channel_opener.open_channels, + pubkeys=pubkey_input, + reckless=True, + ) + asyncio.run(run_tests()) class FeeTest(TestCase): diff --git a/test/test_rebalance.py b/test/test_rebalance.py index bbf5770..2df73e3 100644 --- a/test/test_rebalance.py +++ b/test/test_rebalance.py @@ -1,4 +1,5 @@ """Integration tests for rebalancing of channels.""" +import asyncio import time from typing import Optional @@ -17,7 +18,7 @@ class RebalanceTest(TestNetwork): """ Implements an abstract testing class for channel rebalancing. """ - def rebalance_and_check( + async def rebalance_and_check( self, test_channel_number: int, target: Optional[float], @@ -36,51 +37,52 @@ def rebalance_and_check( :param places: accuracy of the comparison between expected and tested values :type places: int """ - graph_before = self.testnet.assemble_graph() + async with self.lndnode: + graph_before = self.testnet.assemble_graph() - rebalancer = Rebalancer( - self.lndnode, - max_effective_fee_rate=5E-6, - budget_sat=20, - force=allow_uneconomic, - ) + rebalancer = Rebalancer( + self.lndnode, + max_effective_fee_rate=5E-6, + budget_sat=20, + force=allow_uneconomic, + ) - channel_id = self.testnet.channel_mapping[ - test_channel_number]['channel_id'] + channel_id = self.testnet.channel_mapping[ + test_channel_number]['channel_id'] - fees_msat = rebalancer.rebalance( - channel_id, - dry=False, - target=target, - amount_sat=amount_sat, - ) + fees_msat = rebalancer.rebalance( + channel_id, + dry=False, + target=target, + amount_sat=amount_sat, + ) - # sleep a bit to let LNDs update their balances - time.sleep(SLEEP_SEC_AFTER_REBALANCING) + # sleep a bit to let LNDs update their balances + time.sleep(SLEEP_SEC_AFTER_REBALANCING) - # check if graph has the desired channel balances - graph_after = self.testnet.assemble_graph() + # check if graph has the desired channel balances + graph_after = self.testnet.assemble_graph() - channel_data_before = graph_before['A'][test_channel_number] - channel_data_after = graph_after['A'][test_channel_number] - amount_sent = channel_data_before['local_balance'] - channel_data_after['local_balance'] + channel_data_before = graph_before['A'][test_channel_number] + channel_data_after = graph_after['A'][test_channel_number] + amount_sent = channel_data_before['local_balance'] - channel_data_after['local_balance'] - channel_unbalancedness, _ = local_balance_to_unbalancedness( - channel_data_after['local_balance'], - channel_data_after['capacity'], - channel_data_after['commit_fee'], - channel_data_after['initiator'] - ) + channel_unbalancedness, _ = local_balance_to_unbalancedness( + channel_data_after['local_balance'], + channel_data_after['capacity'], + channel_data_after['commit_fee'], + channel_data_after['initiator'] + ) - if target is not None: - self.assertAlmostEqual( - target, channel_unbalancedness, places=places) + if target is not None: + self.assertAlmostEqual( + target, channel_unbalancedness, places=places) - elif amount_sat is not None: - self.assertAlmostEqual( - amount_sat, amount_sent, places=places) + elif amount_sat is not None: + self.assertAlmostEqual( + amount_sat, amount_sent, places=places) - return fees_msat + return fees_msat def graph_test(self): """ @@ -106,21 +108,49 @@ def graph_test(self): def test_non_init_balanced(self): test_channel_number = 6 - self.rebalance_and_check(test_channel_number, target=0.0, amount_sat=None, allow_uneconomic=True) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=0.0, + amount_sat=None, + allow_uneconomic=True + ) + ) def test_non_init_small_positive_target(self): test_channel_number = 6 - self.rebalance_and_check(test_channel_number, target=0.2, amount_sat=None, allow_uneconomic=True) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=0.2, + amount_sat=None, + allow_uneconomic=True + ) + ) def test_non_init_max_target(self): test_channel_number = 6 - self.rebalance_and_check(test_channel_number, target=1.0, amount_sat=None, allow_uneconomic=True) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=1.0, + amount_sat=None, + allow_uneconomic=True + ) + ) def test_non_init_negative_target(self): # this test should fail when unbalancing is not allowed, as it would # unbalance another channel if the full target would be accounted for test_channel_number = 6 - self.rebalance_and_check(test_channel_number, target=-0.2, amount_sat=None, allow_uneconomic=True) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=-0.2, + amount_sat=None, + allow_uneconomic=True + ) + ) def test_non_init_fail_due_to_economic(self): # this test should fail when unbalancing is not allowed, as it would @@ -128,19 +158,50 @@ def test_non_init_fail_due_to_economic(self): test_channel_number = 6 self.assertRaises( NoRebalanceCandidates, - self.rebalance_and_check, test_channel_number, target=-0.2, amount_sat=None, allow_uneconomic=False) + asyncio.run, + self.rebalance_and_check( + test_channel_number, + target=-0.2, + amount_sat=None, + allow_uneconomic=False + ) + ) def test_init_balanced(self): test_channel_number = 1 - self.rebalance_and_check(test_channel_number, target=0.0, amount_sat=None, allow_uneconomic=True, places=1) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=0.0, + amount_sat=None, + allow_uneconomic=True, + places=1 + ) + ) def test_init_already_balanced(self): test_channel_number = 2 - self.rebalance_and_check(test_channel_number, target=0.0, amount_sat=None, allow_uneconomic=True, places=2) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=0.0, + amount_sat=None, + allow_uneconomic=True, + places=2 + ) + ) def test_init_default_amount(self): test_channel_number = 1 - self.rebalance_and_check(test_channel_number, target=None, amount_sat=None, allow_uneconomic=True, places=-1) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=None, + amount_sat=None, + allow_uneconomic=True, + places=-1 + ) + ) def test_shuffle_arround(self): """Shuffles sats around in channel 6.""" @@ -148,10 +209,22 @@ def test_shuffle_arround(self): second_target_amount = 0.1 test_channel_number = 6 - self.rebalance_and_check( - test_channel_number, target=first_target_amount, amount_sat=None, allow_uneconomic=True) - self.rebalance_and_check( - test_channel_number, target=second_target_amount, amount_sat=None, allow_uneconomic=True) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=first_target_amount, + amount_sat=None, + allow_uneconomic=True + ) + ) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=second_target_amount, + amount_sat=None, + allow_uneconomic=True + ) + ) class TestUnbalancedRebalance(RebalanceTest): @@ -169,8 +242,15 @@ def graph_test(self): def test_channel_1(self): """tests multiple rebalance of one channel""" test_channel_number = 1 - print(self.rebalance_and_check( - test_channel_number, target=-0.05, amount_sat=None, allow_uneconomic=True, places=1)) + asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=-0.05, + amount_sat=None, + allow_uneconomic=True, + places=1 + ) + ) class TestIlliquidRebalance(RebalanceTest): @@ -188,7 +268,13 @@ def graph_test(self): def test_channel_1_splitting(self): """Tests multiple payment attempts with splitting.""" test_channel_number = 1 - fees_msat = self.rebalance_and_check( - test_channel_number, target=-0.05, amount_sat=None, allow_uneconomic=True, places=1) + fees_msat = asyncio.run( + self.rebalance_and_check( + test_channel_number, + target=-0.05, + amount_sat=None, + allow_uneconomic=True, + places=1 + ) + ) self.assertAlmostEqual(2000, fees_msat, places=-3) - diff --git a/test/testing_common.py b/test/testing_common.py index 014ff64..97af1bb 100644 --- a/test/testing_common.py +++ b/test/testing_common.py @@ -1,3 +1,4 @@ +import asyncio import os import shutil from unittest import TestCase