From c0c42bf408a1c32591bdcb50aafd47c122589982 Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Thu, 1 Feb 2024 18:03:45 -0600 Subject: [PATCH] fix: protect against send on closed channel panics Fixes #481 --- connection.go | 3 +-- muxer/muxer.go | 22 +++++++++++++--------- protocol/protocol.go | 25 +++++++++++++++---------- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/connection.go b/connection.go index 27e41ea2..14f13436 100644 --- a/connection.go +++ b/connection.go @@ -212,9 +212,8 @@ func (c *Connection) shutdown() { } // Wait for other goroutines to finish c.waitGroup.Wait() - // Close channels + // Close consumer error channel to signify connection shutdown close(c.errorChan) - close(c.protoErrorChan) // We can only close a channel once, so we have to jump through a few hoops select { // The channel is either closed or has an item pending diff --git a/muxer/muxer.go b/muxer/muxer.go index 1c42b374..e6a654b1 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -103,13 +103,6 @@ func (m *Muxer) Stop() { _ = m.conn.Close() // Wait for other goroutines to shutdown m.waitGroup.Wait() - // Close protocol receive channels - // We rely on the individual mini-protocols to close the sender channel - for _, protocolRoles := range m.protocolReceivers { - for _, recvChan := range protocolRoles { - close(recvChan) - } - } // Close ErrorChan to signify to consumer that we're shutting down close(m.errorChan) }) @@ -161,7 +154,10 @@ func (m *Muxer) RegisterProtocol( if !ok { return } - case msg := <-senderChan: + case msg, ok := <-senderChan: + if !ok { + return + } if err := m.Send(msg); err != nil { m.sendError(err) return @@ -200,7 +196,15 @@ func (m *Muxer) Send(msg *Segment) error { // readLoop waits for incoming data on the connection, parses the segment, and passes it to the appropriate // protocol func (m *Muxer) readLoop() { - defer m.waitGroup.Done() + defer func() { + m.waitGroup.Done() + // Close receiver channels + for _, protocolRoles := range m.protocolReceivers { + for _, recvChan := range protocolRoles { + close(recvChan) + } + } + }() started := false for { // Break out of read loop if we're shutting down diff --git a/protocol/protocol.go b/protocol/protocol.go index 9bc07689..2af68f74 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -131,11 +131,6 @@ func (p *Protocol) Start() { <-p.doneChan // Wait for all other goroutines to finish p.waitGroup.Wait() - // Close channels - close(p.sendQueueChan) - close(p.sendStateQueueChan) - close(p.recvReadyChan) - close(p.sendReadyChan) // Cancel any timer if p.stateTransitionTimer != nil { p.stateTransitionTimer.Stop() @@ -174,20 +169,30 @@ func (p *Protocol) SendMessage(msg Message) error { // SendError sends an error to the handler in the Ouroboros object func (p *Protocol) SendError(err error) { - p.config.ErrorChan <- err + select { + case p.config.ErrorChan <- err: + default: + // Discard error if the buffer is full + // The connection will get closed on the first error, so any + // additional errors are unnecessary + return + } } func (p *Protocol) sendLoop() { - defer p.waitGroup.Done() + defer func() { + p.waitGroup.Done() + // Close muxer send channel + // We are responsible for closing this channel as the sender, even through it + // was created by the muxer + close(p.muxerSendChan) + }() var setNewState bool var newState State var err error for { select { case <-p.doneChan: - // We are responsible for closing this channel as the sender, even through it - // was created by the muxer - close(p.muxerSendChan) // Break out of send loop if we're shutting down return case <-p.sendReadyChan: