Skip to content

Commit

Permalink
conn: Add the possibility to send and receive message through the con…
Browse files Browse the repository at this point in the history
…nection

This commit enable the connection to send a message through the wire in
an ergonomic way.

This feature is a basic blocks for the lnprototest refactoring that
allow to semplify how to write test with lnprototest in the future
by keeping the state with the peer by connection and keep inside
the runner just the necessary logic to interact with the node.

Signed-off-by: Vincenzo Palazzo <[email protected]>
  • Loading branch information
vincenzopalazzo committed Nov 11, 2024
1 parent 381f3b9 commit 3c5ef6f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 9 deletions.
36 changes: 31 additions & 5 deletions lnprototest/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
import coincurve
import functools

import pyln
from pyln.proto.message import Message

from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Union, Any, Callable

from pyln.proto.message import Message

from .bitfield import bitfield
from .errors import SpecFileError
from .structure import Sequence
from .event import Event, MustNotMsg, ExpectMsg
from .utils import privkey_expand
from .utils import privkey_expand, ResolvableStr, ResolvableInt, resolve_args
from .keyset import KeySet
from .namespace import namespace

Expand Down Expand Up @@ -78,6 +77,33 @@ def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any:
raise SpecFileError(event, "Unknown stash name {}".format(stashname))
return self.stash[stashname]

def recv_msg(
self, timeout: int = 1000, skip_filter: Optional[int] = None
) -> Message:
"""Listen on the connection for incoming message.
If the {skip_filter} is specified, the message that
match the filters are skipped.
"""
raw_msg = self.connection.read_message()
msg = Message.read(namespace(), io.BytesIO(raw_msg))
self.add_stash(msg.messagetype.name, msg)
return msg

def send_msg(
self, msg_name: str, **kwargs: Union[ResolvableStr, ResolvableInt]
) -> None:
"""Send a message through the last connection"""
msgtype = namespace().get_msgtype(msg_name)
msg = Message(msgtype, **resolve_args(self, kwargs))
missing = msg.missing_fields()
if missing:
raise SpecFileError(self, "Missing fields {}".format(missing))
binmsg = io.BytesIO()
msg.write(binmsg)
self.connection.send_message(binmsg.getvalue())
# FIXME: we should listen to possible connection here


class Runner(ABC):
"""Abstract base class for runners.
Expand Down Expand Up @@ -189,7 +215,7 @@ def is_running(self) -> bool:
pass

@abstractmethod
def connect(self, event: Event, connprivkey: str) -> None:
def connect(self, event: Event, connprivkey: str) -> RunnerConn:
pass

def send_msg(self, msg: Message) -> None:
Expand Down
6 changes: 6 additions & 0 deletions lnprototest/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
check_hex,
privkey_for_index,
merge_events_sequences,
Resolvable,
ResolvableBool,
ResolvableInt,
ResolvableStr,
resolve_arg,
resolve_args,
)
from .bitcoin_utils import (
ScriptType,
Expand Down
24 changes: 23 additions & 1 deletion lnprototest/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
import logging
import traceback

from typing import Union, Sequence, List
from typing import Union, Sequence, List, Dict, Callable, Any
from enum import IntEnum

from lnprototest.keyset import KeySet

# Type for arguments: either strings, or functions to call at runtime
ResolvableStr = Union[str, Callable[["RunnerConn", "Event", str], str]]
ResolvableInt = Union[int, Callable[["RunnerConn", "Event", str], int]]
ResolvableBool = Union[int, Callable[["RunnerConn", "Event", str], bool]]
Resolvable = Union[Any, Callable[["RunnerConn", "Event", str], Any]]


class Side(IntEnum):
local = 0
Expand Down Expand Up @@ -106,3 +112,19 @@ def merge_events_sequences(
"""Merge the two list in the pre-post order"""
pre.extend(post)
return pre


def resolve_arg(fieldname: str, conn: "RunnerConn", arg: Resolvable) -> Any:
"""If this is a string, return it, otherwise call it to get result"""
if callable(arg):
return arg(conn, fieldname)
else:
return arg


def resolve_args(conn: "RunnerConn", kwargs: Dict[str, Resolvable]) -> Dict[str, Any]:
"""Take a dict of args, replace callables with their return values"""
ret: Dict[str, str] = {}
for field, str_or_func in kwargs.items():
ret[field] = resolve_arg(field, conn, str_or_func)
return ret
6 changes: 3 additions & 3 deletions tests/test_v2_bolt1-01-init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def test_v2_init_is_first_msg(runner: Runner, namespaceoverride: Any) -> None:
"""
runner.start()

runner.connect(None, connprivkey="03")
init_msg = runner.recv_msg()
conn1 = runner.connect(None, connprivkey="03")
init_msg = conn1.recv_msg()
assert (
init_msg.messagetype.number == 16
), f"received not an init msg but: {init_msg.to_str()}"

conn1.send_msg("init", globalfeatures="", features="")
runner.stop()

0 comments on commit 3c5ef6f

Please sign in to comment.