Skip to content
This repository has been archived by the owner on Sep 2, 2024. It is now read-only.

Commit

Permalink
Merge pull request #92 from getAlby/fix/ldk-remove-mutex
Browse files Browse the repository at this point in the history
Fix: LDK remove mutex
  • Loading branch information
rolznz authored Mar 8, 2024
2 parents fc43a5b + 72864f5 commit 028ec1a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 63 deletions.
88 changes: 25 additions & 63 deletions ldk.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/getAlby/ldk-node-go/ldk_node"
Expand All @@ -20,12 +19,11 @@ import (
)

type LDKService struct {
svc *Service
workdir string
node *ldk_node.LdkNode
cancelLdkEventListenerCtx context.CancelFunc
subscribeLdkEvents func() chan ldk_node.Event
unsubscribeLdkEvents func(chan ldk_node.Event)
svc *Service
workdir string
node *ldk_node.LdkNode
ldkEventBroadcaster LDKEventBroadcaster
cancel context.CancelFunc
}

func NewLDKService(svc *Service, mnemonic, workDir string, network string, esploraServer string, gossipSource string) (result lnclient.LNClient, err error) {
Expand Down Expand Up @@ -72,43 +70,14 @@ func NewLDKService(svc *Service, mnemonic, workDir string, network string, esplo
return nil, err
}

// TODO: move this event handler code
ldkEventListenerCtx, cancelLdkEventListenerCtx := context.WithCancel(context.Background())
ldkEventHandlers := []chan ldk_node.Event{}
var ldkEventHandlersMutex sync.Mutex

subscribeLdkEvents := func() chan ldk_node.Event {
ldkEventHandler := make(chan ldk_node.Event)
svc.Logger.Debugf("Locking event handler mutex")
ldkEventHandlersMutex.Lock()
svc.Logger.Debugf("Locked event handler mutex")
ldkEventHandlers = append(ldkEventHandlers, ldkEventHandler)
ldkEventHandlersMutex.Unlock()
svc.Logger.Debugf("Unlocked event handler mutex")
return ldkEventHandler
}

unsubscribeLdkEvents := func(eventHandler chan ldk_node.Event) {
svc.Logger.Debugf("Locking event handler mutex")
ldkEventHandlersMutex.Lock()
svc.Logger.Debugf("Locked event handler mutex")
for i := 0; i < len(ldkEventHandlers); i++ {
if eventHandler == ldkEventHandlers[i] {
// Replace the element to be removed with the last element of the slice
ldkEventHandlers[i] = ldkEventHandlers[len(ldkEventHandlers)-1]
// Slice off the last element
ldkEventHandlers = ldkEventHandlers[:len(ldkEventHandlers)-1]
break
}
}
ldkEventHandlersMutex.Unlock()
svc.Logger.Debugf("Unlocked event handler mutex")
}
ldkEventConsumer := make(chan *ldk_node.Event)
ctx, cancel := context.WithCancel(svc.ctx)

// check for and forward new LDK events to LDKEventBroadcaster (through ldkEventConsumer)
go func() {
for {
select {
case <-ldkEventListenerCtx.Done():
case <-ctx.Done():
return
default:
// NOTE: do not use WaitNextEvent() as it can block the LDK thread
Expand All @@ -117,15 +86,9 @@ func NewLDKService(svc *Service, mnemonic, workDir string, network string, esplo
time.Sleep(time.Duration(1) * time.Millisecond)
continue
}
svc.Logger.Debugf("Locking event handler mutex")
ldkEventHandlersMutex.Lock()
svc.Logger.Debugf("Locked event handler mutex")
svc.Logger.Infof("Received LDK event %+v (%d listeners)", *event, len(ldkEventHandlers))
for _, eventHandler := range ldkEventHandlers {
eventHandler <- *event
}
ldkEventHandlersMutex.Unlock()
svc.Logger.Debugf("Unlocked event handler mutex")

svc.Logger.Infof("Received LDK event %+v", *event)
ldkEventConsumer <- event

node.EventHandled()
}
Expand All @@ -136,10 +99,9 @@ func NewLDKService(svc *Service, mnemonic, workDir string, network string, esplo
workdir: newpath,
node: node,
//listener: &listener,
svc: svc,
cancelLdkEventListenerCtx: cancelLdkEventListenerCtx,
subscribeLdkEvents: subscribeLdkEvents,
unsubscribeLdkEvents: unsubscribeLdkEvents,
svc: svc,
cancel: cancel,
ldkEventBroadcaster: NewLDKEventBroadcaster(svc.Logger, ctx, ldkEventConsumer),
}

nodeId := node.NodeId()
Expand All @@ -155,16 +117,16 @@ func NewLDKService(svc *Service, mnemonic, workDir string, network string, esplo

func (gs *LDKService) Shutdown() error {
gs.svc.Logger.Infof("shutting down LDK client")
gs.cancelLdkEventListenerCtx()
gs.cancel()
gs.node.Destroy()

return nil
}

func (gs *LDKService) SendPaymentSync(ctx context.Context, payReq string) (preimage string, err error) {
paymentStart := time.Now()
eventListener := gs.subscribeLdkEvents()
defer gs.unsubscribeLdkEvents(eventListener)
ldkEventSubscription := gs.ldkEventBroadcaster.Subscribe()
defer gs.ldkEventBroadcaster.CancelSubscription(ldkEventSubscription)

paymentHash, err := gs.node.SendPayment(payReq)
if err != nil {
Expand All @@ -173,10 +135,10 @@ func (gs *LDKService) SendPaymentSync(ctx context.Context, payReq string) (preim
}

for start := time.Now(); time.Since(start) < time.Second*60; {
event := <-eventListener
event := <-ldkEventSubscription

eventPaymentSuccessful, isEventPaymentSuccessfulEvent := event.(ldk_node.EventPaymentSuccessful)
eventPaymentFailed, isEventPaymentFailedEvent := event.(ldk_node.EventPaymentFailed)
eventPaymentSuccessful, isEventPaymentSuccessfulEvent := (*event).(ldk_node.EventPaymentSuccessful)
eventPaymentFailed, isEventPaymentFailedEvent := (*event).(ldk_node.EventPaymentFailed)

if isEventPaymentSuccessfulEvent && eventPaymentSuccessful.PaymentHash == paymentHash {
gs.svc.Logger.Infof("Got payment success event")
Expand Down Expand Up @@ -435,8 +397,8 @@ func (gs *LDKService) OpenChannel(ctx context.Context, openChannelRequest *lncli
return nil, errors.New("node is not peered yet")
}

eventListener := gs.subscribeLdkEvents()
defer gs.unsubscribeLdkEvents(eventListener)
ldkEventSubscription := gs.ldkEventBroadcaster.Subscribe()
defer gs.ldkEventBroadcaster.CancelSubscription(ldkEventSubscription)

gs.svc.Logger.Infof("Opening channel with: %v", foundPeer.NodeId)
userChannelId, err := gs.node.ConnectOpenChannel(foundPeer.NodeId, foundPeer.Address, uint64(openChannelRequest.Amount), nil, nil, openChannelRequest.Public)
Expand All @@ -449,9 +411,9 @@ func (gs *LDKService) OpenChannel(ctx context.Context, openChannelRequest *lncli
gs.svc.Logger.Infof("Funded channel: %v", userChannelId)

for start := time.Now(); time.Since(start) < time.Second*60; {
event := <-eventListener
event := <-ldkEventSubscription

channelPendingEvent, isChannelPendingEvent := event.(ldk_node.EventChannelPending)
channelPendingEvent, isChannelPendingEvent := (*event).(ldk_node.EventChannelPending)

if !isChannelPendingEvent {
continue
Expand Down
98 changes: 98 additions & 0 deletions ldk_event_broadcaster.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package main

import (
"context"
"slices"
"time"

"github.com/getAlby/ldk-node-go/ldk_node"
"github.com/sirupsen/logrus"
)

// based on https://betterprogramming.pub/how-to-broadcast-messages-in-go-using-channels-b68f42bdf32e
type ldkEventBroadcastServer struct {
logger *logrus.Logger
source <-chan *ldk_node.Event
listeners []chan *ldk_node.Event
addListener chan chan *ldk_node.Event
removeListener chan (<-chan *ldk_node.Event)
}

type LDKEventBroadcaster interface {
Subscribe() chan *ldk_node.Event
CancelSubscription(chan *ldk_node.Event)
}

func NewLDKEventBroadcaster(logger *logrus.Logger, ctx context.Context, source <-chan *ldk_node.Event) LDKEventBroadcaster {
service := &ldkEventBroadcastServer{
logger: logger,
source: source,
listeners: make([]chan *ldk_node.Event, 0),
addListener: make(chan chan *ldk_node.Event),
removeListener: make(chan (<-chan *ldk_node.Event)),
}
go service.serve(ctx)
return service
}

func (s *ldkEventBroadcastServer) Subscribe() chan *ldk_node.Event {
newListener := make(chan *ldk_node.Event)
s.addListener <- newListener
return newListener
}

func (s *ldkEventBroadcastServer) CancelSubscription(channel chan *ldk_node.Event) {
close(channel)
s.removeListener <- channel
}

func (s *ldkEventBroadcastServer) serve(ctx context.Context) {
defer func() {
for _, listener := range s.listeners {
func() {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Failed to close channel: %v", r)
}
}()
close(listener)
}()
}
}()

for {
select {
case <-ctx.Done():
return
case newListener := <-s.addListener:
s.listeners = append(s.listeners, newListener)
case listenerToRemove := <-s.removeListener:
for i, listener := range s.listeners {
if listener == listenerToRemove {
s.listeners[i] = s.listeners[len(s.listeners)-1]
s.listeners = slices.Delete(s.listeners, len(s.listeners)-1, len(s.listeners))
break
}
}
case event := <-s.source:
s.logger.Debugf("Sending LDK event %+v to %d listeners", *event, len(s.listeners))
for _, listener := range s.listeners {
func() {
// if we fail to send the event to the listener it was probably closed
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("Failed to send event to listener: %v", r)
}
}()

select {
case listener <- event:
s.logger.Debugln("sent event to listener")
case <-time.After(5 * time.Second):
s.logger.Errorf("Timeout sending %+v to listener", *event)
}
}()
}
}
}
}

0 comments on commit 028ec1a

Please sign in to comment.