From d024f6b2106bec3358ea1949c16ca55f48506855 Mon Sep 17 00:00:00 2001 From: Andrew Gaffney Date: Sun, 5 Jun 2022 16:29:31 -0500 Subject: [PATCH] feat!: require explicitly starting each protocol This decentralizes the protocol configuration and reduces resource usage when not using all protocols. Fixes #87 BREAKING CHANGE: protocols must now be explicitly started before they can be used --- cmd/go-ouroboros-network/chainsync.go | 14 ++-- cmd/go-ouroboros-network/localtxsubmission.go | 12 +-- ouroboros.go | 77 ++++++++----------- protocol/blockfetch/blockfetch.go | 11 ++- protocol/chainsync/chainsync.go | 11 ++- protocol/handshake/handshake.go | 4 + protocol/keepalive/keepalive.go | 24 +++--- protocol/localstatequery/localstatequery.go | 11 ++- .../localtxsubmission/localtxsubmission.go | 11 ++- protocol/protocol.go | 27 ++++--- protocol/txsubmission/txsubmission.go | 11 ++- 11 files changed, 108 insertions(+), 105 deletions(-) diff --git a/cmd/go-ouroboros-network/chainsync.go b/cmd/go-ouroboros-network/chainsync.go index 9a583e23..2ec62dc7 100644 --- a/cmd/go-ouroboros-network/chainsync.go +++ b/cmd/go-ouroboros-network/chainsync.go @@ -97,13 +97,11 @@ func testChainSync(f *globalFlags) { conn := createClientConnection(f) errorChan := make(chan error) oOpts := &ouroboros.OuroborosOptions{ - Conn: conn, - NetworkMagic: uint32(f.networkMagic), - ErrorChan: errorChan, - UseNodeToNodeProtocol: f.ntnProto, - SendKeepAlives: true, - ChainSyncCallbackConfig: buildChainSyncCallbackConfig(), - BlockFetchCallbackConfig: buildBlockFetchCallbackConfig(), + Conn: conn, + NetworkMagic: uint32(f.networkMagic), + ErrorChan: errorChan, + UseNodeToNodeProtocol: f.ntnProto, + SendKeepAlives: true, } go func() { for { @@ -117,6 +115,8 @@ func testChainSync(f *globalFlags) { fmt.Printf("ERROR: %s\n", err) os.Exit(1) } + o.ChainSync.Start(buildChainSyncCallbackConfig()) + o.BlockFetch.Start(buildBlockFetchCallbackConfig()) syncState.oConn = o syncState.readyForNextBlockChan = make(chan bool) diff --git a/cmd/go-ouroboros-network/localtxsubmission.go b/cmd/go-ouroboros-network/localtxsubmission.go index 6d3a54a3..8d5fa6b1 100644 --- a/cmd/go-ouroboros-network/localtxsubmission.go +++ b/cmd/go-ouroboros-network/localtxsubmission.go @@ -51,12 +51,11 @@ func testLocalTxSubmission(f *globalFlags) { conn := createClientConnection(f) errorChan := make(chan error) oOpts := &ouroboros.OuroborosOptions{ - Conn: conn, - NetworkMagic: uint32(f.networkMagic), - ErrorChan: errorChan, - UseNodeToNodeProtocol: f.ntnProto, - SendKeepAlives: true, - LocalTxSubmissionCallbackConfig: buildLocalTxSubmissionCallbackConfig(), + Conn: conn, + NetworkMagic: uint32(f.networkMagic), + ErrorChan: errorChan, + UseNodeToNodeProtocol: f.ntnProto, + SendKeepAlives: true, } go func() { for { @@ -70,6 +69,7 @@ func testLocalTxSubmission(f *globalFlags) { fmt.Printf("ERROR: %s\n", err) os.Exit(1) } + o.LocalTxSubmission.Start(buildLocalTxSubmissionCallbackConfig()) txData, err := ioutil.ReadFile(localTxSubmissionFlags.txFile) if err != nil { diff --git a/ouroboros.go b/ouroboros.go index b6e836b6..b312f761 100644 --- a/ouroboros.go +++ b/ouroboros.go @@ -26,51 +26,35 @@ type Ouroboros struct { sendKeepAlives bool delayMuxerStart bool // Mini-protocols - Handshake *handshake.Handshake - ChainSync *chainsync.ChainSync - chainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig - BlockFetch *blockfetch.BlockFetch - blockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig - KeepAlive *keepalive.KeepAlive - keepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig - LocalTxSubmission *localtxsubmission.LocalTxSubmission - localTxSubmissionCallbackConfig *localtxsubmission.CallbackConfig - LocalStateQuery *localstatequery.LocalStateQuery - localStateQueryCallbackConfig *localstatequery.CallbackConfig - TxSubmission *txsubmission.TxSubmission - txSubmissionCallbackConfig *txsubmission.CallbackConfig + Handshake *handshake.Handshake + ChainSync *chainsync.ChainSync + BlockFetch *blockfetch.BlockFetch + KeepAlive *keepalive.KeepAlive + LocalTxSubmission *localtxsubmission.LocalTxSubmission + LocalStateQuery *localstatequery.LocalStateQuery + TxSubmission *txsubmission.TxSubmission } type OuroborosOptions struct { - Conn net.Conn - NetworkMagic uint32 - ErrorChan chan error - Server bool - UseNodeToNodeProtocol bool - SendKeepAlives bool - DelayMuxerStart bool - ChainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig - BlockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig - KeepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig - LocalTxSubmissionCallbackConfig *localtxsubmission.CallbackConfig - LocalStateQueryCallbackConfig *localstatequery.CallbackConfig + Conn net.Conn + NetworkMagic uint32 + ErrorChan chan error + Server bool + UseNodeToNodeProtocol bool + SendKeepAlives bool + DelayMuxerStart bool } func New(options *OuroborosOptions) (*Ouroboros, error) { o := &Ouroboros{ - conn: options.Conn, - networkMagic: options.NetworkMagic, - server: options.Server, - useNodeToNodeProto: options.UseNodeToNodeProtocol, - chainSyncCallbackConfig: options.ChainSyncCallbackConfig, - blockFetchCallbackConfig: options.BlockFetchCallbackConfig, - keepAliveCallbackConfig: options.KeepAliveCallbackConfig, - localTxSubmissionCallbackConfig: options.LocalTxSubmissionCallbackConfig, - localStateQueryCallbackConfig: options.LocalStateQueryCallbackConfig, - ErrorChan: options.ErrorChan, - sendKeepAlives: options.SendKeepAlives, - delayMuxerStart: options.DelayMuxerStart, - protoErrorChan: make(chan error, 10), + conn: options.Conn, + networkMagic: options.NetworkMagic, + server: options.Server, + useNodeToNodeProto: options.UseNodeToNodeProtocol, + ErrorChan: options.ErrorChan, + sendKeepAlives: options.SendKeepAlives, + delayMuxerStart: options.DelayMuxerStart, + protoErrorChan: make(chan error, 10), } if o.ErrorChan == nil { o.ErrorChan = make(chan error, 10) @@ -147,6 +131,7 @@ func (o *Ouroboros) setupConnection() error { } // Perform handshake o.Handshake = handshake.New(protoOptions, protoVersions) + o.Handshake.Start() // TODO: figure out better way to signify automatic handshaking and returning the chosen version if !o.server { err := o.Handshake.ProposeVersions(protoVersions, o.networkMagic) @@ -178,22 +163,22 @@ func (o *Ouroboros) setupConnection() error { if o.useNodeToNodeProto { versionNtN := GetProtocolVersionNtN(o.Handshake.Version) protoOptions.Mode = protocol.ProtocolModeNodeToNode - o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig) - o.BlockFetch = blockfetch.New(protoOptions, o.blockFetchCallbackConfig) - o.TxSubmission = txsubmission.New(protoOptions, o.txSubmissionCallbackConfig) + o.ChainSync = chainsync.New(protoOptions) + o.BlockFetch = blockfetch.New(protoOptions) + o.TxSubmission = txsubmission.New(protoOptions) if versionNtN.EnableKeepAliveProtocol { - o.KeepAlive = keepalive.New(protoOptions, o.keepAliveCallbackConfig) + o.KeepAlive = keepalive.New(protoOptions) if o.sendKeepAlives { - o.KeepAlive.Start() + o.KeepAlive.Start(nil) } } } else { versionNtC := GetProtocolVersionNtC(o.Handshake.Version) protoOptions.Mode = protocol.ProtocolModeNodeToClient - o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig) - o.LocalTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionCallbackConfig) + o.ChainSync = chainsync.New(protoOptions) + o.LocalTxSubmission = localtxsubmission.New(protoOptions) if versionNtC.EnableLocalQueryProtocol { - o.LocalStateQuery = localstatequery.New(protoOptions, o.localStateQueryCallbackConfig) + o.LocalStateQuery = localstatequery.New(protoOptions) } } // Start muxer diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 41ef47f9..b61d0103 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -82,10 +82,8 @@ type BlockFetchNoBlocksFunc func() error type BlockFetchBlockFunc func(uint, interface{}) error type BlockFetchBatchDoneFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackConfig) *BlockFetch { - b := &BlockFetch{ - callbackConfig: callbackConfig, - } +func New(options protocol.ProtocolOptions) *BlockFetch { + b := &BlockFetch{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: PROTOCOL_ID, @@ -102,6 +100,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackCon return b } +func (b *BlockFetch) Start(callbackConfig *BlockFetchCallbackConfig) { + b.callbackConfig = callbackConfig + b.Protocol.Start() +} + func (b *BlockFetch) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 8e2985d0..61348cc3 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -108,7 +108,7 @@ type ChainSyncIntersectFoundFunc func(interface{}, interface{}) error type ChainSyncIntersectNotFoundFunc func(interface{}) error type ChainSyncDoneFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConfig) *ChainSync { +func New(options protocol.ProtocolOptions) *ChainSync { // Use node-to-client protocol ID protocolId := PROTOCOL_ID_NTC msgFromCborFunc := NewMsgFromCborNtC @@ -117,9 +117,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf protocolId = PROTOCOL_ID_NTN msgFromCborFunc = NewMsgFromCborNtN } - c := &ChainSync{ - callbackConfig: callbackConfig, - } + c := &ChainSync{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: protocolId, @@ -136,6 +134,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf return c } +func (c *ChainSync) Start(callbackConfig *ChainSyncCallbackConfig) { + c.callbackConfig = callbackConfig + c.Protocol.Start() +} + func (c *ChainSync) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { diff --git a/protocol/handshake/handshake.go b/protocol/handshake/handshake.go index 6552f658..7b4a6174 100644 --- a/protocol/handshake/handshake.go +++ b/protocol/handshake/handshake.go @@ -75,6 +75,10 @@ func New(options protocol.ProtocolOptions, allowedVersions []uint16) *Handshake return h } +func (h *Handshake) Start() { + h.Protocol.Start() +} + func (h *Handshake) handleMessage(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { diff --git a/protocol/keepalive/keepalive.go b/protocol/keepalive/keepalive.go index 3fe74213..31b352ef 100644 --- a/protocol/keepalive/keepalive.go +++ b/protocol/keepalive/keepalive.go @@ -65,10 +65,8 @@ type KeepAliveFunc func(uint16) error type KeepAliveResponseFunc func(uint16) error type DoneFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConfig) *KeepAlive { - k := &KeepAlive{ - callbackConfig: callbackConfig, - } +func New(options protocol.ProtocolOptions) *KeepAlive { + k := &KeepAlive{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: PROTOCOL_ID, @@ -85,6 +83,12 @@ func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConf return k } +func (k *KeepAlive) Start(callbackConfig *KeepAliveCallbackConfig) { + k.callbackConfig = callbackConfig + k.Protocol.Start() + k.startTimer() +} + func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { @@ -100,7 +104,7 @@ func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error return err } -func (k *KeepAlive) Start() { +func (k *KeepAlive) startTimer() { k.timer = time.AfterFunc(KEEP_ALIVE_PERIOD*time.Second, func() { if err := k.KeepAlive(0); err != nil { k.SendError(err) @@ -108,14 +112,6 @@ func (k *KeepAlive) Start() { }) } -func (k *KeepAlive) Stop() { - if k.timer != nil { - k.timer.Stop() - } - // Remove timer, since we check for its presence elsewhere - k.timer = nil -} - func (k *KeepAlive) KeepAlive(cookie uint16) error { msg := NewMsgKeepAlive(cookie) return k.SendMessage(msg) @@ -137,7 +133,7 @@ func (k *KeepAlive) handleKeepAliveResponse(msgGeneric protocol.Message) error { msg := msgGeneric.(*MsgKeepAliveResponse) // Start the timer again if we had one previously if k.timer != nil { - defer k.Start() + defer k.startTimer() } if k.callbackConfig != nil && k.callbackConfig.KeepAliveResponseFunc != nil { // Call the user callback function diff --git a/protocol/localstatequery/localstatequery.go b/protocol/localstatequery/localstatequery.go index b538a03c..239a3cda 100644 --- a/protocol/localstatequery/localstatequery.go +++ b/protocol/localstatequery/localstatequery.go @@ -114,10 +114,8 @@ type ReleaseFunc func() error type ReAcquireFunc func(interface{}) error type DoneFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *LocalStateQuery { - l := &LocalStateQuery{ - callbackConfig: callbackConfig, - } +func New(options protocol.ProtocolOptions) *LocalStateQuery { + l := &LocalStateQuery{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: PROTOCOL_ID, @@ -142,6 +140,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca return l } +func (l *LocalStateQuery) Start(callbackConfig *CallbackConfig) { + l.callbackConfig = callbackConfig + l.Protocol.Start() +} + func (l *LocalStateQuery) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { diff --git a/protocol/localtxsubmission/localtxsubmission.go b/protocol/localtxsubmission/localtxsubmission.go index 9599a4c2..bfb7b716 100644 --- a/protocol/localtxsubmission/localtxsubmission.go +++ b/protocol/localtxsubmission/localtxsubmission.go @@ -62,10 +62,8 @@ type AcceptTxFunc func() error type RejectTxFunc func(interface{}) error type DoneFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *LocalTxSubmission { - l := &LocalTxSubmission{ - callbackConfig: callbackConfig, - } +func New(options protocol.ProtocolOptions) *LocalTxSubmission { + l := &LocalTxSubmission{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: PROTOCOL_ID, @@ -82,6 +80,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca return l } +func (l *LocalTxSubmission) Start(callbackConfig *CallbackConfig) { + l.callbackConfig = callbackConfig + l.Protocol.Start() +} + func (l *LocalTxSubmission) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { diff --git a/protocol/protocol.go b/protocol/protocol.go index a6605899..d4bedb4d 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -70,24 +70,27 @@ type MessageHandlerFunc func(Message, bool) error type MessageFromCborFunc func(uint, []byte) (Message, error) func New(config ProtocolConfig) *Protocol { - muxerSendChan, muxerRecvChan := config.Muxer.RegisterProtocol(config.ProtocolId) p := &Protocol{ - config: config, - muxerSendChan: muxerSendChan, - muxerRecvChan: muxerRecvChan, - recvBuffer: bytes.NewBuffer(nil), - sendQueueChan: make(chan Message, 50), - sendStateQueueChan: make(chan Message, 50), - recvReadyChan: make(chan bool, 1), - sendReadyChan: make(chan bool, 1), - doneChan: make(chan bool), + config: config, } + return p +} + +func (p *Protocol) Start() { + // Register protocol with muxer + p.muxerSendChan, p.muxerRecvChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId) + // Create buffers and channels + p.recvBuffer = bytes.NewBuffer(nil) + p.sendQueueChan = make(chan Message, 50) + p.sendStateQueueChan = make(chan Message, 50) + p.recvReadyChan = make(chan bool, 1) + p.sendReadyChan = make(chan bool, 1) + p.doneChan = make(chan bool) // Set initial state - p.setState(config.InitialState) + p.setState(p.config.InitialState) // Start our send and receive Goroutines go p.recvLoop() go p.sendLoop() - return p } func (p *Protocol) Mode() ProtocolMode { diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index bb080382..8081834d 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -110,10 +110,8 @@ type ReplyTxsFunc func(interface{}) error type DoneFunc func() error type HelloFunc func() error -func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *TxSubmission { - t := &TxSubmission{ - callbackConfig: callbackConfig, - } +func New(options protocol.ProtocolOptions) *TxSubmission { + t := &TxSubmission{} protoConfig := protocol.ProtocolConfig{ Name: PROTOCOL_NAME, ProtocolId: PROTOCOL_ID, @@ -130,6 +128,11 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *TxSu return t } +func (t *TxSubmission) Start(callbackConfig *CallbackConfig) { + t.callbackConfig = callbackConfig + t.Protocol.Start() +} + func (t *TxSubmission) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() {