Skip to content

Commit

Permalink
feat: added state context to state transitions and enabled chainsync …
Browse files Browse the repository at this point in the history
…pipelining (#585)
  • Loading branch information
rakshasa authored Apr 16, 2024
1 parent 1b07d4f commit 5ac94a5
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 21 deletions.
111 changes: 99 additions & 12 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package chainsync

import (
"sync"
"time"

"github.com/blinklabs-io/gouroboros/connection"
Expand Down Expand Up @@ -44,8 +45,9 @@ var StateMap = protocol.StateMap{
Agency: protocol.AgencyClient,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MatchFunc: IncrementPipelineCount,
},
{
MsgType: MessageTypeFindIntersect,
Expand All @@ -60,17 +62,34 @@ var StateMap = protocol.StateMap{
stateCanAwait: protocol.StateMapEntry{
Agency: protocol.AgencyServer,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MatchFunc: IncrementPipelineCount,
},
{
MsgType: MessageTypeAwaitReply,
NewState: stateMustReply,
},
{
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
},
},
Expand All @@ -91,12 +110,24 @@ var StateMap = protocol.StateMap{
Agency: protocol.AgencyServer,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
},
},
Expand All @@ -105,6 +136,60 @@ var StateMap = protocol.StateMap{
},
}

type StateContext struct {
mu sync.Mutex
pipelineCount int
}

var IncrementPipelineCount = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

s.pipelineCount++
return true
}

var DecrementPipelineCountAndIsEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

if s.pipelineCount == 1 {
s.pipelineCount--
return true
}
return false
}

var DecrementPipelineCountAndIsNotEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

if s.pipelineCount > 1 {
s.pipelineCount--
return true
}
return false
}

var PipelineIsEmtpy = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

return s.pipelineCount == 0
}

var PipelineIsNotEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

return s.pipelineCount > 0
}

// ChainSync is a wrapper object that holds the client and server instances
type ChainSync struct {
Client *Client
Expand Down Expand Up @@ -137,9 +222,11 @@ type RequestNextFunc func(CallbackContext) error

// New returns a new ChainSync object
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {
stateContext := &StateContext{}

c := &ChainSync{
Client: NewClient(protoOptions, cfg),
Server: NewServer(protoOptions, cfg),
Client: NewClient(stateContext, protoOptions, cfg),
Server: NewServer(stateContext, protoOptions, cfg),
}
return c
}
Expand Down
3 changes: 2 additions & 1 deletion protocol/chainsync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type Client struct {
}

// NewClient returns a new ChainSync client object
func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
func NewClient(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
Expand Down Expand Up @@ -91,6 +91,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
MessageHandlerFunc: c.messageHandler,
MessageFromCborFunc: msgFromCborFunc,
StateMap: stateMap,
StateContext: stateContext,
InitialState: stateIdle,
}
c.Protocol = protocol.New(protoConfig)
Expand Down
3 changes: 2 additions & 1 deletion protocol/chainsync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Server struct {
}

// NewServer returns a new ChainSync server object
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
func NewServer(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
Expand All @@ -56,6 +56,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
MessageHandlerFunc: s.messageHandler,
MessageFromCborFunc: msgFromCborFunc,
StateMap: StateMap,
StateContext: stateContext,
InitialState: stateIdle,
}
s.Protocol = protocol.New(protoConfig)
Expand Down
8 changes: 4 additions & 4 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"bytes"
"fmt"
"io"
"reflect"
"sync"
"time"

Expand Down Expand Up @@ -58,6 +57,7 @@ type ProtocolConfig struct {
MessageHandlerFunc MessageHandlerFunc
MessageFromCborFunc MessageFromCborFunc
StateMap StateMap
StateContext interface{}
InitialState State
}

Expand Down Expand Up @@ -495,7 +495,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) {
if transition.MsgType == msg.Type() {
if transition.MatchFunc != nil {
// Skip item if match function returns false
if !transition.MatchFunc(msg) {
if !transition.MatchFunc(p.config.StateContext, msg) {
continue
}
}
Expand All @@ -504,8 +504,8 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) {
}

return State{}, fmt.Errorf(
"message %s not allowed in current protocol state %s",
reflect.TypeOf(msg).Name(),
"message %T not allowed in current protocol state %s",
msg,
currentState,
)
}
Expand Down
2 changes: 1 addition & 1 deletion protocol/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type StateTransition struct {

// StateTransitionMatchFunc represents a function that will take a Message and return a bool
// that indicates whether the message is a match for the state transition rule
type StateTransitionMatchFunc func(Message) bool
type StateTransitionMatchFunc func(interface{}, Message) bool

// StateMapEntry represents a protocol state, it's possible state transitions, and an optional timeout
type StateMapEntry struct {
Expand Down
4 changes: 2 additions & 2 deletions protocol/txsubmission/txsubmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var StateMap = protocol.StateMap{
MsgType: MessageTypeRequestTxIds,
NewState: stateTxIdsBlocking,
// Match if blocking
MatchFunc: func(msg protocol.Message) bool {
MatchFunc: func(context interface{}, msg protocol.Message) bool {
msgRequestTxIds := msg.(*MsgRequestTxIds)
return msgRequestTxIds.Blocking
},
Expand All @@ -64,7 +64,7 @@ var StateMap = protocol.StateMap{
MsgType: MessageTypeRequestTxIds,
NewState: stateTxIdsNonblocking,
// Metch if non-blocking
MatchFunc: func(msg protocol.Message) bool {
MatchFunc: func(context interface{}, msg protocol.Message) bool {
msgRequestTxIds := msg.(*MsgRequestTxIds)
return !msgRequestTxIds.Blocking
},
Expand Down

0 comments on commit 5ac94a5

Please sign in to comment.