Skip to content

Commit

Permalink
fix: protocol shutdown deadlocks
Browse files Browse the repository at this point in the history
Fixes #478
  • Loading branch information
agaffney committed Jan 26, 2024
1 parent acb24fa commit 50c3fce
Showing 1 changed file with 88 additions and 66 deletions.
154 changes: 88 additions & 66 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ func (p *Protocol) Start() {
close(p.sendReadyChan)
// Cancel any timer
if p.stateTransitionTimer != nil {
// Stop timer and drain channel
if !p.stateTransitionTimer.Stop() {
<-p.stateTransitionTimer.C
}
p.stateTransitionTimer.Stop()
p.stateTransitionTimer = nil
}
}()
Expand Down Expand Up @@ -201,54 +198,15 @@ func (p *Protocol) sendLoop() {
// Check for queued state changes from previous pipelined sends
setNewState = false
if len(p.sendStateQueueChan) > 0 {
msg := <-p.sendStateQueueChan
newState, err = p.getNewState(msg)
if err != nil {
p.SendError(
fmt.Errorf(
"%s: error sending message: %s",
p.config.Name,
err,
),
)
return
}
setNewState = true
// If there are no queued messages, set the new state now
if len(p.sendQueueChan) == 0 {
p.setState(newState)
p.stateMutex.Unlock()
continue
}
}
// Read queued messages and write into buffer
payloadBuf := bytes.NewBuffer(nil)
msgCount := 0
for {
// Get next message from send queue
msg, ok := <-p.sendQueueChan
if !ok {
// We're shutting down
select {
case <-p.doneChan:
// Break out of send loop if we're shutting down
return
}
msgCount = msgCount + 1
// Write the message into the send state queue if we already have a new state
if setNewState {
p.sendStateQueueChan <- msg
}
// Get raw CBOR from message
data := msg.Cbor()
// If message has no raw CBOR, encode the message
if data == nil {
var err error
data, err = cbor.Encode(msg)
if err != nil {
p.SendError(err)
case msg, ok := <-p.sendStateQueueChan:
if !ok {
// We're shutting down
return
}
}
payloadBuf.Write(data)
if !setNewState {
newState, err = p.getNewState(msg)
if err != nil {
p.SendError(
Expand All @@ -261,21 +219,82 @@ func (p *Protocol) sendLoop() {
return
}
setNewState = true
// If there are no queued messages, set the new state now
if len(p.sendQueueChan) == 0 {
p.setState(newState)
p.stateMutex.Unlock()
continue
}
}
// We don't want more than maxMessagesPerSegment messages in a segment
if msgCount >= maxMessagesPerSegment {
break
}
// We don't want to add more messages once we spill over into a second segment
if payloadBuf.Len() > muxer.SegmentMaxPayloadLength {
break
}
// Check if there are any more queued messages
if len(p.sendQueueChan) == 0 {
break
}
// Read queued messages and write into buffer
payloadBuf := bytes.NewBuffer(nil)
msgCount := 0
breakLoop := false
for {
// Get next message from send queue
select {
case <-p.doneChan:
// Break out of send loop if we're shutting down
return
case msg, ok := <-p.sendQueueChan:
if !ok {
// We're shutting down
return
}
msgCount = msgCount + 1
// Write the message into the send state queue if we already have a new state
if setNewState {
p.sendStateQueueChan <- msg
}
// Get raw CBOR from message
data := msg.Cbor()
// If message has no raw CBOR, encode the message
if data == nil {
var err error
data, err = cbor.Encode(msg)
if err != nil {
p.SendError(err)
return
}
}
payloadBuf.Write(data)
if !setNewState {
newState, err = p.getNewState(msg)
if err != nil {
p.SendError(
fmt.Errorf(
"%s: error sending message: %s",
p.config.Name,
err,
),
)
return
}
setNewState = true
}
// We don't want more than maxMessagesPerSegment messages in a segment
if msgCount >= maxMessagesPerSegment {
breakLoop = true
break
}
// We don't want to add more messages once we spill over into a second segment
if payloadBuf.Len() > muxer.SegmentMaxPayloadLength {
breakLoop = true
break
}
// Check if there are any more queued messages
if len(p.sendQueueChan) == 0 {
breakLoop = true
break
}
// We don't want to block on writes to the send state queue
if len(p.sendStateQueueChan) == cap(p.sendStateQueueChan) {
breakLoop = true
break
}
}
// We don't want to block on writes to the send state queue
if len(p.sendStateQueueChan) == cap(p.sendStateQueueChan) {
if breakLoop {
break
}
}
Expand Down Expand Up @@ -322,6 +341,9 @@ func (p *Protocol) recvLoop() {
if !leftoverData {
// Wait for segment
select {
case <-p.doneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
close(p.doneChan)
return
Expand All @@ -337,6 +359,9 @@ func (p *Protocol) recvLoop() {
leftoverData = false
// Wait until ready to receive based on state map
select {
case <-p.doneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
close(p.doneChan)
return
Expand Down Expand Up @@ -425,10 +450,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
func (p *Protocol) setState(state State) {
// Disable any previous state transition timer
if p.stateTransitionTimer != nil {
// Stop timer and drain channel
if !p.stateTransitionTimer.Stop() {
<-p.stateTransitionTimer.C
}
p.stateTransitionTimer.Stop()
p.stateTransitionTimer = nil
}
// Set the new state
Expand Down

0 comments on commit 50c3fce

Please sign in to comment.