From 8dc0b05708205b0086924969282e1715a011979c Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Mon, 11 Mar 2024 08:17:32 -0500 Subject: [PATCH] fix: remove use of sync.Once to avoid deadlocks (#523) Fixes #522 --- muxer/muxer.go | 35 +++++++++++++++++++++---------- protocol/handshake/server_test.go | 11 ++++++++-- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/muxer/muxer.go b/muxer/muxer.go index 7cda8212..b8a44c09 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -22,6 +22,7 @@ package muxer import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "net" @@ -64,7 +65,6 @@ type Muxer struct { protocolReceivers map[uint16]map[ProtocolRole]chan *Segment protocolReceiversMutex sync.Mutex diffusionMode DiffusionMode - onceStart sync.Once onceStop sync.Once } @@ -78,8 +78,21 @@ func New(conn net.Conn) *Muxer { protocolSenders: make(map[uint16]map[ProtocolRole]chan *Segment), protocolReceivers: make(map[uint16]map[ProtocolRole]chan *Segment), } + // Start read goroutine m.waitGroup.Add(1) go m.readLoop() + // Start cleanup routine + go func() { + // Wait for done signal + <-m.doneChan + // Close underlying connection + // We must do this to break out of pending Read() calls to shut down cleanly + _ = m.conn.Close() + // Wait for other goroutines to shutdown + m.waitGroup.Wait() + // Close ErrorChan to signify to consumer that we're shutting down + close(m.errorChan) + }() return m } @@ -89,9 +102,10 @@ func (m *Muxer) ErrorChan() chan error { // Start unblocks the read loop after the initial handshake to allow it to start processing messages func (m *Muxer) Start() { - m.onceStart.Do(func() { - m.startChan <- true - }) + select { + case m.startChan <- true: + default: + } } // Stop shuts down the muxer @@ -99,13 +113,6 @@ func (m *Muxer) Stop() { m.onceStop.Do(func() { // Close doneChan to signify that we're shutting down close(m.doneChan) - // Close underlying connection - // We must do this to break out of pending Read() calls to shut down cleanly - _ = m.conn.Close() - // Wait for other goroutines to shutdown - m.waitGroup.Wait() - // Close ErrorChan to signify to consumer that we're shutting down - close(m.errorChan) }) } @@ -220,6 +227,9 @@ func (m *Muxer) readLoop() { } header := SegmentHeader{} if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + err = io.EOF + } m.sendError(err) return } @@ -230,6 +240,9 @@ func (m *Muxer) readLoop() { // We use ReadFull because it guarantees to read the expected number of bytes or // return an error if _, err := io.ReadFull(m.conn, msg.Payload); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + err = io.EOF + } m.sendError(err) return } diff --git a/protocol/handshake/server_test.go b/protocol/handshake/server_test.go index 606fbdef..f29c3eea 100644 --- a/protocol/handshake/server_test.go +++ b/protocol/handshake/server_test.go @@ -106,8 +106,15 @@ func TestServerHandshakeRefuseVersionMismatch(t *testing.T) { InputMessageType: handshake.MessageTypeRefuse, InputMessage: handshake.NewMsgRefuse( []any{ - handshake.RefuseReasonVersionMismatch, - protocol.GetProtocolVersionsNtC(), + uint64(handshake.RefuseReasonVersionMismatch), + // Convert []uint16 to []any + func(in []uint16) []any { + var ret []any + for _, item := range in { + ret = append(ret, item) + } + return ret + }(protocol.GetProtocolVersionsNtC()), }, ), },