Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added state context to state transitions and enabled chainsync pipelining #585

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 99 additions & 12 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package chainsync

import (
"sync"
"time"

"github.com/blinklabs-io/gouroboros/connection"
Expand Down Expand Up @@ -44,8 +45,9 @@ var StateMap = protocol.StateMap{
Agency: protocol.AgencyClient,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MatchFunc: IncrementPipelineCount,
},
{
MsgType: MessageTypeFindIntersect,
Expand All @@ -60,17 +62,34 @@ var StateMap = protocol.StateMap{
stateCanAwait: protocol.StateMapEntry{
Agency: protocol.AgencyServer,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRequestNext,
NewState: stateCanAwait,
MatchFunc: IncrementPipelineCount,
},
{
MsgType: MessageTypeAwaitReply,
NewState: stateMustReply,
},
{
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
},
},
Expand All @@ -91,12 +110,24 @@ var StateMap = protocol.StateMap{
Agency: protocol.AgencyServer,
Transitions: []protocol.StateTransition{
{
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MsgType: MessageTypeRollForward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateIdle,
MatchFunc: DecrementPipelineCountAndIsEmpty,
},
{
MsgType: MessageTypeRollBackward,
NewState: stateCanAwait,
MatchFunc: DecrementPipelineCountAndIsNotEmpty,
},
},
},
Expand All @@ -105,6 +136,60 @@ var StateMap = protocol.StateMap{
},
}

type StateContext struct {
mu sync.Mutex
pipelineCount int
}

var IncrementPipelineCount = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

s.pipelineCount++
return true
}

var DecrementPipelineCountAndIsEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

if s.pipelineCount == 1 {
s.pipelineCount--
return true
}
return false
}

var DecrementPipelineCountAndIsNotEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

if s.pipelineCount > 1 {
s.pipelineCount--
return true
}
return false
}

var PipelineIsEmtpy = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

return s.pipelineCount == 0
}

var PipelineIsNotEmpty = func(context interface{}, msg protocol.Message) bool {
s := context.(*StateContext)
s.mu.Lock()
defer s.mu.Unlock()

return s.pipelineCount > 0
}

// ChainSync is a wrapper object that holds the client and server instances
type ChainSync struct {
Client *Client
Expand Down Expand Up @@ -137,9 +222,11 @@ type RequestNextFunc func(CallbackContext) error

// New returns a new ChainSync object
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {
stateContext := &StateContext{}

c := &ChainSync{
Client: NewClient(protoOptions, cfg),
Server: NewServer(protoOptions, cfg),
Client: NewClient(stateContext, protoOptions, cfg),
Server: NewServer(stateContext, protoOptions, cfg),
}
return c
}
Expand Down
3 changes: 2 additions & 1 deletion protocol/chainsync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type Client struct {
}

// NewClient returns a new ChainSync client object
func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
func NewClient(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
Expand Down Expand Up @@ -91,6 +91,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
MessageHandlerFunc: c.messageHandler,
MessageFromCborFunc: msgFromCborFunc,
StateMap: stateMap,
StateContext: stateContext,
InitialState: stateIdle,
}
c.Protocol = protocol.New(protoConfig)
Expand Down
3 changes: 2 additions & 1 deletion protocol/chainsync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Server struct {
}

// NewServer returns a new ChainSync server object
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
func NewServer(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
Expand All @@ -56,6 +56,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
MessageHandlerFunc: s.messageHandler,
MessageFromCborFunc: msgFromCborFunc,
StateMap: StateMap,
StateContext: stateContext,
InitialState: stateIdle,
}
s.Protocol = protocol.New(protoConfig)
Expand Down
8 changes: 4 additions & 4 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"bytes"
"fmt"
"io"
"reflect"
"sync"
"time"

Expand Down Expand Up @@ -58,6 +57,7 @@ type ProtocolConfig struct {
MessageHandlerFunc MessageHandlerFunc
MessageFromCborFunc MessageFromCborFunc
StateMap StateMap
StateContext interface{}
InitialState State
}

Expand Down Expand Up @@ -495,7 +495,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) {
if transition.MsgType == msg.Type() {
if transition.MatchFunc != nil {
// Skip item if match function returns false
if !transition.MatchFunc(msg) {
if !transition.MatchFunc(p.config.StateContext, msg) {
continue
}
}
Expand All @@ -504,8 +504,8 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) {
}

return State{}, fmt.Errorf(
"message %s not allowed in current protocol state %s",
reflect.TypeOf(msg).Name(),
"message %T not allowed in current protocol state %s",
msg,
currentState,
)
}
Expand Down
2 changes: 1 addition & 1 deletion protocol/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type StateTransition struct {

// StateTransitionMatchFunc represents a function that will take a Message and return a bool
// that indicates whether the message is a match for the state transition rule
type StateTransitionMatchFunc func(Message) bool
type StateTransitionMatchFunc func(interface{}, Message) bool

// StateMapEntry represents a protocol state, it's possible state transitions, and an optional timeout
type StateMapEntry struct {
Expand Down
4 changes: 2 additions & 2 deletions protocol/txsubmission/txsubmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var StateMap = protocol.StateMap{
MsgType: MessageTypeRequestTxIds,
NewState: stateTxIdsBlocking,
// Match if blocking
MatchFunc: func(msg protocol.Message) bool {
MatchFunc: func(context interface{}, msg protocol.Message) bool {
msgRequestTxIds := msg.(*MsgRequestTxIds)
return msgRequestTxIds.Blocking
},
Expand All @@ -64,7 +64,7 @@ var StateMap = protocol.StateMap{
MsgType: MessageTypeRequestTxIds,
NewState: stateTxIdsNonblocking,
// Metch if non-blocking
MatchFunc: func(msg protocol.Message) bool {
MatchFunc: func(context interface{}, msg protocol.Message) bool {
msgRequestTxIds := msg.(*MsgRequestTxIds)
return !msgRequestTxIds.Blocking
},
Expand Down