diff --git a/README.md b/README.md index 51504af0..9344b2bb 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ A Go client implementation of the Cardano Ouroboros network protocol This is loosely based on the [official Haskell implementation](https://github.com/input-output-hk/ouroboros-network) +NOTE: this library is under heavily development, and the interface should not be considered stable until it reaches `v1.0.0` + ## Implementation status The Ouroboros protocol consists of a simple multiplexer protocol and various mini-protocols that run on top of it. diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 9e0ab3c5..8038bca3 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -20,6 +20,51 @@ var ( STATE_DONE = protocol.NewState(4, "Done") ) +var stateMap = protocol.StateMap{ + STATE_IDLE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_CLIENT, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_REQUEST_RANGE, + NewState: STATE_BUSY, + }, + { + MsgType: MESSAGE_TYPE_CLIENT_DONE, + NewState: STATE_DONE, + }, + }, + }, + STATE_BUSY: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_START_BATCH, + NewState: STATE_STREAMING, + }, + { + MsgType: MESSAGE_TYPE_NO_BLOCKS, + NewState: STATE_IDLE, + }, + }, + }, + STATE_STREAMING: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_BLOCK, + NewState: STATE_STREAMING, + }, + { + MsgType: MESSAGE_TYPE_BATCH_DONE, + NewState: STATE_IDLE, + }, + }, + }, + STATE_DONE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_NONE, + }, +} + type BlockFetch struct { proto *protocol.Protocol callbackConfig *BlockFetchCallbackConfig @@ -42,9 +87,17 @@ func New(m *muxer.Muxer, errorChan chan error, callbackConfig *BlockFetchCallbac b := &BlockFetch{ callbackConfig: callbackConfig, } - b.proto = protocol.New(PROTOCOL_NAME, PROTOCOL_ID, m, errorChan, b.messageHandler, NewMsgFromCbor) - // Set initial state - b.proto.SetState(STATE_IDLE) + protoConfig := protocol.ProtocolConfig{ + Name: PROTOCOL_NAME, + ProtocolId: PROTOCOL_ID, + Muxer: m, + ErrorChan: errorChan, + MessageHandlerFunc: b.messageHandler, + MessageFromCborFunc: NewMsgFromCbor, + StateMap: stateMap, + InitialState: STATE_IDLE, + } + b.proto = protocol.New(protoConfig) return b } @@ -66,57 +119,32 @@ func (b *BlockFetch) messageHandler(msg protocol.Message) error { } func (b *BlockFetch) RequestRange(start []interface{}, end []interface{}) error { - if err := b.proto.LockState([]protocol.State{STATE_IDLE}); err != nil { - return fmt.Errorf("%s: RequestRange: protocol not in expected state", PROTOCOL_NAME) - } msg := newMsgRequestRange(start, end) - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_BUSY) - // Send request return b.proto.SendMessage(msg, false) } func (b *BlockFetch) ClientDone() error { - if err := b.proto.LockState([]protocol.State{STATE_IDLE}); err != nil { - return fmt.Errorf("%s: ClientDone: protocol not in expected state", PROTOCOL_NAME) - } msg := newMsgClientDone() - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_BUSY) - // Send request return b.proto.SendMessage(msg, false) } func (b *BlockFetch) handleStartBatch() error { - if err := b.proto.LockState([]protocol.State{STATE_BUSY}); err != nil { - return fmt.Errorf("received block-fetch StartBatch message when protocol not in expected state") - } if b.callbackConfig.StartBatchFunc == nil { return fmt.Errorf("received block-fetch StartBatch message but no callback function is defined") } - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_STREAMING) // Call the user callback function return b.callbackConfig.StartBatchFunc() } func (b *BlockFetch) handleNoBlocks() error { - if err := b.proto.LockState([]protocol.State{STATE_BUSY}); err != nil { - return fmt.Errorf("received block-fetch NoBlocks message when protocol not in expected state") - } if b.callbackConfig.NoBlocksFunc == nil { return fmt.Errorf("received block-fetch NoBlocks message but no callback function is defined") } - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_IDLE) // Call the user callback function return b.callbackConfig.NoBlocksFunc() } func (b *BlockFetch) handleBlock(msgGeneric protocol.Message) error { - if err := b.proto.LockState([]protocol.State{STATE_STREAMING}); err != nil { - return fmt.Errorf("received block-fetch Block message when protocol not in expected state") - } if b.callbackConfig.BlockFunc == nil { return fmt.Errorf("received block-fetch Block message but no callback function is defined") } @@ -130,21 +158,14 @@ func (b *BlockFetch) handleBlock(msgGeneric protocol.Message) error { if err != nil { return err } - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_STREAMING) // Call the user callback function return b.callbackConfig.BlockFunc(wrapBlock.Type, blk) } func (b *BlockFetch) handleBatchDone() error { - if err := b.proto.LockState([]protocol.State{STATE_STREAMING}); err != nil { - return fmt.Errorf("received block-fetch BatchDone message when protocol not in expected state") - } if b.callbackConfig.BatchDoneFunc == nil { return fmt.Errorf("received block-fetch BatchDone message but no callback function is defined") } - // Unlock and change state when we're done - defer b.proto.UnlockState(STATE_IDLE) // Call the user callback function return b.callbackConfig.BatchDoneFunc() } diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 0ba3b3a1..7595e125 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -22,6 +22,72 @@ var ( STATE_DONE = protocol.NewState(5, "Done") ) +var stateMap = protocol.StateMap{ + STATE_IDLE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_CLIENT, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_REQUEST_NEXT, + NewState: STATE_CAN_AWAIT, + }, + { + MsgType: MESSAGE_TYPE_FIND_INTERSECT, + NewState: STATE_INTERSECT, + }, + { + MsgType: MESSAGE_TYPE_DONE, + NewState: STATE_DONE, + }, + }, + }, + STATE_CAN_AWAIT: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_AWAIT_REPLY, + NewState: STATE_MUST_REPLY, + }, + { + MsgType: MESSAGE_TYPE_ROLL_FORWARD, + NewState: STATE_IDLE, + }, + { + MsgType: MESSAGE_TYPE_ROLL_BACKWARD, + NewState: STATE_IDLE, + }, + }, + }, + STATE_INTERSECT: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_INTERSECT_FOUND, + NewState: STATE_IDLE, + }, + { + MsgType: MESSAGE_TYPE_INTERSECT_NOT_FOUND, + NewState: STATE_IDLE, + }, + }, + }, + STATE_MUST_REPLY: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_ROLL_FORWARD, + NewState: STATE_IDLE, + }, + { + MsgType: MESSAGE_TYPE_ROLL_BACKWARD, + NewState: STATE_IDLE, + }, + }, + }, + STATE_DONE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_NONE, + }, +} + type ChainSync struct { proto *protocol.Protocol nodeToNode bool @@ -56,9 +122,17 @@ func New(m *muxer.Muxer, errorChan chan error, nodeToNode bool, callbackConfig * nodeToNode: nodeToNode, callbackConfig: callbackConfig, } - c.proto = protocol.New(PROTOCOL_NAME, protocolId, m, errorChan, c.messageHandler, c.NewMsgFromCbor) - // Set initial state - c.proto.SetState(STATE_IDLE) + protoConfig := protocol.ProtocolConfig{ + Name: PROTOCOL_NAME, + ProtocolId: protocolId, + Muxer: m, + ErrorChan: errorChan, + MessageHandlerFunc: c.messageHandler, + MessageFromCborFunc: c.NewMsgFromCbor, + StateMap: stateMap, + InitialState: STATE_IDLE, + } + c.proto = protocol.New(protoConfig) return c } @@ -84,45 +158,24 @@ func (c *ChainSync) messageHandler(msg protocol.Message) error { } func (c *ChainSync) RequestNext() error { - if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil { - return fmt.Errorf("%s: RequestNext: protocol not in expected state", PROTOCOL_NAME) - } - // Create our request msg := newMsgRequestNext() - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_CAN_AWAIT) - // Send request return c.proto.SendMessage(msg, false) } func (c *ChainSync) FindIntersect(points []interface{}) error { - if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil { - return fmt.Errorf("%s: FindIntersect: protocol not in expected state", PROTOCOL_NAME) - } msg := newMsgFindIntersect(points) - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_INTERSECT) - // Send request return c.proto.SendMessage(msg, false) } func (c *ChainSync) handleAwaitReply() error { - if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT}); err != nil { - return fmt.Errorf("received chain-sync AwaitReply message when protocol not in expected state") - } if c.callbackConfig.AwaitReplyFunc == nil { return fmt.Errorf("received chain-sync AwaitReply message but no callback function is defined") } - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_MUST_REPLY) // Call the user callback function return c.callbackConfig.AwaitReplyFunc() } func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error { - if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT, STATE_MUST_REPLY}); err != nil { - return fmt.Errorf("received chain-sync RollForward message when protocol not in expected state") - } if c.callbackConfig.RollForwardFunc == nil { return fmt.Errorf("received chain-sync RollForward message but no callback function is defined") } @@ -163,8 +216,6 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error { return err } } - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_IDLE) // Call the user callback function return c.callbackConfig.RollForwardFunc(blockType, blockHeader) } else { @@ -178,64 +229,42 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error { if err != nil { return err } - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_IDLE) // Call the user callback function return c.callbackConfig.RollForwardFunc(wrapBlock.Type, blk) } } func (c *ChainSync) handleRollBackward(msgGeneric protocol.Message) error { - if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT, STATE_MUST_REPLY}); err != nil { - return fmt.Errorf("received chain-sync RollBackward message when protocol not in expected state") - } if c.callbackConfig.RollBackwardFunc == nil { return fmt.Errorf("received chain-sync RollBackward message but no callback function is defined") } msg := msgGeneric.(*msgRollBackward) - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_IDLE) // Call the user callback function return c.callbackConfig.RollBackwardFunc(msg.Point, msg.Tip) } func (c *ChainSync) handleIntersectFound(msgGeneric protocol.Message) error { - if err := c.proto.LockState([]protocol.State{STATE_INTERSECT}); err != nil { - return fmt.Errorf("received chain-sync IntersectFound message when protocol not in expected state") - } if c.callbackConfig.IntersectFoundFunc == nil { return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined") } msg := msgGeneric.(*msgIntersectFound) - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_IDLE) // Call the user callback function return c.callbackConfig.IntersectFoundFunc(msg.Point, msg.Tip) } func (c *ChainSync) handleIntersectNotFound(msgGeneric protocol.Message) error { - if err := c.proto.LockState([]protocol.State{STATE_INTERSECT}); err != nil { - return fmt.Errorf("received chain-sync IntersectNotFound message when protocol not in expected state") - } if c.callbackConfig.IntersectNotFoundFunc == nil { return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined") } msg := msgGeneric.(*msgIntersectNotFound) - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_IDLE) // Call the user callback function return c.callbackConfig.IntersectNotFoundFunc(msg.Tip) } func (c *ChainSync) handleDone() error { - if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil { - return fmt.Errorf("received chain-sync Done message when protocol not in expected state") - } if c.callbackConfig.DoneFunc == nil { return fmt.Errorf("received chain-sync Done message but no callback function is defined") } - // Unlock and change state when we're done - defer c.proto.UnlockState(STATE_DONE) // Call the user callback function return c.callbackConfig.DoneFunc() } diff --git a/protocol/handshake/handshake.go b/protocol/handshake/handshake.go index 42ffb140..8d5c0b1b 100644 --- a/protocol/handshake/handshake.go +++ b/protocol/handshake/handshake.go @@ -20,6 +20,34 @@ var ( STATE_DONE = protocol.NewState(3, "Done") ) +var stateMap = protocol.StateMap{ + STATE_PROPOSE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_CLIENT, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_PROPOSE_VERSIONS, + NewState: STATE_CONFIRM, + }, + }, + }, + STATE_CONFIRM: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_ACCEPT_VERSION, + NewState: STATE_DONE, + }, + { + MsgType: MESSAGE_TYPE_REFUSE, + NewState: STATE_DONE, + }, + }, + }, + STATE_DONE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_NONE, + }, +} + type Handshake struct { proto *protocol.Protocol nodeToNode bool @@ -32,9 +60,17 @@ func New(m *muxer.Muxer, errorChan chan error, nodeToNode bool) *Handshake { nodeToNode: nodeToNode, Finished: make(chan bool, 1), } - h.proto = protocol.New(PROTOCOL_NAME, PROTOCOL_ID, m, errorChan, h.handleMessage, NewMsgFromCbor) - // Set initial state - h.proto.SetState(STATE_PROPOSE) + protoConfig := protocol.ProtocolConfig{ + Name: PROTOCOL_NAME, + ProtocolId: PROTOCOL_ID, + Muxer: m, + ErrorChan: errorChan, + MessageHandlerFunc: h.handleMessage, + MessageFromCborFunc: NewMsgFromCbor, + StateMap: stateMap, + InitialState: STATE_PROPOSE, + } + h.proto = protocol.New(protoConfig) return h } @@ -54,9 +90,6 @@ func (h *Handshake) handleMessage(msg protocol.Message) error { } func (h *Handshake) ProposeVersions(versions []uint16, networkMagic uint32) error { - if err := h.proto.LockState([]protocol.State{STATE_PROPOSE}); err != nil { - return fmt.Errorf("protocol not in expected state") - } // Create our request versionMap := make(map[uint16]interface{}) for _, version := range versions { @@ -67,36 +100,22 @@ func (h *Handshake) ProposeVersions(versions []uint16, networkMagic uint32) erro } } msg := newMsgProposeVersions(versionMap) - // Unlock and change state when we're done - defer h.proto.UnlockState(STATE_CONFIRM) - // Send request return h.proto.SendMessage(msg, false) } func (h *Handshake) handleProposeVersions(msgGeneric protocol.Message) error { - if err := h.proto.LockState([]protocol.State{STATE_CONFIRM}); err != nil { - return fmt.Errorf("received handshake request when protocol is in wrong state") - } // TODO: implement me return fmt.Errorf("handshake request handling not yet implemented") } func (h *Handshake) handleAcceptVersion(msgGeneric protocol.Message) error { - if err := h.proto.LockState([]protocol.State{STATE_CONFIRM}); err != nil { - return fmt.Errorf("received handshake accept response when protocol is in wrong state") - } msg := msgGeneric.(*msgAcceptVersion) h.Version = msg.Version h.Finished <- true - // Unlock and change state when we're done - defer h.proto.UnlockState(STATE_DONE) return nil } func (h *Handshake) handleRefuse(msgGeneric protocol.Message) error { - if err := h.proto.LockState([]protocol.State{STATE_CONFIRM}); err != nil { - return fmt.Errorf("received handshake refuse response when protocol is in wrong state") - } msg := msgGeneric.(*msgRefuse) var err error switch msg.Reason[0].(uint64) { @@ -107,7 +126,5 @@ func (h *Handshake) handleRefuse(msgGeneric protocol.Message) error { case REFUSE_REASON_REFUSED: err = fmt.Errorf("%s: refused: %s", PROTOCOL_NAME, msg.Reason[2].(string)) } - // Unlock and change state when we're done - defer h.proto.UnlockState(STATE_DONE) return err } diff --git a/protocol/keepalive/keepalive.go b/protocol/keepalive/keepalive.go index b070ac7a..c4e8e0ed 100644 --- a/protocol/keepalive/keepalive.go +++ b/protocol/keepalive/keepalive.go @@ -21,6 +21,34 @@ var ( STATE_DONE = protocol.NewState(3, "Done") ) +var stateMap = protocol.StateMap{ + STATE_CLIENT: protocol.StateMapEntry{ + Agency: protocol.AGENCY_CLIENT, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_KEEP_ALIVE, + NewState: STATE_SERVER, + }, + { + MsgType: MESSAGE_TYPE_DONE, + NewState: STATE_DONE, + }, + }, + }, + STATE_SERVER: protocol.StateMapEntry{ + Agency: protocol.AGENCY_SERVER, + Transitions: []protocol.StateTransition{ + { + MsgType: MESSAGE_TYPE_KEEP_ALIVE_RESPONSE, + NewState: STATE_CLIENT, + }, + }, + }, + STATE_DONE: protocol.StateMapEntry{ + Agency: protocol.AGENCY_NONE, + }, +} + type KeepAlive struct { proto *protocol.Protocol callbackConfig *KeepAliveCallbackConfig @@ -42,9 +70,17 @@ func New(m *muxer.Muxer, errorChan chan error, callbackConfig *KeepAliveCallback k := &KeepAlive{ callbackConfig: callbackConfig, } - k.proto = protocol.New(PROTOCOL_NAME, PROTOCOL_ID, m, errorChan, k.messageHandler, NewMsgFromCbor) - // Set initial state - k.proto.SetState(STATE_CLIENT) + protoConfig := protocol.ProtocolConfig{ + Name: PROTOCOL_NAME, + ProtocolId: PROTOCOL_ID, + Muxer: m, + ErrorChan: errorChan, + MessageHandlerFunc: k.messageHandler, + MessageFromCborFunc: NewMsgFromCbor, + StateMap: stateMap, + InitialState: STATE_CLIENT, + } + k.proto = protocol.New(protoConfig) return k } @@ -80,23 +116,12 @@ func (k *KeepAlive) Stop() { } func (k *KeepAlive) KeepAlive(cookie uint16) error { - if err := k.proto.LockState([]protocol.State{STATE_CLIENT}); err != nil { - return fmt.Errorf("%s: KeepAlive: protocol not in expected state", PROTOCOL_NAME) - } msg := newMsgKeepAlive(cookie) - // Unlock and change state when we're done - defer k.proto.UnlockState(STATE_SERVER) - // Send request return k.proto.SendMessage(msg, false) } func (k *KeepAlive) handleKeepAlive(msgGeneric protocol.Message) error { - if err := k.proto.LockState([]protocol.State{STATE_CLIENT}); err != nil { - return fmt.Errorf("received keep-alive KeepAlive message when protocol not in expected state") - } msg := msgGeneric.(*msgKeepAlive) - // Unlock and change state when we're done - defer k.proto.UnlockState(STATE_CLIENT) if k.callbackConfig != nil && k.callbackConfig.KeepAliveFunc != nil { // Call the user callback function return k.callbackConfig.KeepAliveFunc(msg.Cookie) @@ -108,12 +133,7 @@ func (k *KeepAlive) handleKeepAlive(msgGeneric protocol.Message) error { } func (k *KeepAlive) handleKeepAliveResponse(msgGeneric protocol.Message) error { - if err := k.proto.LockState([]protocol.State{STATE_SERVER}); err != nil { - return fmt.Errorf("received keep-alive KeepAliveResponse message when protocol not in expected state") - } msg := msgGeneric.(*msgKeepAliveResponse) - // Unlock and change state when we're done - defer k.proto.UnlockState(STATE_CLIENT) // Start the timer again if we had one previously if k.timer != nil { defer k.Start() @@ -126,11 +146,6 @@ func (k *KeepAlive) handleKeepAliveResponse(msgGeneric protocol.Message) error { } func (k *KeepAlive) handleDone() error { - if err := k.proto.LockState([]protocol.State{STATE_CLIENT}); err != nil { - return fmt.Errorf("received keep-alive Done message when protocol not in expected state") - } - // Unlock and change state when we're done - defer k.proto.UnlockState(STATE_DONE) if k.callbackConfig != nil && k.callbackConfig.DoneFunc != nil { // Call the user callback function return k.callbackConfig.DoneFunc() diff --git a/protocol/protocol.go b/protocol/protocol.go index b25ecbac..03ee0bf1 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -10,95 +10,67 @@ import ( ) type Protocol struct { - protocolId uint16 - name string - errorChan chan error - sendChan chan *muxer.Segment - recvChan chan *muxer.Segment - state State - stateMutex sync.Mutex - recvBuffer *bytes.Buffer - msgHandlerFunc MessageHandlerFunc - msgFromCborFunc MessageFromCborFunc + config ProtocolConfig + sendChan chan *muxer.Segment + recvChan chan *muxer.Segment + state State + stateMutex sync.Mutex + recvBuffer *bytes.Buffer } -type State struct { - Id uint - Name string -} - -func NewState(id uint, name string) State { - return State{ - Id: id, - Name: name, - } -} - -func (s State) String() string { - return s.Name +type ProtocolConfig struct { + Name string + ProtocolId uint16 + ErrorChan chan error + Muxer *muxer.Muxer + MessageHandlerFunc MessageHandlerFunc + MessageFromCborFunc MessageFromCborFunc + StateMap StateMap + InitialState State } type MessageHandlerFunc func(Message) error type MessageFromCborFunc func(uint, []byte) (Message, error) -func New(name string, protocolId uint16, m *muxer.Muxer, errorChan chan error, handlerFunc MessageHandlerFunc, msgFromCborFunc MessageFromCborFunc) *Protocol { - sendChan, recvChan := m.RegisterProtocol(protocolId) +func New(config ProtocolConfig) *Protocol { + sendChan, recvChan := config.Muxer.RegisterProtocol(config.ProtocolId) p := &Protocol{ - name: name, - protocolId: protocolId, - errorChan: errorChan, - sendChan: sendChan, - recvChan: recvChan, - recvBuffer: bytes.NewBuffer(nil), - msgHandlerFunc: handlerFunc, - msgFromCborFunc: msgFromCborFunc, + config: config, + sendChan: sendChan, + recvChan: recvChan, + recvBuffer: bytes.NewBuffer(nil), } + // Set initial state + p.state = config.InitialState // Start our receiver Goroutine go p.recvLoop() return p } -func (p *Protocol) GetState() State { - return p.state -} - -func (p *Protocol) SetState(state State) { - p.state = state -} - -func (p *Protocol) LockState(allowedStates []State) error { +func (p *Protocol) SendMessage(msg Message, isResponse bool) error { + // Lock the state to prevent collisions p.stateMutex.Lock() - inAllowedState := false - for _, state := range allowedStates { - if state == p.state { - inAllowedState = true - break - } + if err := p.checkCurrentState(); err != nil { + return fmt.Errorf("%s: error sending message: %s", p.config.Name, err) } - if !inAllowedState { - p.stateMutex.Unlock() - return fmt.Errorf("protocol is not in allowed state (currently in state %s)", p.state.Name) + newState, err := p.getNewState(msg) + if err != nil { + return fmt.Errorf("%s: error sending message: %s", p.config.Name, err) } - return nil -} - -func (p *Protocol) UnlockState(newState State) { - p.state = newState - p.stateMutex.Unlock() -} - -func (p *Protocol) SendMessage(msg interface{}, isResponse bool) error { data, err := utils.CborEncode(msg) if err != nil { return err } - segment := muxer.NewSegment(p.protocolId, data, isResponse) + segment := muxer.NewSegment(p.config.ProtocolId, data, isResponse) p.sendChan <- segment + // Set new state and unlock + p.state = newState + p.stateMutex.Unlock() return nil } func (p *Protocol) SendError(err error) { - p.errorChan <- err + p.config.ErrorChan <- err } func (p *Protocol) recvLoop() { @@ -123,18 +95,20 @@ func (p *Protocol) recvLoop() { // before trying to process it continue } - p.errorChan <- fmt.Errorf("%s: decode error: %s", p.name, err) + p.config.ErrorChan <- fmt.Errorf("%s: decode error: %s", p.config.Name, err) } + // Create Message object from CBOR msgType := uint(tmpMsg[0].(uint64)) - msg, err := p.msgFromCborFunc(msgType, p.recvBuffer.Bytes()) + msg, err := p.config.MessageFromCborFunc(msgType, p.recvBuffer.Bytes()) if err != nil { - p.errorChan <- err + p.config.ErrorChan <- err } if msg == nil { - p.errorChan <- fmt.Errorf("%s: received unknown message type: %#v", p.name, tmpMsg) + p.config.ErrorChan <- fmt.Errorf("%s: received unknown message type: %#v", p.config.Name, tmpMsg) } - if err := p.msgHandlerFunc(msg); err != nil { - p.errorChan <- err + // Handle message + if err := p.handleMessage(msg); err != nil { + p.config.ErrorChan <- err } if numBytesRead < p.recvBuffer.Len() { // There is another message in the same muxer segment, so we reset the buffer with just @@ -147,3 +121,48 @@ func (p *Protocol) recvLoop() { } } } + +func (p *Protocol) checkCurrentState() error { + if currentStateMapEntry, ok := p.config.StateMap[p.state]; ok { + if currentStateMapEntry.Agency == AGENCY_NONE { + return fmt.Errorf("protocol is in state with no agency") + } + // TODO: check client/server agency + } else { + return fmt.Errorf("protocol in unknown state") + } + return nil +} + +func (p *Protocol) getNewState(msg Message) (State, error) { + var newState State + matchFound := false + for _, transition := range p.config.StateMap[p.state].Transitions { + if transition.MsgType == msg.Type() { + newState = transition.NewState + matchFound = true + break + } + } + if !matchFound { + return newState, fmt.Errorf("message not allowed in current protocol state") + } + return newState, nil +} + +func (p *Protocol) handleMessage(msg Message) error { + // Lock the state to prevent collisions + p.stateMutex.Lock() + if err := p.checkCurrentState(); err != nil { + return fmt.Errorf("%s: error handling message: %s", p.config.Name, err) + } + newState, err := p.getNewState(msg) + if err != nil { + return fmt.Errorf("%s: error handling message: %s", p.config.Name, err) + } + // Set new state and unlock + p.state = newState + p.stateMutex.Unlock() + // Call handler function + return p.config.MessageHandlerFunc(msg) +} diff --git a/protocol/state.go b/protocol/state.go new file mode 100644 index 00000000..58fbbaac --- /dev/null +++ b/protocol/state.go @@ -0,0 +1,35 @@ +package protocol + +const ( + AGENCY_NONE uint = 0 + AGENCY_CLIENT uint = 1 + AGENCY_SERVER uint = 2 +) + +type State struct { + Id uint + Name string +} + +func NewState(id uint, name string) State { + return State{ + Id: id, + Name: name, + } +} + +func (s State) String() string { + return s.Name +} + +type StateTransition struct { + MsgType uint8 + NewState State +} + +type StateMapEntry struct { + Agency uint + Transitions []StateTransition +} + +type StateMap map[State]StateMapEntry