Skip to content

Commit

Permalink
fix: remove use of sync.Once to avoid deadlocks (#523)
Browse files Browse the repository at this point in the history
Fixes #522
  • Loading branch information
agaffney authored Mar 11, 2024
1 parent 9f9589f commit 8dc0b05
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
35 changes: 24 additions & 11 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package muxer
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -64,7 +65,6 @@ type Muxer struct {
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
protocolReceiversMutex sync.Mutex
diffusionMode DiffusionMode
onceStart sync.Once
onceStop sync.Once
}

Expand All @@ -78,8 +78,21 @@ func New(conn net.Conn) *Muxer {
protocolSenders: make(map[uint16]map[ProtocolRole]chan *Segment),
protocolReceivers: make(map[uint16]map[ProtocolRole]chan *Segment),
}
// Start read goroutine
m.waitGroup.Add(1)
go m.readLoop()
// Start cleanup routine
go func() {
// Wait for done signal
<-m.doneChan
// Close underlying connection
// We must do this to break out of pending Read() calls to shut down cleanly
_ = m.conn.Close()
// Wait for other goroutines to shutdown
m.waitGroup.Wait()
// Close ErrorChan to signify to consumer that we're shutting down
close(m.errorChan)
}()
return m
}

Expand All @@ -89,23 +102,17 @@ func (m *Muxer) ErrorChan() chan error {

// Start unblocks the read loop after the initial handshake to allow it to start processing messages
func (m *Muxer) Start() {
m.onceStart.Do(func() {
m.startChan <- true
})
select {
case m.startChan <- true:
default:
}
}

// Stop shuts down the muxer
func (m *Muxer) Stop() {
m.onceStop.Do(func() {
// Close doneChan to signify that we're shutting down
close(m.doneChan)
// Close underlying connection
// We must do this to break out of pending Read() calls to shut down cleanly
_ = m.conn.Close()
// Wait for other goroutines to shutdown
m.waitGroup.Wait()
// Close ErrorChan to signify to consumer that we're shutting down
close(m.errorChan)
})
}

Expand Down Expand Up @@ -220,6 +227,9 @@ func (m *Muxer) readLoop() {
}
header := SegmentHeader{}
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
err = io.EOF
}
m.sendError(err)
return
}
Expand All @@ -230,6 +240,9 @@ func (m *Muxer) readLoop() {
// We use ReadFull because it guarantees to read the expected number of bytes or
// return an error
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
err = io.EOF
}
m.sendError(err)
return
}
Expand Down
11 changes: 9 additions & 2 deletions protocol/handshake/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,15 @@ func TestServerHandshakeRefuseVersionMismatch(t *testing.T) {
InputMessageType: handshake.MessageTypeRefuse,
InputMessage: handshake.NewMsgRefuse(
[]any{
handshake.RefuseReasonVersionMismatch,
protocol.GetProtocolVersionsNtC(),
uint64(handshake.RefuseReasonVersionMismatch),
// Convert []uint16 to []any
func(in []uint16) []any {
var ret []any
for _, item := range in {
ret = append(ret, item)
}
return ret
}(protocol.GetProtocolVersionsNtC()),
},
),
},
Expand Down

0 comments on commit 8dc0b05

Please sign in to comment.