From c875d08b4774ba819557b7b461816324d6c00eb6 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Wed, 2 Oct 2024 12:13:16 +0200 Subject: [PATCH] feat(dot/sync): improve worker pool The main difference in the worker pool API is that SubmitBatch() does not block until the whole batch has been processed. Instead, it returns an ID which can be used to retrieve the current state of the batch. In addition, Results() returns a channel over which task results are sent as they become available. The main improvement this brings is increased concurrency, since results can be processed before the whole batch has been completed. What has not changed is the overall flow of the Strategy interface; getting a new batch of tasks with NextActions() and processing the results with Process(). Closes #4232 --- dot/sync/configuration.go | 2 +- dot/sync/fullsync.go | 310 +++++++++++++------------- dot/sync/fullsync_test.go | 74 ++++--- dot/sync/service.go | 54 ++++- dot/sync/worker_pool.go | 419 ++++++++++++++++++++++++----------- dot/sync/worker_pool_test.go | 268 ++++++++++++++++++++++ 6 files changed, 806 insertions(+), 321 deletions(-) create mode 100644 dot/sync/worker_pool_test.go diff --git a/dot/sync/configuration.go b/dot/sync/configuration.go index e144a87cbc..646c77ce87 100644 --- a/dot/sync/configuration.go +++ b/dot/sync/configuration.go @@ -17,7 +17,7 @@ func WithStrategies(currentStrategy, defaultStrategy Strategy) ServiceConfig { func WithNetwork(net Network) ServiceConfig { return func(svc *SyncService) { svc.network = net - svc.workerPool = newSyncWorkerPool(net) + //svc.workerPool = newSyncWorkerPool(net) } } diff --git a/dot/sync/fullsync.go b/dot/sync/fullsync.go index 79db8b4a29..789743a5d9 100644 --- a/dot/sync/fullsync.go +++ b/dot/sync/fullsync.go @@ -19,7 +19,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) -const defaultNumOfTasks = 3 +const defaultNumOfTasks = 10 var _ Strategy = (*FullSyncStrategy)(nil) @@ -86,7 +86,7 @@ func NewFullSyncStrategy(cfg *FullSyncConfig) *FullSyncStrategy { } } -func (f *FullSyncStrategy) NextActions() ([]*SyncTask, error) { +func (f *FullSyncStrategy) NextActions() ([]Task, error) { f.startedAt = time.Now() f.syncedBlocks = 0 @@ -129,12 +129,11 @@ func (f *FullSyncStrategy) NextActions() ([]*SyncTask, error) { return f.createTasks(reqsFromQueue), nil } -func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) []*SyncTask { - tasks := make([]*SyncTask, 0, len(requests)) +func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) []Task { + tasks := make([]Task, 0, len(requests)) for _, req := range requests { - tasks = append(tasks, &SyncTask{ + tasks = append(tasks, &syncTask{ request: req, - response: &messages.BlockResponseMessage{}, requestMaker: f.reqMaker, }) } @@ -146,126 +145,139 @@ func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) // or complete an incomplete block or is part of a disjoint block set which will // as a result it returns the if the strategy is finished, the peer reputations to change, // peers to block/ban, or an error. FullSyncStrategy is intended to run as long as the node lives. -func (f *FullSyncStrategy) Process(results []*SyncTaskResult) ( +func (f *FullSyncStrategy) Process(results <-chan TaskResult) ( isFinished bool, reputations []Change, bans []peer.ID, err error) { - repChanges, peersToIgnore, validResp := validateResults(results, f.badBlocks) - logger.Debugf("evaluating %d task results, %d valid responses", len(results), len(validResp)) - var highestFinalized *types.Header - highestFinalized, err = f.blockState.GetHighestFinalisedHeader() + highestFinalized, err := f.blockState.GetHighestFinalisedHeader() if err != nil { return false, nil, nil, fmt.Errorf("getting highest finalized header") } - readyBlocks := make([][]*types.BlockData, 0, len(validResp)) - for _, reqRespData := range validResp { - // if Gossamer requested the header, then the response data should contains + // This is safe as long as we are the only goroutine reading from the channel. + for len(results) > 0 { + readyBlocks := make([][]*types.BlockData, 0) + result := <-results + repChange, ignorePeer, validResp := validateResult(result, f.badBlocks) + + if repChange != nil { + reputations = append(reputations, *repChange) + } + + if ignorePeer { + bans = append(bans, result.Who) + } + + if validResp == nil || len(validResp.responseData) == 0 { + continue + } + + // if Gossamer requested the header, then the response data should contain // the full blocks to be imported. If Gossamer didn't request the header, // then the response should only contain the missing parts that will complete // the unreadyBlocks and then with the blocks completed we should be able to import them - if reqRespData.req.RequestField(messages.RequestedDataHeader) { - updatedFragment, ok := f.unreadyBlocks.updateDisjointFragments(reqRespData.responseData) + if validResp.req.RequestField(messages.RequestedDataHeader) { + updatedFragment, ok := f.unreadyBlocks.updateDisjointFragments(validResp.responseData) if ok { validBlocks := validBlocksUnderFragment(highestFinalized.Number, updatedFragment) if len(validBlocks) > 0 { readyBlocks = append(readyBlocks, validBlocks) } } else { - readyBlocks = append(readyBlocks, reqRespData.responseData) + readyBlocks = append(readyBlocks, validResp.responseData) } - continue - } - - completedBlocks := f.unreadyBlocks.updateIncompleteBlocks(reqRespData.responseData) - readyBlocks = append(readyBlocks, completedBlocks) - } - - // disjoint fragments are pieces of the chain that could not be imported right now - // because is blocks too far ahead or blocks that belongs to forks - sortFragmentsOfChain(readyBlocks) - orderedFragments := mergeFragmentsOfChain(readyBlocks) - - nextBlocksToImport := make([]*types.BlockData, 0) - disjointFragments := make([][]*types.BlockData, 0) - - for _, fragment := range orderedFragments { - ok, err := f.blockState.HasHeader(fragment[0].Header.ParentHash) - if err != nil && !errors.Is(err, database.ErrNotFound) { - return false, nil, nil, fmt.Errorf("checking block parent header: %w", err) + completedBlocks := f.unreadyBlocks.updateIncompleteBlocks(validResp.responseData) + if len(completedBlocks) > 0 { + readyBlocks = append(readyBlocks, completedBlocks) + } } - if ok { - nextBlocksToImport = append(nextBlocksToImport, fragment...) - continue - } + // disjoint fragments are pieces of the chain that could not be imported right now + // because is blocks too far ahead or blocks that belongs to forks + sortFragmentsOfChain(readyBlocks) + orderedFragments := mergeFragmentsOfChain(readyBlocks) - disjointFragments = append(disjointFragments, fragment) - } + nextBlocksToImport := make([]*types.BlockData, 0) + disjointFragments := make([][]*types.BlockData, 0) - // this loop goal is to import ready blocks as well as update the highestFinalized header - for len(nextBlocksToImport) > 0 || len(disjointFragments) > 0 { - for _, blockToImport := range nextBlocksToImport { - imported, err := f.blockImporter.importBlock(blockToImport, networkInitialSync) - if err != nil { - return false, nil, nil, fmt.Errorf("while handling ready block: %w", err) + for _, fragment := range orderedFragments { + ok, err := f.blockState.HasHeader(fragment[0].Header.ParentHash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + return false, nil, nil, fmt.Errorf("checking block parent header: %w", err) } - if imported { - f.syncedBlocks += 1 + if ok { + nextBlocksToImport = append(nextBlocksToImport, fragment...) + continue } - } - nextBlocksToImport = make([]*types.BlockData, 0) - highestFinalized, err = f.blockState.GetHighestFinalisedHeader() - if err != nil { - return false, nil, nil, fmt.Errorf("getting highest finalized header") + disjointFragments = append(disjointFragments, fragment) } - // check if blocks from the disjoint set can be imported on their on forks - // given that fragment contains chains and these chains contains blocks - // check if the first block in the chain contains a parent known by us - for _, fragment := range disjointFragments { - validFragment := validBlocksUnderFragment(highestFinalized.Number, fragment) - if len(validFragment) == 0 { - continue + // this loop goal is to import ready blocks as well as update the highestFinalized header + for len(nextBlocksToImport) > 0 || len(disjointFragments) > 0 { + for _, blockToImport := range nextBlocksToImport { + imported, err := f.blockImporter.importBlock(blockToImport, networkInitialSync) + if err != nil { + return false, nil, nil, fmt.Errorf("while handling ready block: %w", err) + } + + if imported { + f.syncedBlocks += 1 + } } - ok, err := f.blockState.HasHeader(validFragment[0].Header.ParentHash) - if err != nil && !errors.Is(err, database.ErrNotFound) { - return false, nil, nil, err + nextBlocksToImport = make([]*types.BlockData, 0) + highestFinalized, err = f.blockState.GetHighestFinalisedHeader() + if err != nil { + return false, nil, nil, fmt.Errorf("getting highest finalized header") } - if !ok { - // if the parent of this valid fragment is behind our latest finalized number - // then we can discard the whole fragment since it is a invalid fork - if (validFragment[0].Header.Number - 1) <= highestFinalized.Number { + // check if blocks from the disjoint set can be imported or they're on forks + // given that fragment contains chains and these chains contains blocks + // check if the first block in the chain contains a parent known by us + for _, fragment := range disjointFragments { + validFragment := validBlocksUnderFragment(highestFinalized.Number, fragment) + if len(validFragment) == 0 { continue } - logger.Infof("starting an acestor search from %s parent of #%d (%s)", - validFragment[0].Header.ParentHash, - validFragment[0].Header.Number, - validFragment[0].Header.Hash(), - ) - - f.unreadyBlocks.newDisjointFragment(validFragment) - request := messages.NewBlockRequest( - *messages.NewFromBlock(validFragment[0].Header.ParentHash), - messages.MaxBlocksInResponse, - messages.BootstrapRequestData, messages.Descending) - f.requestQueue.PushBack(request) - } else { - // inserting them in the queue to be processed after the main chain - nextBlocksToImport = append(nextBlocksToImport, validFragment...) + ok, err := f.blockState.HasHeader(validFragment[0].Header.ParentHash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + return false, nil, nil, err + } + + if !ok { + // if the parent of this valid fragment is behind our latest finalized number + // then we can discard the whole fragment since it is a invalid fork + if (validFragment[0].Header.Number - 1) <= highestFinalized.Number { + continue + } + + logger.Infof("starting an ancestor search from %s parent of #%d (%s)", + validFragment[0].Header.ParentHash, + validFragment[0].Header.Number, + validFragment[0].Header.Hash(), + ) + + f.unreadyBlocks.newDisjointFragment(validFragment) + request := messages.NewBlockRequest( + *messages.NewFromBlock(validFragment[0].Header.ParentHash), + messages.MaxBlocksInResponse, + messages.BootstrapRequestData, messages.Descending) + f.requestQueue.PushBack(request) + } else { + // inserting them in the queue to be processed in the next loop iteration + nextBlocksToImport = append(nextBlocksToImport, validFragment...) + } } - } - disjointFragments = nil + disjointFragments = nil + } } f.unreadyBlocks.removeIrrelevantFragments(highestFinalized.Number) - return false, repChanges, peersToIgnore, nil + return false, reputations, bans, nil } func (f *FullSyncStrategy) ShowMetrics() { @@ -395,85 +407,79 @@ type RequestResponseData struct { responseData []*types.BlockData } -func validateResults(results []*SyncTaskResult, badBlocks []string) (repChanges []Change, - peersToBlock []peer.ID, validRes []RequestResponseData) { - - repChanges = make([]Change, 0) - peersToBlock = make([]peer.ID, 0) - validRes = make([]RequestResponseData, 0, len(results)) - -resultLoop: - for _, result := range results { - request := result.request.(*messages.BlockRequestMessage) +func validateResult(result TaskResult, badBlocks []string) (repChange *Change, + blockPeer bool, validRes *RequestResponseData) { - if !result.completed { - continue - } - - response := result.response.(*messages.BlockResponseMessage) - if request.Direction == messages.Descending { - // reverse blocks before pre-validating and placing in ready queue - slices.Reverse(response.BlockData) - } + if !result.Completed { + return + } - err := validateResponseFields(request, response.BlockData) - if err != nil { - logger.Warnf("validating fields: %s", err) - // TODO: check the reputation change for nil body in response - // and nil justification in response - if errors.Is(err, errNilHeaderInResponse) { - repChanges = append(repChanges, Change{ - who: result.who, - rep: peerset.ReputationChange{ - Value: peerset.IncompleteHeaderValue, - Reason: peerset.IncompleteHeaderReason, - }, - }) - } + task, ok := result.Task.(*syncTask) + if !ok { + logger.Warnf("skipping unexpected task type in TaskResult: %T", result.Task) + return + } - continue - } + request := task.request.(*messages.BlockRequestMessage) + response := result.Result.(*messages.BlockResponseMessage) + if request.Direction == messages.Descending { + // reverse blocks before pre-validating and placing in ready queue + slices.Reverse(response.BlockData) + } - // only check if the responses forms a chain if the response contains the headers - // of each block, othewise the response might only have the body/justification for - // a block - if request.RequestField(messages.RequestedDataHeader) && !isResponseAChain(response.BlockData) { - logger.Warnf("response from %s is not a chain", result.who) - repChanges = append(repChanges, Change{ - who: result.who, + err := validateResponseFields(request, response.BlockData) + if err != nil { + logger.Warnf("validating fields: %s", err) + // TODO: check the reputation change for nil body in response + // and nil justification in response + if errors.Is(err, errNilHeaderInResponse) { + repChange = &Change{ + who: result.Who, rep: peerset.ReputationChange{ Value: peerset.IncompleteHeaderValue, Reason: peerset.IncompleteHeaderReason, }, - }) - continue + } + return } + } - for _, block := range response.BlockData { - if slices.Contains(badBlocks, block.Hash.String()) { - logger.Warnf("%s sent a known bad block: #%d (%s)", - result.who, block.Number(), block.Hash.String()) - - peersToBlock = append(peersToBlock, result.who) - repChanges = append(repChanges, Change{ - who: result.who, - rep: peerset.ReputationChange{ - Value: peerset.BadBlockAnnouncementValue, - Reason: peerset.BadBlockAnnouncementReason, - }, - }) - - continue resultLoop - } + // only check if the block data in the response forms a chain if it contains the headers + // of each block, othewise the response might only have the body/justification for a block + if request.RequestField(messages.RequestedDataHeader) && !isResponseAChain(response.BlockData) { + logger.Warnf("response from %s is not a chain", result.Who) + repChange = &Change{ + who: result.Who, + rep: peerset.ReputationChange{ + Value: peerset.IncompleteHeaderValue, + Reason: peerset.IncompleteHeaderReason, + }, } + return + } - validRes = append(validRes, RequestResponseData{ - req: request, - responseData: response.BlockData, - }) + for _, block := range response.BlockData { + if slices.Contains(badBlocks, block.Hash.String()) { + logger.Warnf("%s sent a known bad block: #%d (%s)", + result.Who, block.Number(), block.Hash.String()) + + blockPeer = true + repChange = &Change{ + who: result.Who, + rep: peerset.ReputationChange{ + Value: peerset.BadBlockAnnouncementValue, + Reason: peerset.BadBlockAnnouncementReason, + }, + } + return + } } - return repChanges, peersToBlock, validRes + validRes = &RequestResponseData{ + req: request, + responseData: response.BlockData, + } + return } // sortFragmentsOfChain will organise the fragments diff --git a/dot/sync/fullsync_test.go b/dot/sync/fullsync_test.go index 0c9bbd4122..9b1bc1e794 100644 --- a/dot/sync/fullsync_test.go +++ b/dot/sync/fullsync_test.go @@ -7,6 +7,10 @@ import ( "container/list" "testing" + "gopkg.in/yaml.v3" + + _ "embed" + "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/peerset" @@ -15,9 +19,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "gopkg.in/yaml.v3" - - _ "embed" ) //go:embed testdata/westend_blocks.yaml @@ -69,8 +70,7 @@ func TestFullSyncNextActions(t *testing.T) { task, err := fs.NextActions() require.NoError(t, err) - require.Len(t, task, int(maxRequestsAllowed)) - request := task[0].request.(*messages.BlockRequestMessage) + request := task[0].(*syncTask).request.(*messages.BlockRequestMessage) require.Equal(t, uint(1), request.StartingBlock.RawValue()) require.Equal(t, uint32(128), *request.Max) }) @@ -171,7 +171,7 @@ func TestFullSyncNextActions(t *testing.T) { task, err := fs.NextActions() require.NoError(t, err) - require.Equal(t, task[0].request, tt.expectedTasks[0]) + require.Equal(t, task[0].(*syncTask).request, tt.expectedTasks[0]) require.Equal(t, fs.requestQueue.Len(), tt.expectedQueueLen) }) } @@ -192,37 +192,45 @@ func TestFullSyncProcess(t *testing.T) { require.NoError(t, err) t.Run("requested_max_but_received_less_blocks", func(t *testing.T) { - syncTaskResults := []*SyncTaskResult{ + ctrl := gomock.NewController(t) + requestMaker := NewMockRequestMaker(ctrl) + + syncTaskResults := []TaskResult{ // first task // 1 -> 10 { - who: peer.ID("peerA"), - request: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 127, - messages.BootstrapRequestData, messages.Ascending), - completed: true, - response: fstTaskBlockResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 127, + messages.BootstrapRequestData, messages.Ascending), + requestMaker: requestMaker, + }, + Completed: true, + Result: fstTaskBlockResponse, }, // there is gap from 11 -> 128 // second task // 129 -> 256 { - who: peer.ID("peerA"), - request: messages.NewBlockRequest(*messages.NewFromBlock(uint(129)), 127, - messages.BootstrapRequestData, messages.Ascending), - completed: true, - response: sndTaskBlockResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: messages.NewBlockRequest(*messages.NewFromBlock(uint(129)), 127, + messages.BootstrapRequestData, messages.Ascending), + requestMaker: requestMaker, + }, + Completed: true, + Result: sndTaskBlockResponse, }, } genesisHeader := types.NewHeader(fstTaskBlockResponse.BlockData[0].Header.ParentHash, common.Hash{}, common.Hash{}, 0, types.NewDigest()) - ctrl := gomock.NewController(t) mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().GetHighestFinalisedHeader(). Return(genesisHeader, nil). - Times(4) + Times(5) mockBlockState.EXPECT(). HasHeader(fstTaskBlockResponse.BlockData[0].Header.ParentHash). @@ -247,7 +255,12 @@ func TestFullSyncProcess(t *testing.T) { fs := NewFullSyncStrategy(cfg) fs.blockImporter = mockImporter - done, _, _, err := fs.Process(syncTaskResults) + results := make(chan TaskResult, len(syncTaskResults)) + for _, result := range syncTaskResults { + results <- result + } + + done, _, _, err := fs.Process(results) require.NoError(t, err) require.False(t, done) @@ -271,18 +284,19 @@ func TestFullSyncProcess(t *testing.T) { err = ancestorSearchResponse.Decode(common.MustHexToBytes(westendBlocks.Blocks1To128)) require.NoError(t, err) - syncTaskResults = []*SyncTaskResult{ + results <- TaskResult{ // ancestor search task // 128 -> 1 - { - who: peer.ID("peerA"), - request: expectedAncestorRequest, - completed: true, - response: ancestorSearchResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: expectedAncestorRequest, + requestMaker: requestMaker, }, + Completed: true, + Result: ancestorSearchResponse, } - done, _, _, err = fs.Process(syncTaskResults) + done, _, _, err = fs.Process(results) require.NoError(t, err) require.False(t, done) @@ -293,7 +307,7 @@ func TestFullSyncProcess(t *testing.T) { } func TestFullSyncBlockAnnounce(t *testing.T) { - t.Run("announce_a_far_block_without_any_commom_ancestor", func(t *testing.T) { + t.Run("announce_a_far_block_without_any_common_ancestor", func(t *testing.T) { highestFinalizedHeader := &types.Header{ ParentHash: common.BytesToHash([]byte{0}), StateRoot: common.BytesToHash([]byte{3, 3, 3, 3}), @@ -347,7 +361,7 @@ func TestFullSyncBlockAnnounce(t *testing.T) { require.Zero(t, fs.requestQueue.Len()) }) - t.Run("announce_closer_valid_block_without_any_commom_ancestor", func(t *testing.T) { + t.Run("announce_closer_valid_block_without_any_common_ancestor", func(t *testing.T) { highestFinalizedHeader := &types.Header{ ParentHash: common.BytesToHash([]byte{0}), StateRoot: common.BytesToHash([]byte{3, 3, 3, 3}), @@ -457,7 +471,7 @@ func TestFullSyncBlockAnnounce(t *testing.T) { requests := make([]messages.P2PMessage, len(tasks)) for idx, task := range tasks { - requests[idx] = task.request + requests[idx] = task.(*syncTask).request } block17 := types.NewHeader(announceOfBlock17.ParentHash, diff --git a/dot/sync/service.go b/dot/sync/service.go index 11013ff9dc..70ced5a047 100644 --- a/dot/sync/service.go +++ b/dot/sync/service.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/peerset" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" @@ -88,12 +89,27 @@ type Change struct { type Strategy interface { OnBlockAnnounce(from peer.ID, msg *network.BlockAnnounceMessage) (repChange *Change, err error) OnBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error - NextActions() ([]*SyncTask, error) - Process(results []*SyncTaskResult) (done bool, repChanges []Change, blocks []peer.ID, err error) + NextActions() ([]Task, error) + Process(results <-chan TaskResult) (done bool, repChanges []Change, blocks []peer.ID, err error) ShowMetrics() IsSynced() bool } +type syncTask struct { + requestMaker network.RequestMaker + request messages.P2PMessage +} + +func (s *syncTask) ID() TaskID { + return TaskID(s.request.String()) +} + +func (s *syncTask) Do(p peer.ID) (Result, error) { + response := messages.BlockResponseMessage{} + err := s.requestMaker.Do(p, s.request, &response) + return &response, err +} + type SyncService struct { mu sync.Mutex wg sync.WaitGroup @@ -103,7 +119,7 @@ type SyncService struct { currentStrategy Strategy defaultStrategy Strategy - workerPool *syncWorkerPool + workerPool WorkerPool waitPeersDuration time.Duration minPeers int slotDuration time.Duration @@ -119,6 +135,11 @@ func NewSyncService(cfgs ...ServiceConfig) *SyncService { waitPeersDuration: waitPeersDefaultTimeout, stopCh: make(chan struct{}), seenBlockSyncRequests: lrucache.NewLRUCache[common.Hash, uint](100), + workerPool: NewWorkerPool(WorkerPoolConfig{ + MaxRetries: 5, + // TODO: This should depend on the actual configuration of the currently used sync strategy. + Capacity: defaultNumOfTasks * 10, + }), } for _, cfg := range cfgs { @@ -135,7 +156,7 @@ func (s *SyncService) waitWorkers() { } for { - total := s.workerPool.totalWorkers() + total := s.workerPool.NumPeers() if total >= s.minPeers { return } @@ -168,6 +189,7 @@ func (s *SyncService) Start() error { } func (s *SyncService) Stop() error { + s.workerPool.Shutdown() close(s.stopCh) s.wg.Wait() return nil @@ -175,7 +197,9 @@ func (s *SyncService) Stop() error { func (s *SyncService) HandleBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error { logger.Infof("receiving a block announce handshake from %s", from.String()) - if err := s.workerPool.fromBlockAnnounceHandshake(from); err != nil { + logger.Infof("len(s.workerPool.Results())=%d", len(s.workerPool.Results())) // TODO: remove + if err := s.workerPool.AddPeer(from); err != nil { + logger.Warnf("failed to add peer to worker pool: %s", err) return err } @@ -203,7 +227,7 @@ func (s *SyncService) HandleBlockAnnounce(from peer.ID, msg *network.BlockAnnoun func (s *SyncService) OnConnectionClosed(who peer.ID) { logger.Tracef("removing peer worker: %s", who.String()) - s.workerPool.removeWorker(who) + s.workerPool.RemovePeer(who) } func (s *SyncService) IsSynced() bool { @@ -253,19 +277,20 @@ func (s *SyncService) runStrategy() { finalisedHeader, err := s.blockState.GetHighestFinalisedHeader() if err != nil { - logger.Criticalf("getting highest finalized header: %w", err) + logger.Criticalf("getting highest finalized header: %s", err) return } bestBlockHeader, err := s.blockState.BestBlockHeader() if err != nil { - logger.Criticalf("getting best block header: %w", err) + logger.Criticalf("getting best block header: %s", err) return } logger.Infof( - "🚣 currently syncing, %d peers connected, finalized #%d (%s), best #%d (%s)", + "🚣 currently syncing, %d peers connected, %d peers in the worker pool, finalized #%d (%s), best #%d (%s)", len(s.network.AllConnectedPeersIDs()), + s.workerPool.NumPeers(), finalisedHeader.Number, finalisedHeader.Hash().Short(), bestBlockHeader.Number, @@ -283,8 +308,13 @@ func (s *SyncService) runStrategy() { return } - results := s.workerPool.submitRequests(tasks) - done, repChanges, peersToIgnore, err := s.currentStrategy.Process(results) + _, err = s.workerPool.SubmitBatch(tasks) + if err != nil { + logger.Criticalf("current sync strategy next actions failed with: %s", err.Error()) + return + } + + done, repChanges, peersToIgnore, err := s.currentStrategy.Process(s.workerPool.Results()) if err != nil { logger.Criticalf("current sync strategy failed with: %s", err.Error()) return @@ -295,7 +325,7 @@ func (s *SyncService) runStrategy() { } for _, block := range peersToIgnore { - s.workerPool.ignorePeerAsWorker(block) + s.workerPool.IgnorePeer(block) } s.currentStrategy.ShowMetrics() diff --git a/dot/sync/worker_pool.go b/dot/sync/worker_pool.go index b11b726db7..a85a33053b 100644 --- a/dot/sync/worker_pool.go +++ b/dot/sync/worker_pool.go @@ -1,191 +1,358 @@ -// Copyright 2023 ChainSafe Systems (ON) +// Copyright 2024 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only package sync import ( + "container/list" + "context" "errors" + "fmt" "sync" "time" - "github.com/ChainSafe/gossamer/dot/network" - "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/libp2p/go-libp2p/core/peer" - "golang.org/x/exp/maps" ) +const defaultWorkerPoolCapacity = 100 + var ( - ErrNoPeersToMakeRequest = errors.New("no peers to make requests") - ErrPeerIgnored = errors.New("peer ignored") + ErrNoPeers = errors.New("no peers available") + ErrPeerIgnored = errors.New("peer ignored") ) -const ( - punishmentBaseTimeout = 5 * time.Minute - maxRequestsAllowed uint = 3 -) +type TaskID string +type Result any -type SyncTask struct { - requestMaker network.RequestMaker - request messages.P2PMessage - response messages.P2PMessage +type Task interface { + ID() TaskID + Do(p peer.ID) (Result, error) } -type SyncTaskResult struct { - who peer.ID - completed bool - request messages.P2PMessage - response messages.P2PMessage +type TaskResult struct { + Task Task + Completed bool + Result Result + Error error + Retries uint + Who peer.ID } -type syncWorkerPool struct { - mtx sync.RWMutex +func (t TaskResult) Failed() bool { + return t.Error != nil +} - network Network - workers map[peer.ID]struct{} - ignorePeers map[peer.ID]struct{} +type BatchStatus struct { + Failed map[TaskID]TaskResult + Success map[TaskID]TaskResult } -func newSyncWorkerPool(net Network) *syncWorkerPool { - swp := &syncWorkerPool{ - network: net, - workers: make(map[peer.ID]struct{}), - ignorePeers: make(map[peer.ID]struct{}), +func (bs BatchStatus) Completed(todo int) bool { + if len(bs.Failed)+len(bs.Success) < todo { + return false } - return swp + for _, tr := range bs.Failed { + if !tr.Completed { + return false + } + } + + for _, tr := range bs.Success { + if !tr.Completed { + return false + } + } + + return true } -// fromBlockAnnounceHandshake stores the peer which send us a handshake as -// a possible source for requesting blocks/state/warp proofs -func (s *syncWorkerPool) fromBlockAnnounceHandshake(who peer.ID) error { - s.mtx.Lock() - defer s.mtx.Unlock() +type BatchID string + +type WorkerPool interface { + SubmitBatch(tasks []Task) (id BatchID, err error) + GetBatch(id BatchID) (status BatchStatus, ok bool) + Results() chan TaskResult + AddPeer(p peer.ID) error + RemovePeer(p peer.ID) + IgnorePeer(p peer.ID) + NumPeers() int + Shutdown() +} - if _, ok := s.ignorePeers[who]; ok { - return ErrPeerIgnored +type WorkerPoolConfig struct { + Capacity uint + MaxRetries uint +} + +// NewWorkerPool creates a new worker pool with the given configuration. +func NewWorkerPool(cfg WorkerPoolConfig) WorkerPool { + ctx, cancel := context.WithCancel(context.Background()) + + if cfg.Capacity == 0 { + cfg.Capacity = defaultWorkerPoolCapacity } - _, has := s.workers[who] - if has { - return nil + return &workerPool{ + maxRetries: cfg.MaxRetries, + ignoredPeers: make(map[peer.ID]struct{}), + statuses: make(map[BatchID]BatchStatus), + resChan: make(chan TaskResult, cfg.Capacity), + ctx: ctx, + cancel: cancel, } +} - s.workers[who] = struct{}{} - logger.Tracef("potential worker added, total in the pool %d", len(s.workers)) - return nil +type workerPool struct { + mtx sync.RWMutex + wg sync.WaitGroup + + maxRetries uint + peers list.List + ignoredPeers map[peer.ID]struct{} + statuses map[BatchID]BatchStatus + resChan chan TaskResult + ctx context.Context + cancel context.CancelFunc } -// submitRequests blocks until all tasks have been completed or there are no workers -// left in the pool to retry failed tasks -func (s *syncWorkerPool) submitRequests(tasks []*SyncTask) []*SyncTaskResult { - if len(tasks) == 0 { - return nil +// SubmitBatch accepts a list of tasks and immediately returns a batch ID. The batch ID can be used to query the status +// of the batch using [GetBatchStatus]. +// TODO +// If tasks are submitted faster than they are completed, resChan will run full, blocking the calling goroutine. +// Ideally this method would provide backpressure to the caller in that case. The rejected tasks should then stay in +// FullSyncStrategy.requestQueue until the next round. But this would need to be supported in all sync strategies. +func (w *workerPool) SubmitBatch(tasks []Task) (id BatchID, err error) { + w.mtx.Lock() + defer w.mtx.Unlock() + + bID := BatchID(fmt.Sprintf("%d", time.Now().UnixNano())) + + w.statuses[bID] = BatchStatus{ + Failed: make(map[TaskID]TaskResult), + Success: make(map[TaskID]TaskResult), } - s.mtx.RLock() - defer s.mtx.RUnlock() + w.wg.Add(1) + go func() { + defer w.wg.Done() + w.executeBatch(tasks, bID) + }() - pids := maps.Keys(s.workers) - workerPool := make(chan peer.ID, len(pids)) - for _, worker := range pids { - workerPool <- worker + return bID, nil +} + +// GetBatch returns the status of a batch previously submitted using [SubmitBatch]. +func (w *workerPool) GetBatch(id BatchID) (status BatchStatus, ok bool) { + w.mtx.RLock() + defer w.mtx.RUnlock() + + status, ok = w.statuses[id] + return +} + +// Results returns a channel that can be used to receive the results of completed tasks. +func (w *workerPool) Results() chan TaskResult { + return w.resChan +} + +// AddPeer adds a peer to the worker pool unless it has been ignored previously. +func (w *workerPool) AddPeer(who peer.ID) error { + w.mtx.Lock() + defer w.mtx.Unlock() + + if _, ok := w.ignoredPeers[who]; ok { + return ErrPeerIgnored + } + + for e := w.peers.Front(); e != nil; e = e.Next() { + if e.Value.(peer.ID) == who { + return nil + } } - failedTasks := make(chan *SyncTask, len(tasks)) - results := make(chan *SyncTaskResult, len(tasks)) + w.peers.PushBack(who) + logger.Tracef("peer added, total in the pool %d", w.peers.Len()) + return nil +} + +// RemovePeer removes a peer from the worker pool. +func (w *workerPool) RemovePeer(who peer.ID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + w.removePeer(who) +} + +// IgnorePeer removes a peer from the worker pool and prevents it from being added again. +func (w *workerPool) IgnorePeer(who peer.ID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + w.removePeer(who) + w.ignoredPeers[who] = struct{}{} +} + +// NumPeers returns the number of peers in the worker pool, both busy and free. +func (w *workerPool) NumPeers() int { + w.mtx.RLock() + defer w.mtx.RUnlock() - var wg sync.WaitGroup - for _, task := range tasks { - wg.Add(1) - go func(t *SyncTask) { - defer wg.Done() - executeTask(t, workerPool, failedTasks, results) - }(task) + return w.peers.Len() +} + +// Shutdown stops the worker pool and waits for all tasks to complete. +func (w *workerPool) Shutdown() { + w.cancel() + w.wg.Wait() +} + +func (w *workerPool) executeBatch(tasks []Task, bID BatchID) { + batchResults := make(chan TaskResult, len(tasks)) + + for _, t := range tasks { + w.wg.Add(1) + go func(t Task) { + defer w.wg.Done() + w.executeTask(t, batchResults) + }(t) } - wg.Add(1) - go func() { - defer wg.Done() - for task := range failedTasks { - if len(workerPool) > 0 { - wg.Add(1) - go func(t *SyncTask) { - defer wg.Done() - executeTask(t, workerPool, failedTasks, results) - }(task) + for { + select { + case <-w.ctx.Done(): + return + + case tr := <-batchResults: + if tr.Failed() { + w.handleFailedTask(tr, bID, batchResults) } else { - results <- &SyncTaskResult{ - completed: false, - request: task.request, - response: nil, - } + w.handleSuccessfulTask(tr, bID) } - } - }() - allResults := make(chan []*SyncTaskResult, 1) - wg.Add(1) - go func(expectedResults int) { - defer wg.Done() - var taskResults []*SyncTaskResult - - for result := range results { - taskResults = append(taskResults, result) - if len(taskResults) == expectedResults { - close(failedTasks) - break + if w.batchCompleted(bID, len(tasks)) { + return } } + } +} - allResults <- taskResults - }(len(tasks)) - - wg.Wait() - close(workerPool) - close(results) +func (w *workerPool) executeTask(task Task, ch chan TaskResult) { + if errors.Is(w.ctx.Err(), context.Canceled) { + logger.Tracef("[CANCELED] task=%s, shutting down", task.ID()) + return + } - return <-allResults -} + who, err := w.reservePeer() + if errors.Is(err, ErrNoPeers) { + logger.Tracef("no peers available for task=%s", task.ID()) + ch <- TaskResult{Task: task, Error: ErrNoPeers} + return + } -func executeTask(task *SyncTask, workerPool chan peer.ID, failedTasks chan *SyncTask, results chan *SyncTaskResult) { - worker := <-workerPool - logger.Infof("[EXECUTING] worker %s", worker) + logger.Infof("[EXECUTING] task=%s", task.ID()) - err := task.requestMaker.Do(worker, task.request, task.response) + result, err := task.Do(who) if err != nil { - logger.Infof("[ERR] worker %s, request: %s, err: %s", worker, task.request.String(), err.Error()) - failedTasks <- task + logger.Tracef("[FAILED] task=%s peer=%s, err=%s", task.ID(), who, err.Error()) } else { - logger.Infof("[FINISHED] worker %s, request: %s", worker, task.request.String()) - workerPool <- worker - results <- &SyncTaskResult{ - who: worker, - completed: true, - request: task.request, - response: task.response, + logger.Tracef("[FINISHED] task=%s peer=%s", task.ID(), who) + } + + w.mtx.Lock() + w.peers.PushBack(who) + w.mtx.Unlock() + + ch <- TaskResult{ + Task: task, + Who: who, + Result: result, + Error: err, + Retries: 0, + } +} + +func (w *workerPool) reservePeer() (who peer.ID, err error) { + w.mtx.Lock() + defer w.mtx.Unlock() + + peerElement := w.peers.Front() + + if peerElement == nil { + return who, ErrNoPeers + } + + w.peers.Remove(peerElement) + return peerElement.Value.(peer.ID), nil +} + +func (w *workerPool) removePeer(who peer.ID) { + var toRemove *list.Element + for e := w.peers.Front(); e != nil; e = e.Next() { + if e.Value.(peer.ID) == who { + toRemove = e + break } } + + if toRemove != nil { + w.peers.Remove(toRemove) + } } -func (s *syncWorkerPool) ignorePeerAsWorker(who peer.ID) { - s.mtx.Lock() - defer s.mtx.Unlock() +func (w *workerPool) handleSuccessfulTask(tr TaskResult, batchID BatchID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + tID := tr.Task.ID() + + if failedTr, ok := w.statuses[batchID].Failed[tID]; ok { + tr.Retries = failedTr.Retries + 1 + delete(w.statuses[batchID].Failed, tID) + } - delete(s.workers, who) - s.ignorePeers[who] = struct{}{} + tr.Completed = true + w.statuses[batchID].Success[tID] = tr + logger.Infof("handleSuccessfulTask(): len(w.resChan)=%d", len(w.resChan)) // TODO: remove + w.resChan <- tr } -func (s *syncWorkerPool) removeWorker(who peer.ID) { - s.mtx.Lock() - defer s.mtx.Unlock() +func (w *workerPool) handleFailedTask(tr TaskResult, batchID BatchID, batchResults chan TaskResult) { + w.mtx.Lock() + defer w.mtx.Unlock() + + tID := tr.Task.ID() + + if oldTr, ok := w.statuses[batchID].Failed[tID]; ok { + // It is only considered a retry if the task was actually executed. + if errors.Is(oldTr.Error, ErrNoPeers) { + // TODO Should we sleep a bit to wait for peers? + } else { + tr.Retries = oldTr.Retries + 1 + tr.Completed = tr.Retries >= w.maxRetries + } + } + + w.statuses[batchID].Failed[tID] = tr + + if tr.Completed { + logger.Infof("handleFailedTask(): len(w.resChan)=%d", len(w.resChan)) // TODO: remove + w.resChan <- tr + return + } - delete(s.workers, who) + // retry task + w.wg.Add(1) + go func() { + defer w.wg.Done() + w.executeTask(tr.Task, batchResults) + }() } -// totalWorkers only returns available or busy workers -func (s *syncWorkerPool) totalWorkers() (total int) { - s.mtx.RLock() - defer s.mtx.RUnlock() +func (w *workerPool) batchCompleted(id BatchID, todo int) bool { + w.mtx.Lock() + defer w.mtx.Unlock() - return len(s.workers) + b, ok := w.statuses[id] + return !ok || b.Completed(todo) } diff --git a/dot/sync/worker_pool_test.go b/dot/sync/worker_pool_test.go new file mode 100644 index 0000000000..676b787e78 --- /dev/null +++ b/dot/sync/worker_pool_test.go @@ -0,0 +1,268 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" +) + +type mockTask struct { + id TaskID + err error + execCount uint + succeedAfter uint +} + +func (m *mockTask) ID() TaskID { + return m.id +} + +func (m *mockTask) Do(p peer.ID) (Result, error) { + time.Sleep(time.Millisecond * 100) // simulate network roundtrip + defer func() { + m.execCount++ + }() + + res := Result(fmt.Sprintf("%s - %s great success!", m.id, p)) + if m.err != nil { + if m.succeedAfter > 0 && m.execCount >= m.succeedAfter { + return res, nil + } + return nil, m.err + } + return res, nil +} + +func (m *mockTask) String() string { + return fmt.Sprintf("mockTask %s", m.id) +} + +func makeTasksAndPeers(num, idOffset int) ([]Task, []peer.ID) { + tasks := make([]Task, num) + peers := make([]peer.ID, num) + + for i := 0; i < num; i++ { + tasks[i] = &mockTask{id: TaskID(fmt.Sprintf("t-%d", i+idOffset))} + peers[i] = peer.ID(fmt.Sprintf("p-%d", i+idOffset)) + } + return tasks, peers +} + +func waitForCompletion(wp WorkerPool, numTasks int) { + resultsReceived := 0 + + for { + <-wp.Results() + resultsReceived++ + + if resultsReceived == numTasks { + break + } + } +} + +func TestWorkerPoolHappyPath(t *testing.T) { + numTasks := 10 + + var setup = func() (WorkerPool, []Task) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + return wp, tasks + } + + t.Run("receive_results_on_channel", func(t *testing.T) { + wp, tasks := setup() + results := make([]TaskResult, 0, numTasks) + _, err := wp.SubmitBatch(tasks) + + assert.NoError(t, err) + + for { + result := <-wp.Results() + assert.True(t, result.Completed) + assert.False(t, result.Failed()) + assert.Equal(t, uint(0), result.Retries) + + results = append(results, result) + if len(results) == numTasks { + break + } + } + }) + + t.Run("check_batch_status_on_completion", func(t *testing.T) { + wp, tasks := setup() + batchID, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + waitForCompletion(wp, numTasks) + status, ok := wp.GetBatch(batchID) + + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + }) +} + +func TestWorkerPoolPeerHandling(t *testing.T) { + numTasks := 3 + + t.Run("accepts_batch_without_any_peers", func(t *testing.T) { + tasks, _ := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + _, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + wp.Shutdown() + }) + + t.Run("completes_batch_with_fewer_peers_than_tasks", func(t *testing.T) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + assert.NoError(t, wp.AddPeer(peers[0])) + assert.NoError(t, wp.AddPeer(peers[1])) + + bID, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + waitForCompletion(wp, numTasks) + status, ok := wp.GetBatch(bID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + }) + + t.Run("refuses_to_re_add_ignored_peer", func(t *testing.T) { + _, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + assert.Equal(t, len(peers), wp.NumPeers()) + + badPeer := peers[2] + wp.IgnorePeer(badPeer) + assert.Equal(t, len(peers)-1, wp.NumPeers()) + + err := wp.AddPeer(badPeer) + assert.ErrorIs(t, err, ErrPeerIgnored) + assert.Equal(t, len(peers)-1, wp.NumPeers()) + }) +} + +func TestWorkerPoolTaskFailures(t *testing.T) { + numTasks := 3 + taskErr := errors.New("kaput") + + setup := func(maxRetries uint) (failOnce *mockTask, failTwice *mockTask, batchID BatchID, wp WorkerPool) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + + failOnce = tasks[1].(*mockTask) + failOnce.err = taskErr + failOnce.succeedAfter = 1 + + failTwice = tasks[2].(*mockTask) + failTwice.err = taskErr + failTwice.succeedAfter = 2 + + wp = NewWorkerPool(WorkerPoolConfig{MaxRetries: maxRetries}) + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + var err error + batchID, err = wp.SubmitBatch(tasks) + assert.NoError(t, err) + return + } + + t.Run("retries_failed_tasks", func(t *testing.T) { + failOnce, failTwice, batchID, wp := setup(10) + waitForCompletion(wp, numTasks) + + status, ok := wp.GetBatch(batchID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + + assert.Nil(t, status.Failed[failOnce.ID()].Error) + assert.Equal(t, uint(1), status.Success[failOnce.ID()].Retries) + + assert.Nil(t, status.Failed[failTwice.ID()].Error) + assert.Equal(t, uint(2), status.Success[failTwice.ID()].Retries) + }) + + t.Run("honours_max_retries", func(t *testing.T) { + failOnce, failTwice, batchID, wp := setup(1) + waitForCompletion(wp, numTasks) + + status, ok := wp.GetBatch(batchID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks-1, len(status.Success)) + assert.Equal(t, 1, len(status.Failed)) + + assert.Nil(t, status.Failed[failOnce.ID()].Error) + assert.Equal(t, uint(1), status.Success[failOnce.ID()].Retries) + + assert.ErrorIs(t, taskErr, status.Failed[failTwice.ID()].Error) + assert.Equal(t, uint(1), status.Failed[failTwice.ID()].Retries) + }) +} + +func TestWorkerPoolMultipleBatches(t *testing.T) { + b1NumTasks := 10 + b2NumTasks := 12 + + t.Run("completes_all_batches", func(t *testing.T) { + b1Tasks, b1Peers := makeTasksAndPeers(b1NumTasks, 0) + b2Tasks, b2Peers := makeTasksAndPeers(b2NumTasks, b1NumTasks) + peers := append(b1Peers, b2Peers...) + + wp := NewWorkerPool(WorkerPoolConfig{}) + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + b1ID, err := wp.SubmitBatch(b1Tasks) + assert.NoError(t, err) + + b2ID, err := wp.SubmitBatch(b2Tasks) + assert.NoError(t, err) + + waitForCompletion(wp, b1NumTasks+b2NumTasks) + + b1Status, ok := wp.GetBatch(b1ID) + assert.True(t, ok) + assert.True(t, b1Status.Completed(b1NumTasks)) + assert.Equal(t, b1NumTasks, len(b1Status.Success)) + assert.Equal(t, 0, len(b1Status.Failed)) + + b2Status, ok := wp.GetBatch(b2ID) + assert.True(t, ok) + assert.True(t, b2Status.Completed(b2NumTasks)) + assert.Equal(t, b2NumTasks, len(b2Status.Success)) + assert.Equal(t, 0, len(b2Status.Failed)) + }) +}