Skip to content

Commit

Permalink
Merge pull request #88 from cloudstruct/feature/explicit-protocol-start
Browse files Browse the repository at this point in the history
feat!: require explicitly starting each protocol
  • Loading branch information
agaffney authored Jun 5, 2022
2 parents 3b54dac + d024f6b commit 7890fe3
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 105 deletions.
14 changes: 7 additions & 7 deletions cmd/go-ouroboros-network/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ func testChainSync(f *globalFlags) {
conn := createClientConnection(f)
errorChan := make(chan error)
oOpts := &ouroboros.OuroborosOptions{
Conn: conn,
NetworkMagic: uint32(f.networkMagic),
ErrorChan: errorChan,
UseNodeToNodeProtocol: f.ntnProto,
SendKeepAlives: true,
ChainSyncCallbackConfig: buildChainSyncCallbackConfig(),
BlockFetchCallbackConfig: buildBlockFetchCallbackConfig(),
Conn: conn,
NetworkMagic: uint32(f.networkMagic),
ErrorChan: errorChan,
UseNodeToNodeProtocol: f.ntnProto,
SendKeepAlives: true,
}
go func() {
for {
Expand All @@ -117,6 +115,8 @@ func testChainSync(f *globalFlags) {
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}
o.ChainSync.Start(buildChainSyncCallbackConfig())
o.BlockFetch.Start(buildBlockFetchCallbackConfig())

syncState.oConn = o
syncState.readyForNextBlockChan = make(chan bool)
Expand Down
12 changes: 6 additions & 6 deletions cmd/go-ouroboros-network/localtxsubmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ func testLocalTxSubmission(f *globalFlags) {
conn := createClientConnection(f)
errorChan := make(chan error)
oOpts := &ouroboros.OuroborosOptions{
Conn: conn,
NetworkMagic: uint32(f.networkMagic),
ErrorChan: errorChan,
UseNodeToNodeProtocol: f.ntnProto,
SendKeepAlives: true,
LocalTxSubmissionCallbackConfig: buildLocalTxSubmissionCallbackConfig(),
Conn: conn,
NetworkMagic: uint32(f.networkMagic),
ErrorChan: errorChan,
UseNodeToNodeProtocol: f.ntnProto,
SendKeepAlives: true,
}
go func() {
for {
Expand All @@ -70,6 +69,7 @@ func testLocalTxSubmission(f *globalFlags) {
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}
o.LocalTxSubmission.Start(buildLocalTxSubmissionCallbackConfig())

txData, err := ioutil.ReadFile(localTxSubmissionFlags.txFile)
if err != nil {
Expand Down
77 changes: 31 additions & 46 deletions ouroboros.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,51 +26,35 @@ type Ouroboros struct {
sendKeepAlives bool
delayMuxerStart bool
// Mini-protocols
Handshake *handshake.Handshake
ChainSync *chainsync.ChainSync
chainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig
BlockFetch *blockfetch.BlockFetch
blockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig
KeepAlive *keepalive.KeepAlive
keepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig
LocalTxSubmission *localtxsubmission.LocalTxSubmission
localTxSubmissionCallbackConfig *localtxsubmission.CallbackConfig
LocalStateQuery *localstatequery.LocalStateQuery
localStateQueryCallbackConfig *localstatequery.CallbackConfig
TxSubmission *txsubmission.TxSubmission
txSubmissionCallbackConfig *txsubmission.CallbackConfig
Handshake *handshake.Handshake
ChainSync *chainsync.ChainSync
BlockFetch *blockfetch.BlockFetch
KeepAlive *keepalive.KeepAlive
LocalTxSubmission *localtxsubmission.LocalTxSubmission
LocalStateQuery *localstatequery.LocalStateQuery
TxSubmission *txsubmission.TxSubmission
}

type OuroborosOptions struct {
Conn net.Conn
NetworkMagic uint32
ErrorChan chan error
Server bool
UseNodeToNodeProtocol bool
SendKeepAlives bool
DelayMuxerStart bool
ChainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig
BlockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig
KeepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig
LocalTxSubmissionCallbackConfig *localtxsubmission.CallbackConfig
LocalStateQueryCallbackConfig *localstatequery.CallbackConfig
Conn net.Conn
NetworkMagic uint32
ErrorChan chan error
Server bool
UseNodeToNodeProtocol bool
SendKeepAlives bool
DelayMuxerStart bool
}

func New(options *OuroborosOptions) (*Ouroboros, error) {
o := &Ouroboros{
conn: options.Conn,
networkMagic: options.NetworkMagic,
server: options.Server,
useNodeToNodeProto: options.UseNodeToNodeProtocol,
chainSyncCallbackConfig: options.ChainSyncCallbackConfig,
blockFetchCallbackConfig: options.BlockFetchCallbackConfig,
keepAliveCallbackConfig: options.KeepAliveCallbackConfig,
localTxSubmissionCallbackConfig: options.LocalTxSubmissionCallbackConfig,
localStateQueryCallbackConfig: options.LocalStateQueryCallbackConfig,
ErrorChan: options.ErrorChan,
sendKeepAlives: options.SendKeepAlives,
delayMuxerStart: options.DelayMuxerStart,
protoErrorChan: make(chan error, 10),
conn: options.Conn,
networkMagic: options.NetworkMagic,
server: options.Server,
useNodeToNodeProto: options.UseNodeToNodeProtocol,
ErrorChan: options.ErrorChan,
sendKeepAlives: options.SendKeepAlives,
delayMuxerStart: options.DelayMuxerStart,
protoErrorChan: make(chan error, 10),
}
if o.ErrorChan == nil {
o.ErrorChan = make(chan error, 10)
Expand Down Expand Up @@ -147,6 +131,7 @@ func (o *Ouroboros) setupConnection() error {
}
// Perform handshake
o.Handshake = handshake.New(protoOptions, protoVersions)
o.Handshake.Start()
// TODO: figure out better way to signify automatic handshaking and returning the chosen version
if !o.server {
err := o.Handshake.ProposeVersions(protoVersions, o.networkMagic)
Expand Down Expand Up @@ -178,22 +163,22 @@ func (o *Ouroboros) setupConnection() error {
if o.useNodeToNodeProto {
versionNtN := GetProtocolVersionNtN(o.Handshake.Version)
protoOptions.Mode = protocol.ProtocolModeNodeToNode
o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig)
o.BlockFetch = blockfetch.New(protoOptions, o.blockFetchCallbackConfig)
o.TxSubmission = txsubmission.New(protoOptions, o.txSubmissionCallbackConfig)
o.ChainSync = chainsync.New(protoOptions)
o.BlockFetch = blockfetch.New(protoOptions)
o.TxSubmission = txsubmission.New(protoOptions)
if versionNtN.EnableKeepAliveProtocol {
o.KeepAlive = keepalive.New(protoOptions, o.keepAliveCallbackConfig)
o.KeepAlive = keepalive.New(protoOptions)
if o.sendKeepAlives {
o.KeepAlive.Start()
o.KeepAlive.Start(nil)
}
}
} else {
versionNtC := GetProtocolVersionNtC(o.Handshake.Version)
protoOptions.Mode = protocol.ProtocolModeNodeToClient
o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig)
o.LocalTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionCallbackConfig)
o.ChainSync = chainsync.New(protoOptions)
o.LocalTxSubmission = localtxsubmission.New(protoOptions)
if versionNtC.EnableLocalQueryProtocol {
o.LocalStateQuery = localstatequery.New(protoOptions, o.localStateQueryCallbackConfig)
o.LocalStateQuery = localstatequery.New(protoOptions)
}
}
// Start muxer
Expand Down
11 changes: 7 additions & 4 deletions protocol/blockfetch/blockfetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,8 @@ type BlockFetchNoBlocksFunc func() error
type BlockFetchBlockFunc func(uint, interface{}) error
type BlockFetchBatchDoneFunc func() error

func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackConfig) *BlockFetch {
b := &BlockFetch{
callbackConfig: callbackConfig,
}
func New(options protocol.ProtocolOptions) *BlockFetch {
b := &BlockFetch{}
protoConfig := protocol.ProtocolConfig{
Name: PROTOCOL_NAME,
ProtocolId: PROTOCOL_ID,
Expand All @@ -102,6 +100,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackCon
return b
}

func (b *BlockFetch) Start(callbackConfig *BlockFetchCallbackConfig) {
b.callbackConfig = callbackConfig
b.Protocol.Start()
}

func (b *BlockFetch) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand Down
11 changes: 7 additions & 4 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ type ChainSyncIntersectFoundFunc func(interface{}, interface{}) error
type ChainSyncIntersectNotFoundFunc func(interface{}) error
type ChainSyncDoneFunc func() error

func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConfig) *ChainSync {
func New(options protocol.ProtocolOptions) *ChainSync {
// Use node-to-client protocol ID
protocolId := PROTOCOL_ID_NTC
msgFromCborFunc := NewMsgFromCborNtC
Expand All @@ -117,9 +117,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf
protocolId = PROTOCOL_ID_NTN
msgFromCborFunc = NewMsgFromCborNtN
}
c := &ChainSync{
callbackConfig: callbackConfig,
}
c := &ChainSync{}
protoConfig := protocol.ProtocolConfig{
Name: PROTOCOL_NAME,
ProtocolId: protocolId,
Expand All @@ -136,6 +134,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf
return c
}

func (c *ChainSync) Start(callbackConfig *ChainSyncCallbackConfig) {
c.callbackConfig = callbackConfig
c.Protocol.Start()
}

func (c *ChainSync) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand Down
4 changes: 4 additions & 0 deletions protocol/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func New(options protocol.ProtocolOptions, allowedVersions []uint16) *Handshake
return h
}

func (h *Handshake) Start() {
h.Protocol.Start()
}

func (h *Handshake) handleMessage(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand Down
24 changes: 10 additions & 14 deletions protocol/keepalive/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ type KeepAliveFunc func(uint16) error
type KeepAliveResponseFunc func(uint16) error
type DoneFunc func() error

func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConfig) *KeepAlive {
k := &KeepAlive{
callbackConfig: callbackConfig,
}
func New(options protocol.ProtocolOptions) *KeepAlive {
k := &KeepAlive{}
protoConfig := protocol.ProtocolConfig{
Name: PROTOCOL_NAME,
ProtocolId: PROTOCOL_ID,
Expand All @@ -85,6 +83,12 @@ func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConf
return k
}

func (k *KeepAlive) Start(callbackConfig *KeepAliveCallbackConfig) {
k.callbackConfig = callbackConfig
k.Protocol.Start()
k.startTimer()
}

func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand All @@ -100,22 +104,14 @@ func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error
return err
}

func (k *KeepAlive) Start() {
func (k *KeepAlive) startTimer() {
k.timer = time.AfterFunc(KEEP_ALIVE_PERIOD*time.Second, func() {
if err := k.KeepAlive(0); err != nil {
k.SendError(err)
}
})
}

func (k *KeepAlive) Stop() {
if k.timer != nil {
k.timer.Stop()
}
// Remove timer, since we check for its presence elsewhere
k.timer = nil
}

func (k *KeepAlive) KeepAlive(cookie uint16) error {
msg := NewMsgKeepAlive(cookie)
return k.SendMessage(msg)
Expand All @@ -137,7 +133,7 @@ func (k *KeepAlive) handleKeepAliveResponse(msgGeneric protocol.Message) error {
msg := msgGeneric.(*MsgKeepAliveResponse)
// Start the timer again if we had one previously
if k.timer != nil {
defer k.Start()
defer k.startTimer()
}
if k.callbackConfig != nil && k.callbackConfig.KeepAliveResponseFunc != nil {
// Call the user callback function
Expand Down
11 changes: 7 additions & 4 deletions protocol/localstatequery/localstatequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,8 @@ type ReleaseFunc func() error
type ReAcquireFunc func(interface{}) error
type DoneFunc func() error

func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *LocalStateQuery {
l := &LocalStateQuery{
callbackConfig: callbackConfig,
}
func New(options protocol.ProtocolOptions) *LocalStateQuery {
l := &LocalStateQuery{}
protoConfig := protocol.ProtocolConfig{
Name: PROTOCOL_NAME,
ProtocolId: PROTOCOL_ID,
Expand All @@ -142,6 +140,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca
return l
}

func (l *LocalStateQuery) Start(callbackConfig *CallbackConfig) {
l.callbackConfig = callbackConfig
l.Protocol.Start()
}

func (l *LocalStateQuery) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand Down
11 changes: 7 additions & 4 deletions protocol/localtxsubmission/localtxsubmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ type AcceptTxFunc func() error
type RejectTxFunc func(interface{}) error
type DoneFunc func() error

func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *LocalTxSubmission {
l := &LocalTxSubmission{
callbackConfig: callbackConfig,
}
func New(options protocol.ProtocolOptions) *LocalTxSubmission {
l := &LocalTxSubmission{}
protoConfig := protocol.ProtocolConfig{
Name: PROTOCOL_NAME,
ProtocolId: PROTOCOL_ID,
Expand All @@ -82,6 +80,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca
return l
}

func (l *LocalTxSubmission) Start(callbackConfig *CallbackConfig) {
l.callbackConfig = callbackConfig
l.Protocol.Start()
}

func (l *LocalTxSubmission) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
Expand Down
27 changes: 15 additions & 12 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,27 @@ type MessageHandlerFunc func(Message, bool) error
type MessageFromCborFunc func(uint, []byte) (Message, error)

func New(config ProtocolConfig) *Protocol {
muxerSendChan, muxerRecvChan := config.Muxer.RegisterProtocol(config.ProtocolId)
p := &Protocol{
config: config,
muxerSendChan: muxerSendChan,
muxerRecvChan: muxerRecvChan,
recvBuffer: bytes.NewBuffer(nil),
sendQueueChan: make(chan Message, 50),
sendStateQueueChan: make(chan Message, 50),
recvReadyChan: make(chan bool, 1),
sendReadyChan: make(chan bool, 1),
doneChan: make(chan bool),
config: config,
}
return p
}

func (p *Protocol) Start() {
// Register protocol with muxer
p.muxerSendChan, p.muxerRecvChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
// Create buffers and channels
p.recvBuffer = bytes.NewBuffer(nil)
p.sendQueueChan = make(chan Message, 50)
p.sendStateQueueChan = make(chan Message, 50)
p.recvReadyChan = make(chan bool, 1)
p.sendReadyChan = make(chan bool, 1)
p.doneChan = make(chan bool)
// Set initial state
p.setState(config.InitialState)
p.setState(p.config.InitialState)
// Start our send and receive Goroutines
go p.recvLoop()
go p.sendLoop()
return p
}

func (p *Protocol) Mode() ProtocolMode {
Expand Down
Loading

0 comments on commit 7890fe3

Please sign in to comment.