diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 66b1f3f9..1123d424 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "sync" + "sync/atomic" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" @@ -30,16 +31,25 @@ type Client struct { config *Config callbackContext CallbackContext busyMutex sync.Mutex - intersectResultChan chan error readyForNextBlockChan chan bool - wantCurrentTip bool - currentTipChan chan Tip - wantFirstBlock bool - firstBlockChan chan common.Point - wantIntersectPoint bool - intersectPointChan chan common.Point onceStart sync.Once onceStop sync.Once + + // waitingForCurrentTipChan will process all the requests for the current tip until the channel + // is empty. + // + // want* only processes one request per message reply received from the server. If the message + // request fails, it is the responsibility of the caller to clear the channel. + waitingForCurrentTipChan chan chan<- Tip + wantCurrentTipChan chan chan<- Tip + wantFirstBlockChan chan chan<- clientPointResult + wantIntersectFoundChan chan chan<- clientPointResult +} + +type clientPointResult struct { + tip Tip + point common.Point + error error } // NewClient returns a new ChainSync client object @@ -62,11 +72,12 @@ func NewClient( } c := &Client{ config: cfg, - intersectResultChan: make(chan error), readyForNextBlockChan: make(chan bool), - currentTipChan: make(chan Tip), - firstBlockChan: make(chan common.Point), - intersectPointChan: make(chan common.Point), + + waitingForCurrentTipChan: make(chan chan<- Tip, 20), + wantCurrentTipChan: make(chan chan<- Tip, 1), + wantFirstBlockChan: make(chan chan<- clientPointResult, 1), + wantIntersectFoundChan: make(chan chan<- clientPointResult, 1), } c.callbackContext = CallbackContext{ Client: c, @@ -108,11 +119,7 @@ func (c *Client) Start() { // Start goroutine to cleanup resources on protocol shutdown go func() { <-c.Protocol.DoneChan() - close(c.intersectResultChan) close(c.readyForNextBlockChan) - close(c.currentTipChan) - close(c.firstBlockChan) - close(c.intersectPointChan) }() }) } @@ -156,24 +163,56 @@ func (c *Client) Stop() error { // GetCurrentTip returns the current chain tip func (c *Client) GetCurrentTip() (*Tip, error) { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - c.wantCurrentTip = true - msg := NewMsgFindIntersect([]common.Point{}) - if err := c.SendMessage(msg); err != nil { - return nil, err - } - tip, ok := <-c.currentTipChan - if !ok { - return nil, protocol.ProtocolShuttingDownError - } - // Clear out intersect result channel to prevent blocking - _, ok = <-c.intersectResultChan - if !ok { - return nil, protocol.ProtocolShuttingDownError + done := atomic.Bool{} + requestResultChan := make(chan Tip, 1) + requestErrorChan := make(chan error, 1) + + go func() { + c.busyMutex.Lock() + defer c.busyMutex.Unlock() + + if done.Load() { + return + } + + currentTipChan, cancelCurrentTip := c.wantCurrentTip() + msg := NewMsgFindIntersect([]common.Point{}) + if err := c.SendMessage(msg); err != nil { + cancelCurrentTip() + requestErrorChan <- err + return + } + select { + case <-c.Protocol.DoneChan(): + case tip := <-currentTipChan: + requestResultChan <- tip + } + }() + + waitingResultChan := make(chan Tip, 1) + waitingForCurrentTipChan := c.waitingForCurrentTipChan + + for { + select { + case <-c.Protocol.DoneChan(): + done.Store(true) + return nil, protocol.ProtocolShuttingDownError + case waitingForCurrentTipChan <- waitingResultChan: + // The request is being handled by another request, wait for the result. + waitingForCurrentTipChan = nil + case tip := <-waitingResultChan: + // The result from the other request is ready. + done.Store(true) + return &tip, nil + case tip := <-requestResultChan: + // If waitingForCurrentTipChan is full, the for loop that empties it might finish the + // loop before the select statement that writes to it is triggered. For that reason we + // require requestResultChan here. + return &tip, nil + case err := <-requestErrorChan: + return nil, err + } } - c.wantCurrentTip = false - return &tip, nil } // GetAvailableBlockRange returns the start and end of the range of available blocks given the provided intersect @@ -183,42 +222,32 @@ func (c *Client) GetAvailableBlockRange( ) (common.Point, common.Point, error) { c.busyMutex.Lock() defer c.busyMutex.Unlock() - var start, end common.Point + // Find our chain intersection - c.wantCurrentTip = true - c.wantIntersectPoint = true - msgFindIntersect := NewMsgFindIntersect(intersectPoints) - if err := c.SendMessage(msgFindIntersect); err != nil { - return start, end, err - } - gotIntersectResult := false - for { - select { - case <-c.DoneChan(): - return start, end, protocol.ProtocolShuttingDownError - case tip := <-c.currentTipChan: - end = tip.Point - c.wantCurrentTip = false - case point := <-c.intersectPointChan: - start = point - c.wantIntersectPoint = false - case err := <-c.intersectResultChan: - if err != nil { - return start, end, err - } - gotIntersectResult = true - } - if !c.wantIntersectPoint && !c.wantCurrentTip && gotIntersectResult { - break - } + result := c.requestFindIntersect(intersectPoints) + if result.error != nil { + return common.Point{}, common.Point{}, result.error } + start := result.point + end := result.tip.Point + // If we're already at the chain tip, return an empty range if start.Slot >= end.Slot { return common.Point{}, common.Point{}, nil } + // Request the next block to get the first block after the intersect point. This should result in a rollback - c.wantCurrentTip = true - c.wantFirstBlock = true + currentTipChan, cancelCurrentTip := c.wantCurrentTip() + firstBlockChan, cancelFirstBlock := c.wantFirstBlock() + defer func() { + if currentTipChan != nil { + cancelCurrentTip() + } + if firstBlockChan != nil { + cancelFirstBlock() + } + }() + msgRequestNext := NewMsgRequestNext() if err := c.SendMessage(msgRequestNext); err != nil { return start, end, err @@ -227,12 +256,15 @@ func (c *Client) GetAvailableBlockRange( select { case <-c.DoneChan(): return start, end, protocol.ProtocolShuttingDownError - case tip := <-c.currentTipChan: + case tip := <-currentTipChan: + currentTipChan = nil end = tip.Point - c.wantCurrentTip = false - case point := <-c.firstBlockChan: - start = point - c.wantFirstBlock = false + case firstBlock := <-firstBlockChan: + firstBlockChan = nil + if firstBlock.error != nil { + return start, end, fmt.Errorf("failed to get first block: %w", firstBlock.error) + } + start = firstBlock.point case <-c.readyForNextBlockChan: // Request the next block msg := NewMsgRequestNext() @@ -240,7 +272,7 @@ func (c *Client) GetAvailableBlockRange( return start, end, err } } - if !c.wantFirstBlock && !c.wantCurrentTip { + if currentTipChan == nil && firstBlockChan == nil { break } } @@ -260,15 +292,22 @@ func (c *Client) Sync(intersectPoints []common.Point) error { if len(intersectPoints) == 0 { intersectPoints = []common.Point{common.NewPointOrigin()} } + + intersectResultChan, cancel := c.wantIntersectFound() msg := NewMsgFindIntersect(intersectPoints) if err := c.SendMessage(msg); err != nil { + cancel() return err } - if err, ok := <-c.intersectResultChan; !ok { + select { + case <-c.Protocol.DoneChan(): return protocol.ProtocolShuttingDownError - } else if err != nil { - return err + case result := <-intersectResultChan: + if result.error != nil { + return result.error + } } + // Pipeline the initial block requests to speed things up a bit // Using a value higher than 10 seems to cause problems with NtN for i := 0; i <= c.config.PipelineLimit; i++ { @@ -304,13 +343,109 @@ func (c *Client) syncLoop() { } } +func (c *Client) sendCurrentTip(tip Tip) { + // Sends to the requester. + select { + case ch := <-c.wantCurrentTipChan: + ch <- tip + default: + } + + // Sends to all passive listeners that are in the queue. + for { + select { + case ch := <-c.waitingForCurrentTipChan: + ch <- tip + default: + return + } + } +} + +// wantCurrentTip returns a channel that will receive the current tip, and a function that can be +// used to clear the channel if sending the request message fails. +func (c *Client) wantCurrentTip() (<-chan Tip, func()) { + ch := make(chan Tip, 1) + + select { + case <-c.Protocol.DoneChan(): + return nil, func() {} + case c.wantCurrentTipChan <- ch: + return ch, func() { + select { + case <-c.wantCurrentTipChan: + default: + } + } + } +} + +// wantFirstBlock returns a channel that will receive the first block after the current tip, and a +// function that can be used to clear the channel if sending the request message fails. +func (c *Client) wantFirstBlock() (<-chan clientPointResult, func()) { + ch := make(chan clientPointResult, 1) + + select { + case <-c.Protocol.DoneChan(): + return nil, func() {} + case c.wantFirstBlockChan <- ch: + return ch, func() { + select { + case <-c.wantFirstBlockChan: + default: + } + } + } +} + +// wantIntersectFound returns a channel that will receive the result of the next intersect request, +// and a function that can be used to clear the channel if sending the request message fails. +func (c *Client) wantIntersectFound() (<-chan clientPointResult, func()) { + ch := make(chan clientPointResult, 1) + + select { + case <-c.Protocol.DoneChan(): + return nil, func() {} + case c.wantIntersectFoundChan <- ch: + return ch, func() { + select { + case <-c.wantIntersectFoundChan: + default: + } + } + } +} + +func (c *Client) requestFindIntersect(intersectPoints []common.Point) clientPointResult { + resultChan, cancel := c.wantIntersectFound() + msg := NewMsgFindIntersect(intersectPoints) + if err := c.SendMessage(msg); err != nil { + cancel() + return clientPointResult{error: err} + } + + select { + case <-c.Protocol.DoneChan(): + return clientPointResult{error: protocol.ProtocolShuttingDownError} + case result := <-resultChan: + return result + } +} + func (c *Client) handleAwaitReply() error { return nil } func (c *Client) handleRollForward(msgGeneric protocol.Message) error { - if (c.config == nil || c.config.RollForwardFunc == nil) && - !c.wantFirstBlock { + firstBlockChan := func() chan<- clientPointResult { + select { + case ch := <-c.wantFirstBlockChan: + return ch + default: + return nil + } + }() + if firstBlockChan == nil && (c.config == nil || c.config.RollForwardFunc == nil) { return fmt.Errorf( "received chain-sync RollForward message but no callback function is defined", ) @@ -318,9 +453,12 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { var callbackErr error if c.Mode() == protocol.ProtocolModeNodeToNode { msg := msgGeneric.(*MsgRollForwardNtN) + c.sendCurrentTip(msg.Tip) + var blockHeader ledger.BlockHeader var blockType uint blockEra := msg.WrappedHeader.Era + switch blockEra { case ledger.BlockHeaderTypeByron: blockType = msg.WrappedHeader.ByronType() @@ -330,6 +468,9 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { msg.WrappedHeader.HeaderCbor(), ) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } default: @@ -341,16 +482,20 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { msg.WrappedHeader.HeaderCbor(), ) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } } - if c.wantFirstBlock { + if firstBlockChan != nil { blockHash, err := hex.DecodeString(blockHeader.Hash()) if err != nil { + firstBlockChan <- clientPointResult{error: err} return err } point := common.NewPoint(blockHeader.SlotNumber(), blockHash) - c.firstBlockChan <- point + firstBlockChan <- clientPointResult{tip: msg.Tip, point: point} return nil } // Call the user callback function @@ -362,17 +507,23 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { ) } else { msg := msgGeneric.(*MsgRollForwardNtC) + c.sendCurrentTip(msg.Tip) + blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } - if c.wantFirstBlock { + if firstBlockChan != nil { blockHash, err := hex.DecodeString(blk.Hash()) if err != nil { + firstBlockChan <- clientPointResult{error: err} return err } point := common.NewPoint(blk.SlotNumber(), blockHash) - c.firstBlockChan <- point + firstBlockChan <- clientPointResult{tip: msg.Tip, point: point} return nil } // Call the user callback function @@ -394,10 +545,8 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { func (c *Client) handleRollBackward(msg protocol.Message) error { msgRollBackward := msg.(*MsgRollBackward) - if c.wantCurrentTip { - c.currentTipChan <- msgRollBackward.Tip - } - if !c.wantFirstBlock { + c.sendCurrentTip(msgRollBackward.Tip) + if len(c.wantFirstBlockChan) == 0 { if c.config.RollBackwardFunc == nil { return fmt.Errorf( "received chain-sync RollBackward message but no callback function is defined", @@ -421,21 +570,24 @@ func (c *Client) handleRollBackward(msg protocol.Message) error { func (c *Client) handleIntersectFound(msg protocol.Message) error { msgIntersectFound := msg.(*MsgIntersectFound) - if c.wantCurrentTip { - c.currentTipChan <- msgIntersectFound.Tip - } - if c.wantIntersectPoint { - c.intersectPointChan <- msgIntersectFound.Point + c.sendCurrentTip(msgIntersectFound.Tip) + + select { + case ch := <-c.wantIntersectFoundChan: + ch <- clientPointResult{tip: msgIntersectFound.Tip, point: msgIntersectFound.Point} + default: } - c.intersectResultChan <- nil return nil } func (c *Client) handleIntersectNotFound(msgGeneric protocol.Message) error { - if c.wantCurrentTip { - msgIntersectNotFound := msgGeneric.(*MsgIntersectNotFound) - c.currentTipChan <- msgIntersectNotFound.Tip + msgIntersectNotFound := msgGeneric.(*MsgIntersectNotFound) + c.sendCurrentTip(msgIntersectNotFound.Tip) + + select { + case ch := <-c.wantIntersectFoundChan: + ch <- clientPointResult{tip: msgIntersectNotFound.Tip, error: IntersectNotFoundError} + default: } - c.intersectResultChan <- IntersectNotFoundError return nil }