diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 6bab15d7..89b367b6 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -16,6 +16,7 @@ package chainsync import ( + "sync" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -44,8 +45,9 @@ var StateMap = protocol.StateMap{ Agency: protocol.AgencyClient, Transitions: []protocol.StateTransition{ { - MsgType: MessageTypeRequestNext, - NewState: stateCanAwait, + MsgType: MessageTypeRequestNext, + NewState: stateCanAwait, + MatchFunc: IncrementPipelineCount, }, { MsgType: MessageTypeFindIntersect, @@ -60,17 +62,34 @@ var StateMap = protocol.StateMap{ stateCanAwait: protocol.StateMapEntry{ Agency: protocol.AgencyServer, Transitions: []protocol.StateTransition{ + { + MsgType: MessageTypeRequestNext, + NewState: stateCanAwait, + MatchFunc: IncrementPipelineCount, + }, { MsgType: MessageTypeAwaitReply, NewState: stateMustReply, }, { - MsgType: MessageTypeRollForward, - NewState: stateIdle, + MsgType: MessageTypeRollForward, + NewState: stateIdle, + MatchFunc: DecrementPipelineCountAndIsEmpty, }, { - MsgType: MessageTypeRollBackward, - NewState: stateIdle, + MsgType: MessageTypeRollForward, + NewState: stateCanAwait, + MatchFunc: DecrementPipelineCountAndIsNotEmpty, + }, + { + MsgType: MessageTypeRollBackward, + NewState: stateIdle, + MatchFunc: DecrementPipelineCountAndIsEmpty, + }, + { + MsgType: MessageTypeRollBackward, + NewState: stateCanAwait, + MatchFunc: DecrementPipelineCountAndIsNotEmpty, }, }, }, @@ -91,12 +110,24 @@ var StateMap = protocol.StateMap{ Agency: protocol.AgencyServer, Transitions: []protocol.StateTransition{ { - MsgType: MessageTypeRollForward, - NewState: stateIdle, + MsgType: MessageTypeRollForward, + NewState: stateIdle, + MatchFunc: DecrementPipelineCountAndIsEmpty, }, { - MsgType: MessageTypeRollBackward, - NewState: stateIdle, + MsgType: MessageTypeRollForward, + NewState: stateCanAwait, + MatchFunc: DecrementPipelineCountAndIsNotEmpty, + }, + { + MsgType: MessageTypeRollBackward, + NewState: stateIdle, + MatchFunc: DecrementPipelineCountAndIsEmpty, + }, + { + MsgType: MessageTypeRollBackward, + NewState: stateCanAwait, + MatchFunc: DecrementPipelineCountAndIsNotEmpty, }, }, }, @@ -105,6 +136,60 @@ var StateMap = protocol.StateMap{ }, } +type StateContext struct { + mu sync.Mutex + pipelineCount int +} + +var IncrementPipelineCount = func(context interface{}, msg protocol.Message) bool { + s := context.(*StateContext) + s.mu.Lock() + defer s.mu.Unlock() + + s.pipelineCount++ + return true +} + +var DecrementPipelineCountAndIsEmpty = func(context interface{}, msg protocol.Message) bool { + s := context.(*StateContext) + s.mu.Lock() + defer s.mu.Unlock() + + if s.pipelineCount == 1 { + s.pipelineCount-- + return true + } + return false +} + +var DecrementPipelineCountAndIsNotEmpty = func(context interface{}, msg protocol.Message) bool { + s := context.(*StateContext) + s.mu.Lock() + defer s.mu.Unlock() + + if s.pipelineCount > 1 { + s.pipelineCount-- + return true + } + return false +} + +var PipelineIsEmtpy = func(context interface{}, msg protocol.Message) bool { + s := context.(*StateContext) + s.mu.Lock() + defer s.mu.Unlock() + + return s.pipelineCount == 0 +} + +var PipelineIsNotEmpty = func(context interface{}, msg protocol.Message) bool { + s := context.(*StateContext) + s.mu.Lock() + defer s.mu.Unlock() + + return s.pipelineCount > 0 +} + // ChainSync is a wrapper object that holds the client and server instances type ChainSync struct { Client *Client @@ -137,9 +222,11 @@ type RequestNextFunc func(CallbackContext) error // New returns a new ChainSync object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync { + stateContext := &StateContext{} + c := &ChainSync{ - Client: NewClient(protoOptions, cfg), - Server: NewServer(protoOptions, cfg), + Client: NewClient(stateContext, protoOptions, cfg), + Server: NewServer(stateContext, protoOptions, cfg), } return c } diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 553c5bfd..f2a4bded 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -43,7 +43,7 @@ type Client struct { } // NewClient returns a new ChainSync client object -func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { +func NewClient(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Client { // Use node-to-client protocol ID ProtocolId := ProtocolIdNtC msgFromCborFunc := NewMsgFromCborNtC @@ -91,6 +91,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { MessageHandlerFunc: c.messageHandler, MessageFromCborFunc: msgFromCborFunc, StateMap: stateMap, + StateContext: stateContext, InitialState: stateIdle, } c.Protocol = protocol.New(protoConfig) diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index 8cc76ee4..c39e7ebf 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -30,7 +30,7 @@ type Server struct { } // NewServer returns a new ChainSync server object -func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { +func NewServer(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Server { // Use node-to-client protocol ID ProtocolId := ProtocolIdNtC msgFromCborFunc := NewMsgFromCborNtC @@ -56,6 +56,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { MessageHandlerFunc: s.messageHandler, MessageFromCborFunc: msgFromCborFunc, StateMap: StateMap, + StateContext: stateContext, InitialState: stateIdle, } s.Protocol = protocol.New(protoConfig) diff --git a/protocol/protocol.go b/protocol/protocol.go index b3110190..0976babb 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -19,7 +19,6 @@ import ( "bytes" "fmt" "io" - "reflect" "sync" "time" @@ -58,6 +57,7 @@ type ProtocolConfig struct { MessageHandlerFunc MessageHandlerFunc MessageFromCborFunc MessageFromCborFunc StateMap StateMap + StateContext interface{} InitialState State } @@ -495,7 +495,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) { if transition.MsgType == msg.Type() { if transition.MatchFunc != nil { // Skip item if match function returns false - if !transition.MatchFunc(msg) { + if !transition.MatchFunc(p.config.StateContext, msg) { continue } } @@ -504,8 +504,8 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) { } return State{}, fmt.Errorf( - "message %s not allowed in current protocol state %s", - reflect.TypeOf(msg).Name(), + "message %T not allowed in current protocol state %s", + msg, currentState, ) } diff --git a/protocol/state.go b/protocol/state.go index 01faafe7..c0fc7c95 100644 --- a/protocol/state.go +++ b/protocol/state.go @@ -55,7 +55,7 @@ type StateTransition struct { // StateTransitionMatchFunc represents a function that will take a Message and return a bool // that indicates whether the message is a match for the state transition rule -type StateTransitionMatchFunc func(Message) bool +type StateTransitionMatchFunc func(interface{}, Message) bool // StateMapEntry represents a protocol state, it's possible state transitions, and an optional timeout type StateMapEntry struct { diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index 00ef75ba..daf577c1 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -55,7 +55,7 @@ var StateMap = protocol.StateMap{ MsgType: MessageTypeRequestTxIds, NewState: stateTxIdsBlocking, // Match if blocking - MatchFunc: func(msg protocol.Message) bool { + MatchFunc: func(context interface{}, msg protocol.Message) bool { msgRequestTxIds := msg.(*MsgRequestTxIds) return msgRequestTxIds.Blocking }, @@ -64,7 +64,7 @@ var StateMap = protocol.StateMap{ MsgType: MessageTypeRequestTxIds, NewState: stateTxIdsNonblocking, // Metch if non-blocking - MatchFunc: func(msg protocol.Message) bool { + MatchFunc: func(context interface{}, msg protocol.Message) bool { msgRequestTxIds := msg.(*MsgRequestTxIds) return !msgRequestTxIds.Blocking },