diff --git a/protocol/protocol.go b/protocol/protocol.go index 2af68f74..b7391269 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -32,21 +32,17 @@ const maxMessagesPerSegment = 20 // Protocol implements the base functionality of an Ouroboros mini-protocol type Protocol struct { - config ProtocolConfig - muxerSendChan chan *muxer.Segment - muxerRecvChan chan *muxer.Segment - muxerDoneChan chan bool - state State - stateMutex sync.Mutex - recvBuffer *bytes.Buffer - sendQueueChan chan Message - sendStateQueueChan chan Message - recvReadyChan chan bool - sendReadyChan chan bool - doneChan chan bool - waitGroup sync.WaitGroup - stateTransitionTimer *time.Timer - onceStart sync.Once + config ProtocolConfig + muxerSendChan chan *muxer.Segment + muxerRecvChan chan *muxer.Segment + muxerDoneChan chan bool + sendQueueChan chan Message + recvReadyChan chan bool + sendReadyChan chan bool + stateTransitionChan chan<- protocolStateTransition + doneChan chan bool + waitGroup sync.WaitGroup + onceStart sync.Once } // ProtocolConfig provides the configuration for Protocol @@ -92,6 +88,11 @@ type ProtocolOptions struct { Version uint16 } +type protocolStateTransition struct { + msg Message + errorChan chan<- error +} + // MessageHandlerFunc represents a function that handles an incoming message type MessageHandlerFunc func(Message) error @@ -119,28 +120,19 @@ func (p *Protocol) Start() { p.config.ProtocolId, muxerProtocolRole, ) - // Create buffers and channels - p.recvBuffer = bytes.NewBuffer(nil) + + // Create channels p.sendQueueChan = make(chan Message, 50) - p.sendStateQueueChan = make(chan Message, 50) p.recvReadyChan = make(chan bool, 1) p.sendReadyChan = make(chan bool, 1) - // Start goroutine to cleanup when shutting down - go func() { - // Wait for doneChan to be closed - <-p.doneChan - // Wait for all other goroutines to finish - p.waitGroup.Wait() - // Cancel any timer - if p.stateTransitionTimer != nil { - p.stateTransitionTimer.Stop() - p.stateTransitionTimer = nil - } - }() - // Set initial state - p.setState(p.config.InitialState) + + stateTransitionChan := make(chan protocolStateTransition) + p.stateTransitionChan = stateTransitionChan + // Start our send and receive Goroutines p.waitGroup.Add(2) + + go p.stateLoop(stateTransitionChan) go p.recvLoop() go p.sendLoop() }) @@ -187,9 +179,7 @@ func (p *Protocol) sendLoop() { // was created by the muxer close(p.muxerSendChan) }() - var setNewState bool - var newState State - var err error + for { select { case <-p.doneChan: @@ -198,40 +188,7 @@ func (p *Protocol) sendLoop() { case <-p.sendReadyChan: // We are ready to send based on state map } - // Lock the state to prevent collisions - p.stateMutex.Lock() - // Check for queued state changes from previous pipelined sends - setNewState = false - if len(p.sendStateQueueChan) > 0 { - select { - case <-p.doneChan: - // Break out of send loop if we're shutting down - return - case msg, ok := <-p.sendStateQueueChan: - if !ok { - // We're shutting down - return - } - newState, err = p.getNewState(msg) - if err != nil { - p.SendError( - fmt.Errorf( - "%s: error sending message: %s", - p.config.Name, - err, - ), - ) - return - } - setNewState = true - // If there are no queued messages, set the new state now - if len(p.sendQueueChan) == 0 { - p.setState(newState) - p.stateMutex.Unlock() - continue - } - } - } + // Read queued messages and write into buffer payloadBuf := bytes.NewBuffer(nil) msgCount := 0 @@ -248,10 +205,7 @@ func (p *Protocol) sendLoop() { return } msgCount = msgCount + 1 - // Write the message into the send state queue if we already have a new state - if setNewState { - p.sendStateQueueChan <- msg - } + // Get raw CBOR from message data := msg.Cbor() // If message has no raw CBOR, encode the message @@ -264,20 +218,18 @@ func (p *Protocol) sendLoop() { } } payloadBuf.Write(data) - if !setNewState { - newState, err = p.getNewState(msg) - if err != nil { - p.SendError( - fmt.Errorf( - "%s: error sending message: %s", - p.config.Name, - err, - ), - ) - return - } - setNewState = true + + if err := p.transitionState(msg); err != nil { + p.SendError( + fmt.Errorf( + "%s: error sending message: %s", + p.config.Name, + err, + ), + ) + return } + // We don't want more than maxMessagesPerSegment messages in a segment if msgCount >= maxMessagesPerSegment { breakLoop = true @@ -293,16 +245,12 @@ func (p *Protocol) sendLoop() { breakLoop = true break } - // We don't want to block on writes to the send state queue - if len(p.sendStateQueueChan) == cap(p.sendStateQueueChan) { - breakLoop = true - break - } } if breakLoop { break } } + // Send messages in multiple segments (if needed) for { // Determine segment payload length @@ -331,15 +279,14 @@ func (p *Protocol) sendLoop() { break } } - // Set new state and unlock - p.setState(newState) - p.stateMutex.Unlock() } } func (p *Protocol) recvLoop() { defer p.waitGroup.Done() leftoverData := false + recvBuffer := bytes.NewBuffer(nil) + for { var err error // Don't grab the next segment from the muxer if we still have data in the buffer @@ -358,7 +305,7 @@ func (p *Protocol) recvLoop() { return } // Add segment payload to buffer - p.recvBuffer.Write(segment.Payload) + recvBuffer.Write(segment.Payload) } } leftoverData = false @@ -376,9 +323,9 @@ func (p *Protocol) recvLoop() { // This also lets us determine how many bytes the message is. We use RawMessage here to // avoid parsing things that we may not be able to parse var tmpMsg []cbor.RawMessage - numBytesRead, err := cbor.Decode(p.recvBuffer.Bytes(), &tmpMsg) + numBytesRead, err := cbor.Decode(recvBuffer.Bytes(), &tmpMsg) if err != nil { - if err == io.ErrUnexpectedEOF && p.recvBuffer.Len() > 0 { + if err == io.ErrUnexpectedEOF && recvBuffer.Len() > 0 { // This is probably a multi-part message, so we wait until we get more of the message // before trying to process it p.recvReadyChan <- true @@ -393,7 +340,7 @@ func (p *Protocol) recvLoop() { p.SendError(fmt.Errorf("%s: decode error: %s", p.config.Name, err)) } // Create Message object from CBOR - msgData := p.recvBuffer.Bytes()[:numBytesRead] + msgData := recvBuffer.Bytes()[:numBytesRead] msg, err := p.config.MessageFromCborFunc(msgType, msgData) if err != nil { p.SendError(err) @@ -414,22 +361,135 @@ func (p *Protocol) recvLoop() { p.SendError(err) return } - if numBytesRead < p.recvBuffer.Len() { + if numBytesRead < recvBuffer.Len() { // There is another message in the same muxer segment, so we reset the buffer with just // the remaining data - p.recvBuffer = bytes.NewBuffer(p.recvBuffer.Bytes()[numBytesRead:]) + recvBuffer = bytes.NewBuffer(recvBuffer.Bytes()[numBytesRead:]) leftoverData = true } else { // Empty out our buffer since we successfully processed the message - p.recvBuffer.Reset() + recvBuffer.Reset() + } + } +} + +func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { + var currentState State + var transitionTimer *time.Timer + + setState := func(s State) { + // Disable any previous state transition timer + if transitionTimer != nil && !transitionTimer.Stop() { + <-transitionTimer.C + } + transitionTimer = nil + + // Set the new state + currentState = s + + // Mark protocol as ready to send/receive based on role and agency of the new state + switch p.config.StateMap[currentState].Agency { + case AgencyClient: + switch p.config.Role { + case ProtocolRoleClient: + select { + case p.sendReadyChan <- true: + default: + } + case ProtocolRoleServer: + select { + case p.recvReadyChan <- true: + default: + } + } + case AgencyServer: + switch p.config.Role { + case ProtocolRoleServer: + select { + case p.sendReadyChan <- true: + default: + } + case ProtocolRoleClient: + select { + case p.recvReadyChan <- true: + default: + } + } + } + + // Set timeout for state transition + if p.config.StateMap[currentState].Timeout > 0 { + transitionTimer = time.NewTimer(p.config.StateMap[currentState].Timeout) + } + } + getTimerChan := func() <-chan time.Time { + if transitionTimer == nil { + return nil + } + return transitionTimer.C + } + + protocolDoneChan := p.doneChan + stateDoneChan := make(chan struct{}) + + setState(p.config.InitialState) + + for { + select { + case t := <-ch: + nextState, err := p.nextState(currentState, t.msg) + if err != nil { + t.errorChan <- fmt.Errorf( + "%s: error handling protocol state transition: %s", + p.config.Name, + err, + ) + + // It is the responsibility of the caller to initiate the shutdown of the protocol, + // so the state handler should keep running to ensure other state transitions + // requesters do not encounter a deadlock + continue + } + + setState(nextState) + t.errorChan <- nil + + case <-getTimerChan(): + transitionTimer = nil + + p.SendError( + fmt.Errorf( + "%s: timeout waiting on transition from protocol state %s", + p.config.Name, + currentState, + ), + ) + + case <-protocolDoneChan: + // Disable this case so it doesn't block + protocolDoneChan = nil + + // Wait for all other goroutines to finish before shutting down the state handler + go func() { + p.waitGroup.Wait() + + close(stateDoneChan) + }() + + case <-stateDoneChan: + // All other goroutines have finished, so we can stop the timer and return + if transitionTimer != nil && !transitionTimer.Stop() { + <-transitionTimer.C + } + transitionTimer = nil + + return } } } -func (p *Protocol) getNewState(msg Message) (State, error) { - var newState State - matchFound := false - for _, transition := range p.config.StateMap[p.state].Transitions { +func (p *Protocol) nextState(currentState State, msg Message) (State, error) { + for _, transition := range p.config.StateMap[currentState].Transitions { if transition.MsgType == msg.Type() { if transition.MatchFunc != nil { // Skip item if match function returns false @@ -437,73 +497,29 @@ func (p *Protocol) getNewState(msg Message) (State, error) { continue } } - newState = transition.NewState - matchFound = true - break + return transition.NewState, nil } } - if !matchFound { - return newState, fmt.Errorf( - "message %s not allowed in current protocol state %s", - reflect.TypeOf(msg).Name(), - p.state, - ) - } - return newState, nil + + return State{}, fmt.Errorf( + "message %s not allowed in current protocol state %s", + reflect.TypeOf(msg).Name(), + currentState, + ) } -func (p *Protocol) setState(state State) { - // Disable any previous state transition timer - if p.stateTransitionTimer != nil { - p.stateTransitionTimer.Stop() - p.stateTransitionTimer = nil - } - // Set the new state - p.state = state - // Mark protocol as ready to send/receive based on role and agency of the new state - switch p.config.StateMap[p.state].Agency { - case AgencyClient: - switch p.config.Role { - case ProtocolRoleClient: - p.sendReadyChan <- true - case ProtocolRoleServer: - p.recvReadyChan <- true - } - case AgencyServer: - switch p.config.Role { - case ProtocolRoleServer: - p.sendReadyChan <- true - case ProtocolRoleClient: - p.recvReadyChan <- true - } - } - // Set timeout for state transition - if p.config.StateMap[p.state].Timeout > 0 { - p.stateTransitionTimer = time.AfterFunc( - p.config.StateMap[p.state].Timeout, - func() { - p.SendError( - fmt.Errorf( - "%s: timeout waiting on transition from protocol state %s", - p.config.Name, - p.state, - ), - ) - }, - ) - } +func (p *Protocol) transitionState(msg Message) error { + errorChan := make(chan error, 1) + p.stateTransitionChan <- protocolStateTransition{msg, errorChan} + + return <-errorChan } func (p *Protocol) handleMessage(msg Message) error { - // Lock the state to prevent collisions - p.stateMutex.Lock() - newState, err := p.getNewState(msg) - if err != nil { + if err := p.transitionState(msg); err != nil { return fmt.Errorf("%s: error handling message: %s", p.config.Name, err) } - // Set new state and unlock - p.setState(newState) - p.stateMutex.Unlock() + // Call handler function return p.config.MessageHandlerFunc(msg) }