From a42ba11e11fe7d600eab214d9cae18174cba552c Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Fri, 5 Apr 2024 16:00:10 -0500 Subject: [PATCH] feat!: provide context object for callback functions This provides a new protocol callback function parameter that provides the current client or server object and the connection ID. This allows easily identifying the connection that triggered a callback when using multiple connections and performing related operations within the protocol. Fixes #578 BREAKING CHANGE: this changes the prototype for protocol callback functions --- cmd/gouroboros/chainsync.go | 7 ++-- cmd/tx-submission/main.go | 4 ++- connection.go | 25 +++++--------- connection/id.go | 33 +++++++++++++++++++ protocol/blockfetch/blockfetch.go | 14 ++++++-- protocol/blockfetch/client.go | 9 +++-- protocol/blockfetch/server.go | 11 +++++-- protocol/chainsync/chainsync.go | 18 +++++++--- protocol/chainsync/client.go | 13 +++++--- protocol/chainsync/server.go | 13 +++++--- protocol/handshake/client.go | 13 +++++--- protocol/handshake/handshake.go | 12 +++++-- protocol/handshake/server.go | 11 +++++-- protocol/keepalive/client.go | 17 ++++++---- protocol/keepalive/keepalive.go | 16 ++++++--- protocol/keepalive/server.go | 14 +++++--- protocol/localstatequery/client.go | 7 +++- protocol/localstatequery/localstatequery.go | 20 +++++++---- protocol/localstatequery/server.go | 21 +++++++----- protocol/localtxmonitor/client.go | 7 +++- protocol/localtxmonitor/localtxmonitor.go | 12 +++++-- protocol/localtxmonitor/server.go | 9 +++-- protocol/localtxsubmission/client.go | 7 +++- .../localtxsubmission/localtxsubmission.go | 12 +++++-- protocol/localtxsubmission/server.go | 11 +++++-- protocol/peersharing/client.go | 11 +++++-- protocol/peersharing/peersharing.go | 12 +++++-- protocol/peersharing/server.go | 11 +++++-- protocol/protocol.go | 8 +++-- protocol/txsubmission/client.go | 14 +++++--- protocol/txsubmission/server.go | 9 +++-- protocol/txsubmission/txsubmission.go | 16 ++++++--- 32 files changed, 305 insertions(+), 112 deletions(-) create mode 100644 connection/id.go diff --git a/cmd/gouroboros/chainsync.go b/cmd/gouroboros/chainsync.go index 4fd97246..8be1f6e0 100644 --- a/cmd/gouroboros/chainsync.go +++ b/cmd/gouroboros/chainsync.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -256,12 +256,13 @@ func testChainSync(f *globalFlags) { select {} } -func chainSyncRollBackwardHandler(point common.Point, tip chainsync.Tip) error { +func chainSyncRollBackwardHandler(ctx chainsync.CallbackContext, point common.Point, tip chainsync.Tip) error { fmt.Printf("roll backward: point = %#v, tip = %#v\n", point, tip) return nil } func chainSyncRollForwardHandler( + ctx chainsync.CallbackContext, blockType uint, blockData interface{}, tip chainsync.Tip, @@ -312,7 +313,7 @@ func chainSyncRollForwardHandler( return nil } -func blockFetchBlockHandler(blockData ledger.Block) error { +func blockFetchBlockHandler(ctx blockfetch.CallbackContext, blockData ledger.Block) error { switch block := blockData.(type) { case *ledger.ByronEpochBoundaryBlock: fmt.Printf("era = Byron (EBB), epoch = %d, slot = %d, id = %s\n", block.Header.ConsensusData.Epoch, block.SlotNumber(), block.Hash()) diff --git a/cmd/tx-submission/main.go b/cmd/tx-submission/main.go index 1b603e94..76e20542 100644 --- a/cmd/tx-submission/main.go +++ b/cmd/tx-submission/main.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -145,6 +145,7 @@ func main() { } func handleRequestTxIds( + ctx txsubmission.CallbackContext, blocking bool, ack uint16, req uint16, @@ -167,6 +168,7 @@ func handleRequestTxIds( } func handleRequestTxs( + ctx txsubmission.CallbackContext, txIds []txsubmission.TxId, ) ([]txsubmission.TxBody, error) { ret := []txsubmission.TxBody{ diff --git a/connection.go b/connection.go index c6a4dbf2..44d172a6 100644 --- a/connection.go +++ b/connection.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import ( "sync" "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/muxer" "github.com/blinklabs-io/gouroboros/protocol" "github.com/blinklabs-io/gouroboros/protocol/blockfetch" @@ -49,6 +50,8 @@ const ( DefaultConnectTimeout = 30 * time.Second ) +type ConnectionId = connection.ConnectionId + // The Connection type is a wrapper around a net.Conn object that handles communication using the Ouroboros network protocol over that connection type Connection struct { id ConnectionId @@ -90,19 +93,6 @@ type Connection struct { txSubmissionConfig *txsubmission.Config } -type ConnectionId struct { - LocalAddr net.Addr - RemoteAddr net.Addr -} - -func (c ConnectionId) String() string { - return fmt.Sprintf( - "%s<->%s", - c.LocalAddr.String(), - c.RemoteAddr.String(), - ) -} - // NewConnection returns a new Connection object with the specified options. If a connection is provided, the // handshake will be started. An error will be returned if the handshake fails func NewConnection(options ...ConnectionOptionFunc) (*Connection, error) { @@ -289,8 +279,9 @@ func (c *Connection) setupConnection() error { } }() protoOptions := protocol.ProtocolOptions{ - Muxer: c.muxer, - ErrorChan: c.protoErrorChan, + ConnectionId: c.id, + Muxer: c.muxer, + ErrorChan: c.protoErrorChan, } if c.useNodeToNodeProto { protoOptions.Mode = protocol.ProtocolModeNodeToNode @@ -319,7 +310,7 @@ func (c *Connection) setupConnection() error { var handshakeFullDuplex bool handshakeConfig := handshake.NewConfig( handshake.WithProtocolVersionMap(protoVersions), - handshake.WithFinishedFunc(func(version uint16, versionData protocol.VersionData) error { + handshake.WithFinishedFunc(func(ctx handshake.CallbackContext, version uint16, versionData protocol.VersionData) error { c.handshakeVersion = version c.handshakeVersionData = versionData if c.useNodeToNodeProto { diff --git a/connection/id.go b/connection/id.go new file mode 100644 index 00000000..8179d249 --- /dev/null +++ b/connection/id.go @@ -0,0 +1,33 @@ +// Copyright 2024 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection + +import ( + "fmt" + "net" +) + +type ConnectionId struct { + LocalAddr net.Addr + RemoteAddr net.Addr +} + +func (c ConnectionId) String() string { + return fmt.Sprintf( + "%s<->%s", + c.LocalAddr.String(), + c.RemoteAddr.String(), + ) +} diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index e535c11f..6c992afa 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package blockfetch import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" "github.com/blinklabs-io/gouroboros/protocol/common" @@ -92,9 +93,16 @@ type Config struct { BlockTimeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type BlockFunc func(ledger.Block) error -type RequestRangeFunc func(common.Point, common.Point) error +type BlockFunc func(CallbackContext, ledger.Block) error +type RequestRangeFunc func(CallbackContext, common.Point, common.Point) error func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { b := &BlockFetch{ diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index bf83807c..1cc7beb3 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import ( type Client struct { *protocol.Protocol config *Config + callbackContext CallbackContext blockChan chan ledger.Block startBatchResultChan chan error busyMutex sync.Mutex @@ -46,6 +47,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { blockChan: make(chan ledger.Block), startBatchResultChan: make(chan error), } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeouts stateMap := StateMap.Copy() if entry, ok := stateMap[StateBusy]; ok { @@ -186,7 +191,7 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error { } // We use the callback when requesting ranges and the internal channel for a single block if c.blockUseCallback { - if err := c.config.BlockFunc(blk); err != nil { + if err := c.config.BlockFunc(c.callbackContext, blk); err != nil { return err } } else { diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index 2369ea38..9fe81672 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,13 +23,18 @@ import ( type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -98,7 +103,7 @@ func (s *Server) handleRequestRange(msg protocol.Message) error { ) } msgRequestRange := msg.(*MsgRequestRange) - return s.config.RequestRangeFunc(msgRequestRange.Start, msgRequestRange.End) + return s.config.RequestRangeFunc(s.callbackContext, msgRequestRange.Start, msgRequestRange.End) } func (s *Server) handleClientDone() error { diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 4df59bef..6bab15d7 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package chainsync import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" "github.com/blinklabs-io/gouroboros/protocol/common" ) @@ -121,11 +122,18 @@ type Config struct { PipelineLimit int } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type RollBackwardFunc func(common.Point, Tip) error -type RollForwardFunc func(uint, interface{}, Tip) error -type FindIntersectFunc func([]common.Point) (common.Point, Tip, error) -type RequestNextFunc func() error +type RollBackwardFunc func(CallbackContext, common.Point, Tip) error +type RollForwardFunc func(CallbackContext, uint, interface{}, Tip) error +type FindIntersectFunc func(CallbackContext, []common.Point) (common.Point, Tip, error) +type RequestNextFunc func(CallbackContext) error // New returns a new ChainSync object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync { diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 341ecade..553c5bfd 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import ( type Client struct { *protocol.Protocol config *Config + callbackContext CallbackContext busyMutex sync.Mutex intersectResultChan chan error readyForNextBlockChan chan bool @@ -63,6 +64,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { firstBlockChan: make(chan common.Point), intersectPointChan: make(chan common.Point), } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeouts stateMap := StateMap.Copy() if entry, ok := stateMap[stateIntersect]; ok { @@ -343,7 +348,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { return nil } // Call the user callback function - callbackErr = c.config.RollForwardFunc(blockType, blockHeader, msg.Tip) + callbackErr = c.config.RollForwardFunc(c.callbackContext, blockType, blockHeader, msg.Tip) } else { msg := msgGeneric.(*MsgRollForwardNtC) blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) @@ -360,7 +365,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { return nil } // Call the user callback function - callbackErr = c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip) + callbackErr = c.config.RollForwardFunc(c.callbackContext, msg.BlockType(), blk, msg.Tip) } if callbackErr != nil { if callbackErr == StopSyncProcessError { @@ -388,7 +393,7 @@ func (c *Client) handleRollBackward(msg protocol.Message) error { ) } // Call the user callback function - if callbackErr := c.config.RollBackwardFunc(msgRollBackward.Point, msgRollBackward.Tip); callbackErr != nil { + if callbackErr := c.config.RollBackwardFunc(c.callbackContext, msgRollBackward.Point, msgRollBackward.Tip); callbackErr != nil { if callbackErr == StopSyncProcessError { // Signal that we're cancelling the sync c.readyForNextBlockChan <- false diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index ab487105..8cc76ee4 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,7 +25,8 @@ import ( // Server implements the ChainSync server type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } // NewServer returns a new ChainSync server object @@ -41,6 +42,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -112,7 +117,7 @@ func (s *Server) handleRequestNext(msg protocol.Message) error { "received chain-sync RequestNext message but no callback function is defined", ) } - return s.config.RequestNextFunc() + return s.config.RequestNextFunc(s.callbackContext) } func (s *Server) handleFindIntersect(msg protocol.Message) error { @@ -122,7 +127,7 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error { ) } msgFindIntersect := msg.(*MsgFindIntersect) - point, tip, err := s.config.FindIntersectFunc(msgFindIntersect.Points) + point, tip, err := s.config.FindIntersectFunc(s.callbackContext, msgFindIntersect.Points) if err != nil { if err == IntersectNotFoundError { msgResp := NewMsgIntersectNotFound(tip) diff --git a/protocol/handshake/client.go b/protocol/handshake/client.go index 0e4cf8e2..828796ff 100644 --- a/protocol/handshake/client.go +++ b/protocol/handshake/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,8 +24,9 @@ import ( // Client implements the Handshake client type Client struct { *protocol.Protocol - config *Config - onceStart sync.Once + config *Config + callbackContext CallbackContext + onceStart sync.Once } // NewClient returns a new Handshake client object @@ -37,6 +38,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { c := &Client{ config: cfg, } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[stateConfirm]; ok { @@ -99,7 +104,7 @@ func (c *Client) handleAcceptVersion(msg protocol.Message) error { if err != nil { return err } - return c.config.FinishedFunc(msgAcceptVersion.Version, versionData) + return c.config.FinishedFunc(c.callbackContext, msgAcceptVersion.Version, versionData) } func (c *Client) handleRefuse(msgGeneric protocol.Message) error { diff --git a/protocol/handshake/handshake.go b/protocol/handshake/handshake.go index 32cc0502..bc325f6e 100644 --- a/protocol/handshake/handshake.go +++ b/protocol/handshake/handshake.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package handshake import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -75,8 +76,15 @@ type Config struct { Timeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type FinishedFunc func(uint16, protocol.VersionData) error +type FinishedFunc func(CallbackContext, uint16, protocol.VersionData) error // New returns a new Handshake object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *Handshake { diff --git a/protocol/handshake/server.go b/protocol/handshake/server.go index 94bdfbb9..b0a58ede 100644 --- a/protocol/handshake/server.go +++ b/protocol/handshake/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,8 @@ import ( // Server implements the Handshake server type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } // NewServer returns a new Handshake server object @@ -32,6 +33,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -153,5 +158,5 @@ func (s *Server) handleProposeVersions(msg protocol.Message) error { if err := s.SendMessage(msgAcceptVersion); err != nil { return err } - return s.config.FinishedFunc(proposedVersion, proposedVersionData) + return s.config.FinishedFunc(s.callbackContext, proposedVersion, proposedVersionData) } diff --git a/protocol/keepalive/client.go b/protocol/keepalive/client.go index 4037917c..fd219add 100644 --- a/protocol/keepalive/client.go +++ b/protocol/keepalive/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,10 +24,11 @@ import ( type Client struct { *protocol.Protocol - config *Config - timer *time.Timer - timerMutex sync.Mutex - onceStart sync.Once + config *Config + callbackContext CallbackContext + timer *time.Timer + timerMutex sync.Mutex + onceStart sync.Once } func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { @@ -38,6 +39,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { c := &Client{ config: cfg, } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[StateServer]; ok { @@ -125,7 +130,7 @@ func (c *Client) handleKeepAliveResponse(msgGeneric protocol.Message) error { } if c.config != nil && c.config.KeepAliveResponseFunc != nil { // Call the user callback function - return c.config.KeepAliveResponseFunc(msg.Cookie) + return c.config.KeepAliveResponseFunc(c.callbackContext, msg.Cookie) } return nil } diff --git a/protocol/keepalive/keepalive.go b/protocol/keepalive/keepalive.go index 310d8cdb..681ebf82 100644 --- a/protocol/keepalive/keepalive.go +++ b/protocol/keepalive/keepalive.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package keepalive import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -79,10 +80,17 @@ type Config struct { Cookie uint16 } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type KeepAliveFunc func(uint16) error -type KeepAliveResponseFunc func(uint16) error -type DoneFunc func() error +type KeepAliveFunc func(CallbackContext, uint16) error +type KeepAliveResponseFunc func(CallbackContext, uint16) error +type DoneFunc func(CallbackContext) error func New(protoOptions protocol.ProtocolOptions, cfg *Config) *KeepAlive { k := &KeepAlive{ diff --git a/protocol/keepalive/server.go b/protocol/keepalive/server.go index cc238909..ed46786e 100644 --- a/protocol/keepalive/server.go +++ b/protocol/keepalive/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,18 +16,24 @@ package keepalive import ( "fmt" + "github.com/blinklabs-io/gouroboros/protocol" ) type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -65,7 +71,7 @@ func (s *Server) handleKeepAlive(msgGeneric protocol.Message) error { msg := msgGeneric.(*MsgKeepAlive) if s.config != nil && s.config.KeepAliveFunc != nil { // Call the user callback function - return s.config.KeepAliveFunc(msg.Cookie) + return s.config.KeepAliveFunc(s.callbackContext, msg.Cookie) } else { // Send the keep-alive response resp := NewMsgKeepAliveResponse(msg.Cookie) @@ -76,7 +82,7 @@ func (s *Server) handleKeepAlive(msgGeneric protocol.Message) error { func (s *Server) handleDone() error { if s.config != nil && s.config.DoneFunc != nil { // Call the user callback function - return s.config.DoneFunc() + return s.config.DoneFunc(s.callbackContext) } return nil } diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index fc54ace9..2719efe3 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import ( type Client struct { *protocol.Protocol config *Config + callbackContext CallbackContext enableGetChainBlockNo bool enableGetChainPoint bool enableGetRewardInfoPoolsBlock bool @@ -52,6 +53,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { acquired: false, currentEra: -1, } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeouts stateMap := StateMap.Copy() if entry, ok := stateMap[stateAcquiring]; ok { diff --git a/protocol/localstatequery/localstatequery.go b/protocol/localstatequery/localstatequery.go index 1f7fa672..5c584413 100644 --- a/protocol/localstatequery/localstatequery.go +++ b/protocol/localstatequery/localstatequery.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package localstatequery import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -119,13 +120,20 @@ type Config struct { QueryTimeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types // TODO: update callbacks -type AcquireFunc func(interface{}) error -type QueryFunc func(interface{}) error -type ReleaseFunc func() error -type ReAcquireFunc func(interface{}) error -type DoneFunc func() error +type AcquireFunc func(CallbackContext, interface{}) error +type QueryFunc func(CallbackContext, interface{}) error +type ReleaseFunc func(CallbackContext) error +type ReAcquireFunc func(CallbackContext, interface{}) error +type DoneFunc func(CallbackContext) error // New returns a new LocalStateQuery object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalStateQuery { diff --git a/protocol/localstatequery/server.go b/protocol/localstatequery/server.go index e10f3fbf..a555b1d3 100644 --- a/protocol/localstatequery/server.go +++ b/protocol/localstatequery/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( type Server struct { *protocol.Protocol config *Config + callbackContext CallbackContext enableGetChainBlockNo bool enableGetChainPoint bool enableGetRewardInfoPoolsBlock bool @@ -34,6 +35,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -94,10 +99,10 @@ func (s *Server) handleAcquire(msg protocol.Message) error { switch msgAcquire := msg.(type) { case *MsgAcquire: // Call the user callback function - return s.config.AcquireFunc(msgAcquire.Point) + return s.config.AcquireFunc(s.callbackContext, msgAcquire.Point) case *MsgAcquireNoPoint: // Call the user callback function - return s.config.AcquireFunc(nil) + return s.config.AcquireFunc(s.callbackContext, nil) } return nil } @@ -110,7 +115,7 @@ func (s *Server) handleQuery(msg protocol.Message) error { } msgQuery := msg.(*MsgQuery) // Call the user callback function - return s.config.QueryFunc(msgQuery.Query) + return s.config.QueryFunc(s.callbackContext, msgQuery.Query) } func (s *Server) handleRelease() error { @@ -120,7 +125,7 @@ func (s *Server) handleRelease() error { ) } // Call the user callback function - return s.config.ReleaseFunc() + return s.config.ReleaseFunc(s.callbackContext) } func (s *Server) handleReAcquire(msg protocol.Message) error { @@ -132,10 +137,10 @@ func (s *Server) handleReAcquire(msg protocol.Message) error { switch msgReAcquire := msg.(type) { case *MsgReAcquire: // Call the user callback function - return s.config.ReAcquireFunc(msgReAcquire.Point) + return s.config.ReAcquireFunc(s.callbackContext, msgReAcquire.Point) case *MsgReAcquireNoPoint: // Call the user callback function - return s.config.ReAcquireFunc(nil) + return s.config.ReAcquireFunc(s.callbackContext, nil) } return nil } @@ -147,5 +152,5 @@ func (s *Server) handleDone() error { ) } // Call the user callback function - return s.config.DoneFunc() + return s.config.DoneFunc(s.callbackContext) } diff --git a/protocol/localtxmonitor/client.go b/protocol/localtxmonitor/client.go index dd574618..9b9df33a 100644 --- a/protocol/localtxmonitor/client.go +++ b/protocol/localtxmonitor/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import ( type Client struct { *protocol.Protocol config *Config + callbackContext CallbackContext busyMutex sync.Mutex acquired bool acquiredSlot uint64 @@ -49,6 +50,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { nextTxResultChan: make(chan []byte), getSizesResultChan: make(chan MsgReplyGetSizesResult), } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[stateAcquiring]; ok { diff --git a/protocol/localtxmonitor/localtxmonitor.go b/protocol/localtxmonitor/localtxmonitor.go index 1d06e4be..d90e72a2 100644 --- a/protocol/localtxmonitor/localtxmonitor.go +++ b/protocol/localtxmonitor/localtxmonitor.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package localtxmonitor import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -127,8 +128,15 @@ type TxAndEraId struct { txObj ledger.Transaction } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type GetMempoolFunc func() (uint64, uint32, []TxAndEraId, error) +type GetMempoolFunc func(CallbackContext) (uint64, uint32, []TxAndEraId, error) // New returns a new LocalTxMonitor object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalTxMonitor { diff --git a/protocol/localtxmonitor/server.go b/protocol/localtxmonitor/server.go index 5497b6fa..498f9776 100644 --- a/protocol/localtxmonitor/server.go +++ b/protocol/localtxmonitor/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import ( type Server struct { *protocol.Protocol config *Config + callbackContext CallbackContext mempoolCapacity uint32 mempoolTxs []TxAndEraId mempoolNextTxIdx int @@ -36,6 +37,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -84,7 +89,7 @@ func (s *Server) handleAcquire() error { ) } // Call the user callback function to get mempool information - mempoolSlotNumber, mempoolCapacity, mempoolTxs, err := s.config.GetMempoolFunc() + mempoolSlotNumber, mempoolCapacity, mempoolTxs, err := s.config.GetMempoolFunc(s.callbackContext) if err != nil { return err } diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index df61f5db..118364c7 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import ( type Client struct { *protocol.Protocol config *Config + callbackContext CallbackContext busyMutex sync.Mutex submitResultChan chan error onceStart sync.Once @@ -42,6 +43,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { config: cfg, submitResultChan: make(chan error), } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[stateBusy]; ok { diff --git a/protocol/localtxsubmission/localtxsubmission.go b/protocol/localtxsubmission/localtxsubmission.go index e4c9be07..91ef8c4e 100644 --- a/protocol/localtxsubmission/localtxsubmission.go +++ b/protocol/localtxsubmission/localtxsubmission.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package localtxsubmission import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -74,8 +75,15 @@ type Config struct { Timeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type SubmitTxFunc func(interface{}) error +type SubmitTxFunc func(CallbackContext, interface{}) error // New returns a new LocalTxSubmission object func New( diff --git a/protocol/localtxsubmission/server.go b/protocol/localtxsubmission/server.go index 67d10d45..98c181fb 100644 --- a/protocol/localtxsubmission/server.go +++ b/protocol/localtxsubmission/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,8 @@ import ( // Server implements the LocalTxSubmission server type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } // NewServer returns a new Server object @@ -32,6 +33,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -73,7 +78,7 @@ func (s *Server) handleSubmitTx(msg protocol.Message) error { } msgSubmitTx := msg.(*MsgSubmitTx) // Call the user callback function and send Accept/RejectTx based on result - err := s.config.SubmitTxFunc(msgSubmitTx.Transaction) + err := s.config.SubmitTxFunc(s.callbackContext, msgSubmitTx.Transaction) if err == nil { newMsg := NewMsgAcceptTx() if err := s.SendMessage(newMsg); err != nil { diff --git a/protocol/peersharing/client.go b/protocol/peersharing/client.go index 61e2e8d4..a960e4e1 100644 --- a/protocol/peersharing/client.go +++ b/protocol/peersharing/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,8 +23,9 @@ import ( // Client implements the PeerSharing client type Client struct { *protocol.Protocol - config *Config - sharePeersChan chan []PeerAddress + config *Config + callbackContext CallbackContext + sharePeersChan chan []PeerAddress } // NewClient returns a new PeerSharing client object @@ -37,6 +38,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { config: cfg, sharePeersChan: make(chan []PeerAddress), } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[stateBusy]; ok { diff --git a/protocol/peersharing/peersharing.go b/protocol/peersharing/peersharing.go index ab76201c..45d76f33 100644 --- a/protocol/peersharing/peersharing.go +++ b/protocol/peersharing/peersharing.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package peersharing import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -74,8 +75,15 @@ type Config struct { Timeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type ShareRequestFunc func(int) ([]PeerAddress, error) +type ShareRequestFunc func(CallbackContext, int) ([]PeerAddress, error) // New returns a new PeerSharing object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *PeerSharing { diff --git a/protocol/peersharing/server.go b/protocol/peersharing/server.go index 6a48c36b..a2ab6bba 100644 --- a/protocol/peersharing/server.go +++ b/protocol/peersharing/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,8 @@ import ( // Server implements the PeerSharing server type Server struct { *protocol.Protocol - config *Config + config *Config + callbackContext CallbackContext } // NewServer returns a new PeerSharing server object @@ -31,6 +32,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { s := &Server{ config: cfg, } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -71,7 +76,7 @@ func (s *Server) handleShareRequest(msg protocol.Message) error { ) } msgShareRequest := msg.(*MsgShareRequest) - peers, err := s.config.ShareRequestFunc(int(msgShareRequest.Amount)) + peers, err := s.config.ShareRequestFunc(s.callbackContext, int(msgShareRequest.Amount)) if err != nil { return err } diff --git a/protocol/protocol.go b/protocol/protocol.go index 287629b9..b3110190 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -24,6 +24,7 @@ import ( "time" "github.com/blinklabs-io/gouroboros/cbor" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/muxer" "github.com/blinklabs-io/gouroboros/utils" ) @@ -81,9 +82,10 @@ const ( // ProtocolOptions provides common arguments for all mini-protocols type ProtocolOptions struct { - Muxer *muxer.Muxer - ErrorChan chan error - Mode ProtocolMode + ConnectionId connection.ConnectionId + Muxer *muxer.Muxer + ErrorChan chan error + Mode ProtocolMode // TODO: remove me Role ProtocolRole Version uint16 diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index f0f23fa5..3fa09244 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,8 +24,9 @@ import ( // Client implements the TxSubmission client type Client struct { *protocol.Protocol - config *Config - onceInit sync.Once + config *Config + callbackContext CallbackContext + onceInit sync.Once } // NewClient returns a new TxSubmission client object @@ -37,6 +38,10 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { c := &Client{ config: cfg, } + c.callbackContext = CallbackContext{ + Client: c, + ConnectionId: protoOptions.ConnectionId, + } // Update state map with timeout stateMap := StateMap.Copy() if entry, ok := stateMap[stateIdle]; ok { @@ -95,6 +100,7 @@ func (c *Client) handleRequestTxIds(msg protocol.Message) error { msgRequestTxIds := msg.(*MsgRequestTxIds) // Call the user callback function txIds, err := c.config.RequestTxIdsFunc( + c.callbackContext, msgRequestTxIds.Blocking, msgRequestTxIds.Ack, msgRequestTxIds.Req, @@ -117,7 +123,7 @@ func (c *Client) handleRequestTxs(msg protocol.Message) error { } msgRequestTxs := msg.(*MsgRequestTxs) // Call the user callback function - txs, err := c.config.RequestTxsFunc(msgRequestTxs.TxIds) + txs, err := c.config.RequestTxsFunc(c.callbackContext, msgRequestTxs.TxIds) if err != nil { return err } diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index b7a8af9d..a01998e6 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import ( type Server struct { *protocol.Protocol config *Config + callbackContext CallbackContext ackCount int stateDone bool requestTxIdsResultChan chan []TxIdAndSize @@ -39,6 +40,10 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { requestTxIdsResultChan: make(chan []TxIdAndSize), requestTxsResultChan: make(chan []TxBody), } + s.callbackContext = CallbackContext{ + Server: s, + ConnectionId: protoOptions.ConnectionId, + } protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, ProtocolId: ProtocolId, @@ -153,5 +158,5 @@ func (s *Server) handleInit() error { ) } // Call the user callback function - return s.config.InitFunc() + return s.config.InitFunc(s.callbackContext) } diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index cd8e870e..00ef75ba 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package txsubmission import ( "time" + "github.com/blinklabs-io/gouroboros/connection" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -124,10 +125,17 @@ type Config struct { IdleTimeout time.Duration } +// Callback context +type CallbackContext struct { + ConnectionId connection.ConnectionId + Client *Client + Server *Server +} + // Callback function types -type RequestTxIdsFunc func(bool, uint16, uint16) ([]TxIdAndSize, error) -type RequestTxsFunc func([]TxId) ([]TxBody, error) -type InitFunc func() error +type RequestTxIdsFunc func(CallbackContext, bool, uint16, uint16) ([]TxIdAndSize, error) +type RequestTxsFunc func(CallbackContext, []TxId) ([]TxBody, error) +type InitFunc func(CallbackContext) error // New returns a new TxSubmission object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *TxSubmission {