diff --git a/cmd/gouroboros/chainsync.go b/cmd/gouroboros/chainsync.go index 50a9fd38..3d726c55 100644 --- a/cmd/gouroboros/chainsync.go +++ b/cmd/gouroboros/chainsync.go @@ -208,6 +208,37 @@ func testChainSync(f *globalFlags) { os.Exit(1) } + // REMOVE: Test GetCurrentTip during chain sync. + // for i := 0; i < 10; i++ { + // go func() { + // last := uint64(0) + // count := 0 + + // // REMOVE. + // time.Sleep(3 * time.Second) + + // fmt.Printf("Starting GetCurrentTip test\n") + + // for { + // count++ + + // tip, err := oConn.ChainSync().Client.GetCurrentTip() + // if err != nil { + // fmt.Printf("ERROR: GetCurrentTip: %v\n", err) + // return + // } + + // if tip.BlockNumber != last { + // fmt.Printf("tip: block:%d count:%d\n", tip.BlockNumber, count) + // last = tip.BlockNumber + // count = 0 + // } + + // time.Sleep(10 * time.Millisecond) + // } + // }() + // } + var point common.Point if chainSyncFlags.tip { tip, err := oConn.ChainSync().Client.GetCurrentTip() @@ -226,9 +257,12 @@ func testChainSync(f *globalFlags) { point = common.NewPointOrigin() } if chainSyncFlags.blockRange { + fmt.Printf("client: requesting block range\n") + start, end, err := oConn.ChainSync().Client.GetAvailableBlockRange( []common.Point{point}, ) + fmt.Printf("client: block range: %d -> %d\n", start, end) if err != nil { fmt.Printf("ERROR: failed to get available block range: %s\n", err) os.Exit(1) diff --git a/connection.go b/connection.go index 44d172a6..6e045235 100644 --- a/connection.go +++ b/connection.go @@ -223,6 +223,10 @@ func (c *Connection) ProtocolVersion() (uint16, protocol.VersionData) { // shutdown performs cleanup operations when the connection is shutdown, either due to explicit Close() or an error func (c *Connection) shutdown() { + // Immediately close the chainsync client + if c.chainSync != nil { + c.chainSync.Client.Close() + } // Gracefully stop the muxer if c.muxer != nil { c.muxer.Stop() diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index f2a4bded..0e80f8d6 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -17,7 +17,6 @@ package chainsync import ( "encoding/hex" "fmt" - "sync" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" @@ -27,19 +26,57 @@ import ( // Client implements the ChainSync client type Client struct { *protocol.Protocol - config *Config - callbackContext CallbackContext - busyMutex sync.Mutex - intersectResultChan chan error + config *Config + callbackContext CallbackContext + + clientDoneChan chan struct{} + messageHandlerDoneChan chan struct{} + requestHandlerDoneChan chan struct{} + + handleMessageChan chan clientHandleMessage + immediateStopChan chan struct{} readyForNextBlockChan chan bool - wantCurrentTip bool - currentTipChan chan Tip - wantFirstBlock bool - firstBlockChan chan common.Point - wantIntersectPoint bool - intersectPointChan chan common.Point - onceStart sync.Once - onceStop sync.Once + + requestFindIntersectChan chan clientFindIntersectRequest + requestGetAvailableBlockRangeChan chan clientGetAvailableBlockRangeRequest + requestStartSyncingChan chan clientStartSyncingRequest + requestStopClientChan chan chan<- error + wantCurrentTipChan chan chan<- Tip + wantFirstBlockChan chan chan<- clientPointResult + wantIntersectFoundChan chan chan<- clientPointResult + wantRollbackChan chan chan<- Tip +} + +type clientFindIntersectRequest struct { + intersectPoints []common.Point + resultChan chan<- clientPointResult +} + +type clientGetAvailableBlockRangeRequest struct { + intersectPoints []common.Point + resultChan chan<- clientGetAvailableBlockRangeResult +} + +type clientGetAvailableBlockRangeResult struct { + start common.Point + end common.Point + error error +} + +type clientHandleMessage struct { + message protocol.Message + errorChan chan<- error +} + +type clientPointResult struct { + tip Tip + point common.Point + error error +} + +type clientStartSyncingRequest struct { + intersectPoints []common.Point + resultChan chan<- error } // NewClient returns a new ChainSync client object @@ -52,17 +89,37 @@ func NewClient(stateContext interface{}, protoOptions protocol.ProtocolOptions, ProtocolId = ProtocolIdNtN msgFromCborFunc = NewMsgFromCborNtN } + // TODO: Storing the config as a pointer is unsafe. if cfg == nil { tmpCfg := NewConfig() cfg = &tmpCfg } + c := &Client{ - config: cfg, - intersectResultChan: make(chan error), + config: cfg, + + clientDoneChan: make(chan struct{}), + messageHandlerDoneChan: make(chan struct{}), + requestHandlerDoneChan: make(chan struct{}), + + handleMessageChan: make(chan clientHandleMessage), + immediateStopChan: make(chan struct{}, 1), + // TODO: This channel set to 0 length, which would block message handling. Review if this is ok + // to set to 1. readyForNextBlockChan: make(chan bool), - currentTipChan: make(chan Tip), - firstBlockChan: make(chan common.Point), - intersectPointChan: make(chan common.Point), + + requestFindIntersectChan: make(chan clientFindIntersectRequest), + requestGetAvailableBlockRangeChan: make(chan clientGetAvailableBlockRangeRequest), + requestStartSyncingChan: make(chan clientStartSyncingRequest), + requestStopClientChan: make(chan chan<- error), + + // TODO: We should only have a buffer size of 1 here, and review the protocol to make sure + // it always responds to messages. If it doesn't, we should add a timeout to the channels + // and error handling in case the node misbehaves. + wantCurrentTipChan: make(chan chan<- Tip), + wantFirstBlockChan: make(chan chan<- clientPointResult, 1), + wantIntersectFoundChan: make(chan chan<- clientPointResult, 1), + wantRollbackChan: make(chan chan<- Tip, 1), } c.callbackContext = CallbackContext{ Client: c, @@ -95,81 +152,90 @@ func NewClient(stateContext interface{}, protoOptions protocol.ProtocolOptions, InitialState: stateIdle, } c.Protocol = protocol.New(protoConfig) - return c -} -func (c *Client) Start() { - c.onceStart.Do(func() { - c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.Protocol.DoneChan() - close(c.intersectResultChan) - close(c.readyForNextBlockChan) - close(c.currentTipChan) - close(c.firstBlockChan) - close(c.intersectPointChan) - }() - }) + go func() { + select { + case <-c.Protocol.DoneChan(): + case <-c.immediateStopChan: + } + + // Strictly speaking, the client isn't done here as either the request or message handling + // loops might still be running. + close(c.clientDoneChan) + }() + go c.requestHandlerLoop() + go c.messageHandlerLoop() + + return c } -func (c *Client) messageHandler(msg protocol.Message) error { - var err error - switch msg.Type() { - case MessageTypeAwaitReply: - err = c.handleAwaitReply() - case MessageTypeRollForward: - err = c.handleRollForward(msg) - case MessageTypeRollBackward: - err = c.handleRollBackward(msg) - case MessageTypeIntersectFound: - err = c.handleIntersectFound(msg) - case MessageTypeIntersectNotFound: - err = c.handleIntersectNotFound(msg) - default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) +// Close immediately transitions the protocol to the Done state. No more protocol operations will be +// possible afterward. +func (c *Client) Close() error { + select { + case <-c.clientDoneChan: + case c.immediateStopChan <- struct{}{}: } - return err + + <-c.clientDoneChan + <-c.messageHandlerDoneChan + <-c.requestHandlerDoneChan + return nil } -// Stop transitions the protocol to the Done state. No more protocol operations will be possible afterward +// Stop gracefully transitions the protocol to the Done state. No more protocol operations will be +// possible afterward. func (c *Client) Stop() error { - var err error - c.onceStop.Do(func() { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + ch := make(chan error) + + select { + case <-c.clientDoneChan: + return nil + case c.requestStopClientChan <- ch: + } + + select { + case <-c.clientDoneChan: + return nil + case err := <-ch: + <-c.clientDoneChan + <-c.messageHandlerDoneChan + <-c.requestHandlerDoneChan + + if err == protocol.ProtocolShuttingDownError { + return nil } - }) - return err + return err + } } -// GetCurrentTip returns the current chain tip +// GetCurrentTip returns the current chain tip. func (c *Client) GetCurrentTip() (*Tip, error) { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - c.wantCurrentTip = true - msg := NewMsgFindIntersect([]common.Point{}) - if err := c.SendMessage(msg); err != nil { - return nil, err + currentTipChan := make(chan Tip, 1) + resultChan := make(chan clientPointResult, 1) + request := clientFindIntersectRequest{ + intersectPoints: []common.Point{}, + resultChan: resultChan, } - tip, ok := <-c.currentTipChan - if !ok { + + select { + case <-c.clientDoneChan: return nil, protocol.ProtocolShuttingDownError + case c.wantCurrentTipChan <- currentTipChan: + result := <-currentTipChan + return &result, nil + case c.requestFindIntersectChan <- request: } - // Clear out intersect result channel to prevent blocking - _, ok = <-c.intersectResultChan - if !ok { + + select { + case <-c.clientDoneChan: return nil, protocol.ProtocolShuttingDownError + case result := <-resultChan: + if result.error != nil && result.error != IntersectNotFoundError { + return nil, result.error + } + return &result.tip, nil } - c.wantCurrentTip = false - return &tip, nil } // GetAvailableBlockRange returns the start and end of the range of available blocks given the provided intersect @@ -177,126 +243,479 @@ func (c *Client) GetCurrentTip() (*Tip, error) { func (c *Client) GetAvailableBlockRange( intersectPoints []common.Point, ) (common.Point, common.Point, error) { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - var start, end common.Point - // Find our chain intersection - c.wantCurrentTip = true - c.wantIntersectPoint = true - msgFindIntersect := NewMsgFindIntersect(intersectPoints) - if err := c.SendMessage(msgFindIntersect); err != nil { - return start, end, err - } - gotIntersectResult := false + resultChan := make(chan clientGetAvailableBlockRangeResult, 1) + request := clientGetAvailableBlockRangeRequest{ + intersectPoints: intersectPoints, + resultChan: resultChan, + } + + fmt.Printf("GetAvailableBlockRange: %v\n", intersectPoints) + + select { + case <-c.clientDoneChan: + return common.Point{}, common.Point{}, protocol.ProtocolShuttingDownError + case c.requestGetAvailableBlockRangeChan <- request: + } + + fmt.Printf("GetAvailableBlockRange: waiting for result\n") + + select { + case <-c.clientDoneChan: + return common.Point{}, common.Point{}, protocol.ProtocolShuttingDownError + case result := <-resultChan: + fmt.Printf("GetAvailableBlockRange: result: %v, %v, %v\n", result.start, result.end, result.error) + return result.start, result.end, result.error + } +} + +// Sync begins a chain-sync operation using the provided intersect point(s). Incoming blocks will be delivered +// via the RollForward callback function specified in the protocol config +func (c *Client) Sync(intersectPoints []common.Point) error { + resultChan := make(chan error, 1) + request := clientStartSyncingRequest{ + intersectPoints: intersectPoints, + resultChan: resultChan, + } + + select { + case <-c.clientDoneChan: + return protocol.ProtocolShuttingDownError + case c.requestStartSyncingChan <- request: + return c.waitForErrorChan(resultChan) + } +} + +func (c *Client) sendCurrentTip(tip Tip) { for { select { - case <-c.DoneChan(): - return start, end, protocol.ProtocolShuttingDownError - case tip := <-c.currentTipChan: - end = tip.Point - c.wantCurrentTip = false - case point := <-c.intersectPointChan: - start = point - c.wantIntersectPoint = false - case err := <-c.intersectResultChan: - if err != nil { - return start, end, err + case ch := <-c.wantCurrentTipChan: + fmt.Printf("sendCurrentTip: %v\n", tip) + ch <- tip + default: + return + } + } +} + +func (c *Client) sendReadyForNextBlock(ready bool) error { + select { + case <-c.clientDoneChan: + return protocol.ProtocolShuttingDownError + case c.readyForNextBlockChan <- ready: + return nil + } +} + +// wantFirstBlock returns a channel that will receive the first block after the current tip, and a +// function that can be used to clear the channel if sending the request message fails. +func (c *Client) wantFirstBlock() (<-chan clientPointResult, func()) { + ch := make(chan clientPointResult, 1) + + select { + case <-c.clientDoneChan: + return nil, func() {} + case c.wantFirstBlockChan <- ch: + return ch, func() { + select { + case <-c.wantFirstBlockChan: + default: } - gotIntersectResult = true } - if !c.wantIntersectPoint && !c.wantCurrentTip && gotIntersectResult { - break + } +} + +// wantIntersectFound returns a channel that will receive the result of the next intersect request, +// and a function that can be used to clear the channel if sending the request message fails. +func (c *Client) wantIntersectFound() (<-chan clientPointResult, func()) { + ch := make(chan clientPointResult, 1) + + select { + case <-c.clientDoneChan: + return nil, func() {} + case c.wantIntersectFoundChan <- ch: + return ch, func() { + select { + case <-c.wantIntersectFoundChan: + default: + } } } +} + +// wantRollback returns a channel that will receive the result of the next rollback request, and a +// function that can be used to clear the channel if sending the request message fails. +func (c *Client) wantRollback() (<-chan Tip, func()) { + ch := make(chan Tip, 1) + + select { + case <-c.clientDoneChan: + return nil, func() {} + case c.wantRollbackChan <- ch: + return ch, func() { + select { + case <-c.wantRollbackChan: + default: + } + } + } +} + +func (c *Client) waitForErrorChan(ch <-chan error) error { + select { + case <-c.clientDoneChan: + return protocol.ProtocolShuttingDownError + case err := <-ch: + return err + } +} + +func (c *Client) requestFindIntersect(intersectPoints []common.Point) clientPointResult { + resultChan, cancel := c.wantIntersectFound() + if resultChan == nil { + return clientPointResult{error: protocol.ProtocolShuttingDownError} + } + + msg := NewMsgFindIntersect(intersectPoints) + if err := c.SendMessage(msg); err != nil { + fmt.Printf("requestFindIntersect: error sending message: %v\n", err) + cancel() + return clientPointResult{error: err} + } + + select { + case <-c.clientDoneChan: + return clientPointResult{error: protocol.ProtocolShuttingDownError} + case result := <-resultChan: + fmt.Printf("requestFindIntersect: received intersect: %+v --- %+v --- %+v --- %v\n", intersectPoints, result.tip, result.point, result.error) + return result + } +} + +func (c *Client) requestGetAvailableBlockRange( + intersectPoints []common.Point, +) clientGetAvailableBlockRangeResult { + fmt.Printf("requestGetAvailableBlockRange: waiting for intersect result for: %+v\n", intersectPoints) + + result := c.requestFindIntersect(intersectPoints) + if result.error != nil { + return clientGetAvailableBlockRangeResult{error: result.error} + } + start := result.point + end := result.tip.Point + // If we're already at the chain tip, return an empty range if start.Slot >= end.Slot { - return common.Point{}, common.Point{}, nil + return clientGetAvailableBlockRangeResult{} + } + + fmt.Printf("requestGetAvailableBlockRange: start=%v, end=%v\n", start, end) + + // Request the next block to get the first block after the intersect point. This should result + // in a rollback. + // + // TODO: Verify that the rollback always happends, if not review the code here. + rollbackChan, cancelRollback := c.wantRollback() + if rollbackChan == nil { + return clientGetAvailableBlockRangeResult{error: protocol.ProtocolShuttingDownError} } - // Request the next block to get the first block after the intersect point. This should result in a rollback - c.wantCurrentTip = true - c.wantFirstBlock = true + firstBlockChan, cancelFirstBlock := c.wantFirstBlock() + if firstBlockChan == nil { + return clientGetAvailableBlockRangeResult{error: protocol.ProtocolShuttingDownError} + } + defer func() { + if rollbackChan != nil { + cancelRollback() + } + if firstBlockChan != nil { + cancelFirstBlock() + } + }() + + // TODO: Recommended behavior on error should be to send an empty range. + + fmt.Printf("requestGetAvailableBlockRange: requesting next block\n") + msgRequestNext := NewMsgRequestNext() if err := c.SendMessage(msgRequestNext); err != nil { - return start, end, err + return clientGetAvailableBlockRangeResult{start: start, end: end, error: err} } + + fmt.Printf("requestGetAvailableBlockRange: waiting for rollback\n") + for { select { - case <-c.DoneChan(): - return start, end, protocol.ProtocolShuttingDownError - case tip := <-c.currentTipChan: + case <-c.clientDoneChan: + return clientGetAvailableBlockRangeResult{start: start, end: end, error: protocol.ProtocolShuttingDownError} + case tip := <-rollbackChan: + rollbackChan = nil end = tip.Point - c.wantCurrentTip = false - case point := <-c.firstBlockChan: - start = point - c.wantFirstBlock = false + + fmt.Printf("requestGetAvailableBlockRange: rollback received: %v\n", tip) + + case firstBlock := <-firstBlockChan: + firstBlockChan = nil + + fmt.Printf("requestGetAvailableBlockRange: first block received: %v\n", firstBlock) + + if firstBlock.error != nil { + return clientGetAvailableBlockRangeResult{ + start: start, + end: end, + error: fmt.Errorf("failed to get first block: %w", firstBlock.error), + } + } + start = firstBlock.point case <-c.readyForNextBlockChan: + // TODO: This doesn't check for true/false, verify if it should? + + fmt.Printf("requestGetAvailableBlockRange: ready for next block received\n") + // Request the next block msg := NewMsgRequestNext() if err := c.SendMessage(msg); err != nil { - return start, end, err + return clientGetAvailableBlockRangeResult{start: start, end: end, error: err} } } - if !c.wantFirstBlock && !c.wantCurrentTip { + if firstBlockChan == nil && rollbackChan == nil { break } } + + fmt.Println("GetAvailableBlockRange: done") + // If we're already at the chain tip, return an empty range if start.Slot >= end.Slot { - return common.Point{}, common.Point{}, nil + return clientGetAvailableBlockRangeResult{} } - return start, end, nil + + return clientGetAvailableBlockRangeResult{start: start, end: end} } -// Sync begins a chain-sync operation using the provided intersect point(s). Incoming blocks will be delivered -// via the RollForward callback function specified in the protocol config -func (c *Client) Sync(intersectPoints []common.Point) error { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() +func (c *Client) requestSync(intersectPoints []common.Point) error { + // TODO: Check if we're already syncing, if so return an error or cancel the current sync + // operation. Use a channel for this. + // Use origin if no intersect points were specified if len(intersectPoints) == 0 { intersectPoints = []common.Point{common.NewPointOrigin()} } + + fmt.Printf("Sync: intersectPoints=%v\n", func() string { + var s string + for _, p := range intersectPoints { + s += fmt.Sprintf("%v ", p.Slot) + } + return s + }()) + + intersectResultChan, cancel := c.wantIntersectFound() + if intersectResultChan == nil { + return protocol.ProtocolShuttingDownError + } + msg := NewMsgFindIntersect(intersectPoints) if err := c.SendMessage(msg); err != nil { + cancel() return err } - if err, ok := <-c.intersectResultChan; !ok { + + select { + case <-c.clientDoneChan: return protocol.ProtocolShuttingDownError - } else if err != nil { - return err - } - // Pipeline the initial block requests to speed things up a bit - // Using a value higher than 10 seems to cause problems with NtN - for i := 0; i <= c.config.PipelineLimit; i++ { - msg := NewMsgRequestNext() - if err := c.SendMessage(msg); err != nil { - return err + case result := <-intersectResultChan: + if result.error != nil { + return result.error } } - go c.syncLoop() + + fmt.Println("Sync: starting sync loop") + return nil } -func (c *Client) syncLoop() { +// requestHandlerLoop is the request handler loop for the client. +func (c *Client) requestHandlerLoop() { + defer func() { + close(c.requestHandlerDoneChan) + + select { + case <-c.clientDoneChan: + case c.immediateStopChan <- struct{}{}: + } + }() + + requestFindIntersectChan := c.requestFindIntersectChan + requestGetAvailableBlockRangeChan := c.requestGetAvailableBlockRangeChan + + isSyncing := false + syncPipelineCount := 0 + // syncPipelineLimit := c.config.PipelineLimit + + // REMOVE: Testing pipeline limit + syncPipelineLimit := 10 + + // TODO: Change NewClient to return errors on invalid configuration. + if syncPipelineLimit < 1 { + syncPipelineLimit = 1 + } + for { - // Wait for a block to be received - if ready, ok := <-c.readyForNextBlockChan; !ok { - // Channel is closed, which means we're shutting down + select { + case <-c.clientDoneChan: return - } else if !ready { - // Sync was cancelled + + case request := <-requestFindIntersectChan: + result := c.requestFindIntersect(request.intersectPoints) + request.resultChan <- result + if result.error != nil && result.error != IntersectNotFoundError { + c.SendError(result.error) + return + } + + case request := <-requestGetAvailableBlockRangeChan: + result := c.requestGetAvailableBlockRange(request.intersectPoints) + request.resultChan <- result + if result.error != nil && result.error != IntersectNotFoundError { + c.SendError(result.error) + return + } + + case request := <-c.requestStartSyncingChan: + if isSyncing { + // Already syncing. This should be an error(?) + request.resultChan <- nil + return + } + if syncPipelineCount != 0 { + // TODO: Review this behavior. Should we wait for the current pipeline to finish? + err := fmt.Errorf("sync pipeline is not empty") + request.resultChan <- err + c.SendError(err) + return + } + + // Disable requests that aren't allowed during syncing. (Review this) + isSyncing = true + requestFindIntersectChan = nil + requestGetAvailableBlockRangeChan = nil + + err := c.requestSync(request.intersectPoints) + request.resultChan <- err + if err != nil { + if err == IntersectNotFoundError { + continue + } + c.SendError(err) + return + } + + for syncPipelineCount < syncPipelineLimit { + fmt.Printf("requestNextBlockChan: %v : %v\n", syncPipelineCount, c.config.PipelineLimit) + + msg := NewMsgRequestNext() + if err := c.SendMessage(msg); err != nil { + c.SendError(err) + return + } + syncPipelineCount++ + } + + case ch := <-c.requestStopClientChan: + msg := NewMsgDone() + if err := c.SendMessage(msg); err != nil && err != protocol.ProtocolShuttingDownError { + fmt.Printf("Error sending Done message: %v\n", err) + ch <- err + c.SendError(err) + return + } + + fmt.Printf("Client done: Done message sent\n") + ch <- nil return + + case ready := <-c.readyForNextBlockChan: + if syncPipelineCount != 0 { + syncPipelineCount-- + } + + if !isSyncing { + // We're not syncing, so just ignore the ready signal. This can happen if the + // protocol sends us an unexpected rollforward/rollback message. + // + // TODO: Should this be an error? + fmt.Printf("readyForNextBlock received when not syncing\n") + continue + } + if !ready { + isSyncing = false + requestFindIntersectChan = c.requestFindIntersectChan + requestGetAvailableBlockRangeChan = c.requestGetAvailableBlockRangeChan + continue + } + + fmt.Printf("requestNextBlockChan: %v : %v\n", syncPipelineCount, c.config.PipelineLimit) + + msg := NewMsgRequestNext() + if err := c.SendMessage(msg); err != nil { + c.SendError(err) + return + } + syncPipelineCount++ } - c.busyMutex.Lock() - // Request the next block - // In practice we already have multiple block requests pipelined - // and this just adds another one to the pile - msg := NewMsgRequestNext() - if err := c.SendMessage(msg); err != nil { - c.SendError(err) + } +} + +// messageHandlerLoop is responsible for handling messages from the protocol connection. +func (c *Client) messageHandlerLoop() { + defer func() { + close(c.messageHandlerDoneChan) + + select { + case <-c.clientDoneChan: + case c.immediateStopChan <- struct{}{}: + } + }() + + for { + select { + case <-c.clientDoneChan: return + + case msg := <-c.handleMessageChan: + msg.errorChan <- func() error { + switch msg.message.Type() { + case MessageTypeAwaitReply: + return c.handleAwaitReply() + case MessageTypeRollForward: + return c.handleRollForward(msg.message) + case MessageTypeRollBackward: + return c.handleRollBackward(msg.message) + case MessageTypeIntersectFound: + return c.handleIntersectFound(msg.message) + case MessageTypeIntersectNotFound: + return c.handleIntersectNotFound(msg.message) + default: + return fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.message.Type(), + ) + } + }() } - c.busyMutex.Unlock() + } +} + +// messageHandler handles incoming messages from the protocol. It is called from the underlying +// protocol and is blocking. +func (c *Client) messageHandler(msg protocol.Message) error { + errorChan := make(chan error, 1) + + select { + case <-c.clientDoneChan: + return protocol.ProtocolShuttingDownError + case c.handleMessageChan <- clientHandleMessage{message: msg, errorChan: errorChan}: + return c.waitForErrorChan(errorChan) } } @@ -305,17 +724,29 @@ func (c *Client) handleAwaitReply() error { } func (c *Client) handleRollForward(msgGeneric protocol.Message) error { - if (c.config == nil || c.config.RollForwardFunc == nil) && !c.wantFirstBlock { + firstBlockChan := func() chan<- clientPointResult { + select { + case ch := <-c.wantFirstBlockChan: + return ch + default: + return nil + } + }() + if firstBlockChan == nil && (c.config == nil || c.config.RollForwardFunc == nil) { return fmt.Errorf( "received chain-sync RollForward message but no callback function is defined", ) } + var callbackErr error if c.Mode() == protocol.ProtocolModeNodeToNode { msg := msgGeneric.(*MsgRollForwardNtN) + c.sendCurrentTip(msg.Tip) + var blockHeader ledger.BlockHeader var blockType uint blockEra := msg.WrappedHeader.Era + switch blockEra { case ledger.BlockHeaderTypeByron: blockType = msg.WrappedHeader.ByronType() @@ -325,6 +756,9 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { msg.WrappedHeader.HeaderCbor(), ) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } default: @@ -336,58 +770,79 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { msg.WrappedHeader.HeaderCbor(), ) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } } - if c.wantFirstBlock { + + if firstBlockChan != nil { blockHash, err := hex.DecodeString(blockHeader.Hash()) if err != nil { + firstBlockChan <- clientPointResult{error: err} return err } point := common.NewPoint(blockHeader.SlotNumber(), blockHash) - c.firstBlockChan <- point + firstBlockChan <- clientPointResult{tip: msg.Tip, point: point} return nil } + // Call the user callback function callbackErr = c.config.RollForwardFunc(c.callbackContext, blockType, blockHeader, msg.Tip) } else { msg := msgGeneric.(*MsgRollForwardNtC) + c.sendCurrentTip(msg.Tip) + blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } return err } - if c.wantFirstBlock { + + if firstBlockChan != nil { blockHash, err := hex.DecodeString(blk.Hash()) if err != nil { + firstBlockChan <- clientPointResult{error: err} return err } point := common.NewPoint(blk.SlotNumber(), blockHash) - c.firstBlockChan <- point + firstBlockChan <- clientPointResult{tip: msg.Tip, point: point} return nil } + // Call the user callback function callbackErr = c.config.RollForwardFunc(c.callbackContext, msg.BlockType(), blk, msg.Tip) } if callbackErr != nil { if callbackErr == StopSyncProcessError { // Signal that we're cancelling the sync - c.readyForNextBlockChan <- false - return nil + return c.sendReadyForNextBlock(false) } else { return callbackErr } } + // Signal that we're ready for the next block - c.readyForNextBlockChan <- true - return nil + return c.sendReadyForNextBlock(true) } func (c *Client) handleRollBackward(msg protocol.Message) error { msgRollBackward := msg.(*MsgRollBackward) - if c.wantCurrentTip { - c.currentTipChan <- msgRollBackward.Tip + + fmt.Printf("handleRolling back to %v\n", msgRollBackward.Point) + + c.sendCurrentTip(msgRollBackward.Tip) + + select { + case ch := <-c.wantRollbackChan: + ch <- msgRollBackward.Tip + default: } - if !c.wantFirstBlock { + + if len(c.wantFirstBlockChan) == 0 { if c.config.RollBackwardFunc == nil { return fmt.Errorf( "received chain-sync RollBackward message but no callback function is defined", @@ -397,35 +852,46 @@ func (c *Client) handleRollBackward(msg protocol.Message) error { if callbackErr := c.config.RollBackwardFunc(c.callbackContext, msgRollBackward.Point, msgRollBackward.Tip); callbackErr != nil { if callbackErr == StopSyncProcessError { // Signal that we're cancelling the sync - c.readyForNextBlockChan <- false - return nil + return c.sendReadyForNextBlock(false) } else { return callbackErr } } + } else { + fmt.Printf("handleRolling firstBlockchan\n") } - // Signal that we're ready for the next block - c.readyForNextBlockChan <- true - return nil + + return c.sendReadyForNextBlock(true) } func (c *Client) handleIntersectFound(msg protocol.Message) error { msgIntersectFound := msg.(*MsgIntersectFound) - if c.wantCurrentTip { - c.currentTipChan <- msgIntersectFound.Tip - } - if c.wantIntersectPoint { - c.intersectPointChan <- msgIntersectFound.Point + + fmt.Printf("handleIntersect found: %v\n", msgIntersectFound.Point) + + c.sendCurrentTip(msgIntersectFound.Tip) + + select { + case ch := <-c.wantIntersectFoundChan: + ch <- clientPointResult{tip: msgIntersectFound.Tip, point: msgIntersectFound.Point} + default: } - c.intersectResultChan <- nil + return nil } func (c *Client) handleIntersectNotFound(msgGeneric protocol.Message) error { - if c.wantCurrentTip { - msgIntersectNotFound := msgGeneric.(*MsgIntersectNotFound) - c.currentTipChan <- msgIntersectNotFound.Tip + msgIntersectNotFound := msgGeneric.(*MsgIntersectNotFound) + + fmt.Printf("handleIntersect not found\n") + + c.sendCurrentTip(msgIntersectNotFound.Tip) + + select { + case ch := <-c.wantIntersectFoundChan: + ch <- clientPointResult{tip: msgIntersectNotFound.Tip, error: IntersectNotFoundError} + default: } - c.intersectResultChan <- IntersectNotFoundError + return nil }