diff --git a/api.go b/api.go index bc32c5d..bcd2154 100644 --- a/api.go +++ b/api.go @@ -57,13 +57,12 @@ type Storage interface { } type Communication interface { + // Send sends a message to the given destination node + Send(msg *Message, destination NodeID) // Nodes returns all nodes that participate in the epoch. Nodes() []NodeID - // Send sends a message to the given destination node - Send(msg *Message, destination NodeID) - // Broadcast broadcasts the given message to all nodes. // Does not send it to yourself. Broadcast(msg *Message) diff --git a/epoch.go b/epoch.go index 830027e..b380200 100644 --- a/epoch.go +++ b/epoch.go @@ -6,12 +6,10 @@ package simplex import ( "bytes" "context" - "crypto/rand" "encoding/binary" "errors" "fmt" "math" - "math/big" "slices" "sync" "sync/atomic" @@ -25,14 +23,13 @@ import ( var ErrAlreadyStarted = errors.New("epoch already started") const ( - DefaultMaxRoundWindow = 10 - DefaultMaxPendingBlocks = 20 - - DefaultMaxProposalWaitTime = 5 * time.Second - DefaultReplicationRequestTimeout = 5 * time.Second - DefaultEmptyVoteRebroadcastTimeout = 5 * time.Second - DefaultFinalizeVoteRebroadcastTimeout = 6 * time.Second - EmptyVoteTimeoutID = "rebroadcast_empty_vote" + DefaultMaxRoundWindow = 10 + DefaultMaxPendingBlocks = 20 + DefaultMaxProposalWaitTime = 5 * time.Second + DefaultReplicationRequestTimeout = 5 * time.Second + DefaultEmptyVoteRebroadcastTimeout = 5 * time.Second + DefaultFinalizeVoteRebroadcastTimeout = 6 * time.Second + EmptyVoteTimeoutID uint64 = 1 ) type EmptyVoteSet struct { @@ -199,7 +196,7 @@ func (e *Epoch) init() error { e.eligibleNodeIDs = make(map[string]struct{}, len(e.nodes)) e.futureMessages = make(messagesFromNode, len(e.nodes)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.maxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock) - e.timeoutHandler = NewTimeoutHandler(e.Logger, e.StartTime, e.nodes) + e.timeoutHandler = NewTimeoutHandler(e.Logger, e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) for _, node := range e.nodes { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) @@ -620,7 +617,8 @@ func (e *Epoch) handleFinalizationForPendingOrFutureRound(message *Finalization, e.Logger.Debug("We are the leader of this round, but a higher round has been finalized. Aborting block building.") e.blockBuilderCancelFunc() } - e.replicationState.replicateBlocks(message, nextSeqToCommit) + + e.replicationState.receivedFutureFinalization(message, nextSeqToCommit) } func (e *Epoch) handleFinalizeVoteMessage(message *FinalizeVote, from NodeID) error { @@ -702,7 +700,7 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { vote := message.Vote e.Logger.Verbo("Received empty vote message", - zap.Stringer("from", from), zap.Uint64("round", vote.Round)) + zap.Stringer("from", from), zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) // Only process point to point empty votes. // A node will never need to forward to us someone else's vote. @@ -716,6 +714,12 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { e.Logger.Debug("Got empty vote from a past round", zap.Uint64("round", vote.Round), zap.Uint64("my round", e.round), zap.Stringer("from", from)) + // if this node has sent us an empty vote for a past round, it may be behind + // send it both the latest finalization and the highest round to help it catch up and initiate the replication process + e.sendLatestFinalization(from) + e.sendHighestRound(from) + + // also send the notarization or finalization for this round as well e.maybeSendNotarizationOrFinalization(from, vote.Round) return nil } @@ -731,7 +735,7 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { // Else, this is an empty vote for current round e.Logger.Debug("Received an empty vote for the current round", - zap.Uint64("round", vote.Round), zap.Stringer("from", from)) + zap.Uint64("round", vote.Round), zap.Stringer("from", from), zap.Bool("isReplicationDone", e.replicationState.isReplicationComplete(e.nextSeqToCommit(), e.round))) signature := message.Signature @@ -749,10 +753,60 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { return e.maybeAssembleEmptyNotarization() } +func (e *Epoch) sendLatestFinalization(to NodeID) { + if e.lastBlock == nil { + e.Logger.Debug("No blocks committed yet, cannot send latest block", zap.Stringer("to", to)) + return + } + + msg := &Message{ + Finalization: &e.lastBlock.Finalization, + } + e.Logger.Debug("Node appears behind, sending them the latest block", zap.Stringer("to", to), zap.Uint64("round", e.lastBlock.VerifiedBlock.BlockHeader().Round)) + e.Comm.Send(msg, to) +} + +func (e *Epoch) sendHighestRound(to NodeID) { + latestQR := e.getLatestVerifiedQuorumRound() + + if latestQR == nil { + e.Logger.Debug("Cannot send latest round because there is none", zap.Stringer("to", to)) + return + } + + if latestQR.Notarization != nil { + msg := &Message{ + Notarization: latestQR.Notarization, + } + e.Logger.Debug("Node appears behind, sending them the highest round", zap.Stringer("to", to), zap.Uint64("round", latestQR.Notarization.Vote.Round)) + e.Comm.Send(msg, to) + return + } + + if latestQR.EmptyNotarization != nil { + msg := &Message{ + EmptyNotarization: latestQR.EmptyNotarization, + } + e.Logger.Debug("Node appears behind, sending them the highest empty notarized round", zap.Stringer("to", to), zap.Uint64("round", latestQR.EmptyNotarization.Vote.Round)) + e.Comm.Send(msg, to) + return + } +} + +// send notarization or finalization for this round as well func (e *Epoch) maybeSendNotarizationOrFinalization(to NodeID, round uint64) { r, ok := e.rounds[round] if !ok { + // round could be an empty notarized round + evs, ok := e.emptyVotes[round] + if ok && evs.emptyNotarization != nil { + msg := &Message{ + EmptyNotarization: evs.emptyNotarization, + } + e.Logger.Debug("Node appears behind, sending them an empty notarization", zap.Stringer("to", to), zap.Uint64("round", round)) + e.Comm.Send(msg, to) + } return } @@ -999,7 +1053,7 @@ func (e *Epoch) persistFinalization(finalization Finalization) error { // we receive a finalization for a future round e.Logger.Debug("Received a finalization for a future sequence", zap.Uint64("seq", finalization.Finalization.Seq), zap.Uint64("nextSeqToCommit", nextSeqToCommit)) - e.replicationState.replicateBlocks(&finalization, nextSeqToCommit) + e.replicationState.receivedFutureFinalization(&finalization, nextSeqToCommit) if err := e.rebroadcastPastFinalizeVotes(); err != nil { return err @@ -1362,16 +1416,27 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat e.Logger.Verbo("Received empty notarization message", zap.Uint64("round", vote.Round)) - if e.isRoundTooFarAhead(vote.Round) { - e.Logger.Debug("Received an empty notarization for a too high round", - zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.quorumSize, e.eligibleNodeIDs, emptyNotarization, from); err != nil { return nil } // Ignore votes for previous rounds - if !e.isVoteRoundValid(vote.Round) { - e.Logger.Debug("Empty notarization is invalid", + if vote.Round < e.round { + e.Logger.Debug("Received an empty notarization for a past round", zap.Uint64("round", vote.Round), zap.Uint64("my round", e.round)) + return nil + } + + if e.round < vote.Round { + e.Logger.Debug("Received an empty notarization for a higher round", zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + + e.replicationState.receivedFutureRound(vote.Round, emptyNotarization.QC.Signers(), e.round) + + // store in future state if within max round window + if e.isWithinMaxRoundWindow(vote.Round) { + emptyVotes := e.getOrCreateEmptyVoteSetForRound(vote.Round) + emptyVotes.emptyNotarization = emptyNotarization + } return nil } @@ -1384,11 +1449,7 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat return nil } - // Otherwise, this round is not notarized or finalized yet, so verify the empty notarization and store it. - if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.quorumSize, e.eligibleNodeIDs, emptyNotarization, from); err != nil { - return nil - } - + // Otherwise, this round is not notarized or finalized yet, so store it. emptyVotes := e.getOrCreateEmptyVoteSetForRound(vote.Round) emptyVotes.emptyNotarization = emptyNotarization if e.round != vote.Round { @@ -1407,11 +1468,17 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er e.Logger.Verbo("Received notarization message", zap.Stringer("from", from), zap.Uint64("round", vote.Round)) - if !e.isVoteRoundValid(vote.Round) { + if err := VerifyQC(message.QC, e.Logger, "Notarization", e.quorumSize, e.eligibleNodeIDs, message, from); err != nil { return nil } - if err := VerifyQC(message.QC, e.Logger, "Notarization", e.quorumSize, e.eligibleNodeIDs, message, from); err != nil { + if e.round < vote.Round { + e.Logger.Debug("Received a notarization for a future round", + zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + e.replicationState.receivedFutureRound(vote.Round, message.QC.Signers(), e.round) + } + + if !e.isVoteRoundValid(vote.Round) { return nil } @@ -1651,6 +1718,18 @@ func (e *Epoch) processNotarizedBlock(block Block, notarization *Notarization) e return nil } + finalizeVote, finalizeVoteMsg, err := e.constructFinalizeVoteMessage(md) + if err != nil { + e.Logger.Warn("Failed to construct finalize vote message", zap.Error(err)) + return err + } + e.Comm.Broadcast(finalizeVoteMsg) + + if err := e.handleFinalizeVoteMessage(&finalizeVote, e.ID); err != nil { + e.Logger.Warn("Failed to handle finalize vote message", zap.Error(err)) + return err + } + return e.processReplicationState() } @@ -1779,15 +1858,11 @@ func (e *Epoch) createFinalizedBlockVerificationTask(block Block, finalization F if err != nil { e.Logger.Debug("Failed verifying block", zap.Error(err)) // if we fail to verify the block, we re-add to request timeout - numSigners := int64(len(finalization.QC.Signers())) - index, err := rand.Int(rand.Reader, big.NewInt(numSigners)) + err = e.replicationState.resendFinalizationRequest(md.Seq, finalization.QC.Signers()) if err != nil { e.haltedError = err - e.Logger.Debug("Failed to generate random index", zap.Error(err)) - return md.Digest + e.Logger.Debug("Failed to resend finalization", zap.Error(err)) } - - e.replicationState.sendRequestToNode(md.Seq, md.Seq, finalization.QC.Signers(), int(index.Int64())) return md.Digest } @@ -1834,6 +1909,7 @@ func (e *Epoch) createNotarizedBlockVerificationTask(block Block, notarization N verifiedBlock, err := block.Verify(context.Background()) if err != nil { e.Logger.Debug("Failed verifying block", zap.Error(err)) + // TODO: if we fail to verify the block, we should re-request it from the replication state return md.Digest } @@ -1865,6 +1941,23 @@ func (e *Epoch) createNotarizedBlockVerificationTask(block Block, notarization N if err := e.persistNotarization(notarization); err != nil { e.haltedError = err + e.Logger.Error("Failed to persist notarization", zap.Error(err)) + return md.Digest + } + + // create finalized votes for notarizations we process during replication + finalizeVote, finalizeVoteMsg, err := e.constructFinalizeVoteMessage(md) + if err != nil { + e.haltedError = err + e.Logger.Error("Failed to construct finalize vote message", zap.Error(err)) + return md.Digest + } + e.Comm.Broadcast(finalizeVoteMsg) + + if err := e.handleFinalizeVoteMessage(&finalizeVote, e.ID); err != nil { + e.haltedError = err + e.Logger.Error("Failed to handle finalize vote message", zap.Error(err)) + return md.Digest } err = e.processReplicationState() @@ -2234,7 +2327,7 @@ func (e *Epoch) triggerEmptyBlockNotarization(round uint64) { e.Comm.Broadcast(&Message{EmptyVoteMessage: &signedEV}) - e.addEmptyVoteRebroadcastTimeout(&signedEV) + e.addEmptyVoteRebroadcastTimeout() if err := e.maybeAssembleEmptyNotarization(); err != nil { e.Logger.Error("Failed assembling empty notarization", zap.Error(err)) @@ -2242,23 +2335,32 @@ func (e *Epoch) triggerEmptyBlockNotarization(round uint64) { } } -func (e *Epoch) addEmptyVoteRebroadcastTimeout(vote *EmptyVote) { - task := &TimeoutTask{ - NodeID: e.ID, - TaskID: EmptyVoteTimeoutID, - Deadline: e.timeoutHandler.GetTime().Add(e.EpochConfig.MaxRebroadcastWait), - Task: func() { - e.Logger.Debug("Rebroadcasting empty vote because round has not advanced", zap.Uint64("round", vote.Vote.Round)) - e.Comm.Broadcast(&Message{EmptyVoteMessage: vote}) - e.addEmptyVoteRebroadcastTimeout(vote) - }, +func (e *Epoch) emptyVoteTimeoutTaskRunner(_ []uint64) { + e.lock.Lock() + roundVotes, ok := e.emptyVotes[e.round] + e.lock.Unlock() + + if !ok { + e.Logger.Debug("No empty vote set found to rebroadcast, yet expected to rebroadcast", zap.Uint64("round", e.round)) + return + } + + ourVote, voted := roundVotes.votes[string(e.ID)] + if !voted { + e.Logger.Debug("Our empty vote not found in the set to rebroadcast, yet expected to rebroadcast", zap.Uint64("round", e.round)) + return } - e.timeoutHandler.AddTask(task) + e.Logger.Debug("Rebroadcasting empty vote because round has not advanced", zap.Uint64("round", ourVote.Vote.Round)) + e.Comm.Broadcast(&Message{EmptyVoteMessage: ourVote}) +} + +func (e *Epoch) addEmptyVoteRebroadcastTimeout() { + e.timeoutHandler.AddTask(EmptyVoteTimeoutID) } func (e *Epoch) monitorProgress(round uint64) { - e.Logger.Debug("Monitoring progress", zap.Uint64("round", round)) + e.Logger.Debug("Monitoring progress", zap.Uint64("round", round), zap.Uint64("currentRound", e.round)) ctx, cancelContext := context.WithCancel(context.Background()) noop := func() {} @@ -2453,11 +2555,18 @@ func (e *Epoch) deleteRounds(round uint64) { } } -func (e *Epoch) deleteEmptyVoteForPreviousRound() { - if e.round == 0 { +// maybeDeleteEmptyVotes deletes all previous empty votes if the current round was notarized or finalized. +func (e *Epoch) maybeDeleteEmptyVotes() { + round, ok := e.rounds[e.round] + if !ok || (round.notarization == nil && round.finalization == nil) { return } - delete(e.emptyVotes, e.round-1) + + for r := range e.emptyVotes { + if r < e.round { + delete(e.emptyVotes, r) + } + } } func (e *Epoch) increaseRound() { @@ -2469,8 +2578,8 @@ func (e *Epoch) increaseRound() { e.blockBuilderCancelFunc() // remove the rebroadcast empty vote task - e.timeoutHandler.RemoveTask(e.ID, EmptyVoteTimeoutID) - e.deleteEmptyVoteForPreviousRound() + e.timeoutHandler.RemoveTask(EmptyVoteTimeoutID) + e.maybeDeleteEmptyVotes() prevLeader := LeaderForRound(e.nodes, e.round) nextLeader := LeaderForRound(e.nodes, e.round+1) @@ -2650,28 +2759,53 @@ func (e *Epoch) handleReplicationRequest(req *ReplicationRequest, from NodeID) e } response := &VerifiedReplicationResponse{} - latestRound := e.getLatestVerifiedQuorumRound() - - if latestRound != nil && latestRound.GetRound() > req.LatestRound { - response.LatestRound = latestRound + if req.LatestRound > 0 { + latestRound := e.getLatestVerifiedQuorumRound() + if latestRound != nil && latestRound.GetRound() > req.LatestRound { + response.LatestRound = latestRound + } + } + if req.LatestFinalizedSeq > 0 { + if e.lastBlock != nil && e.lastBlock.Finalization.Finalization.Seq > req.LatestFinalizedSeq { + response.LatestFinalizedSeq = &VerifiedQuorumRound{ + VerifiedBlock: e.lastBlock.VerifiedBlock, + Finalization: &e.lastBlock.Finalization, + } + } } seqs := req.Seqs slices.Sort(seqs) - data := make([]VerifiedQuorumRound, len(seqs)) + seqData := make([]VerifiedQuorumRound, len(seqs)) for i, seq := range seqs { quorumRound := e.locateQuorumRecord(seq) if quorumRound == nil { // since we are sorted, we can break early - data = data[:i] + seqData = seqData[:i] break } - data[i] = *quorumRound + seqData[i] = *quorumRound } + rounds := req.Rounds + roundData := make([]VerifiedQuorumRound, 0, len(rounds)) + slices.Sort(rounds) + for _, roundNum := range rounds { + quorumRound := e.locateQuorumRecordByRound(roundNum) + if quorumRound == nil { + // we cannot break early since empty votes may + continue + } + roundData = append(roundData, *quorumRound) + } + + data := make([]VerifiedQuorumRound, 0, len(seqData)+len(roundData)) + data = append(data, seqData...) + data = append(data, roundData...) response.Data = data - if len(data) == 0 && response.LatestRound == nil { + + if len(data) == 0 && response.LatestRound == nil && response.LatestFinalizedSeq == nil { e.Logger.Debug("No data found for replication request", zap.Stringer("from", from)) return nil } @@ -2728,6 +2862,32 @@ func (e *Epoch) locateQuorumRecord(seq uint64) *VerifiedQuorumRound { } } +// if this round is storage, we do not need to retrieve it from storage +func (e *Epoch) locateQuorumRecordByRound(targetRound uint64) *VerifiedQuorumRound { + for _, round := range e.rounds { + blockRound := round.block.BlockHeader().Round + if blockRound == targetRound { + if round.finalization != nil || round.notarization != nil { + return &VerifiedQuorumRound{ + VerifiedBlock: round.block, + Finalization: round.finalization, + Notarization: round.notarization, + } + } + } + } + + // check if the round is empty notarized + emptyVoteForRound, exists := e.emptyVotes[targetRound] + if exists && emptyVoteForRound.emptyNotarization != nil { + return &VerifiedQuorumRound{ + EmptyNotarization: emptyVoteForRound.emptyNotarization, + } + } + + return nil +} + func (e *Epoch) haveNotFinalizedNotarizedRound() (uint64, bool) { e.lock.Lock() defer e.lock.Unlock() @@ -2751,42 +2911,36 @@ func (e *Epoch) handleReplicationResponse(resp *ReplicationResponse, from NodeID return nil } - e.Logger.Debug("Received replication response", zap.Stringer("from", from), zap.Int("num seqs", len(resp.Data)), zap.Stringer("latest round", resp.LatestRound)) + e.Logger.Debug("Received replication response", zap.Stringer("from", from), zap.Int("num seqs", len(resp.Data)), zap.Stringer("latest round", resp.LatestRound), zap.Stringer("latest seq", resp.LatestSeq)) nextSeqToCommit := e.nextSeqToCommit() - validRounds := make([]QuorumRound, 0, len(resp.Data)) for _, data := range resp.Data { - if err := data.IsWellFormed(); err != nil { - e.Logger.Debug("Malformed Quorum Round Received", zap.Error(err)) - continue - } - - if data.EmptyNotarization == nil && nextSeqToCommit > data.GetSequence() { - e.Logger.Debug("Received quorum round for a seq that is too far behind", zap.Uint64("seq", data.GetSequence())) - continue - } - - if data.GetSequence() > nextSeqToCommit+e.maxRoundWindow { + if data.Finalization != nil && data.GetSequence() > nextSeqToCommit+e.maxRoundWindow { e.Logger.Debug("Received quorum round for a seq that is too far ahead", zap.Uint64("seq", data.GetSequence())) // we are too far behind, we should ignore this message continue } - if err := e.verifyQuorumRound(data, from); err != nil { - e.Logger.Debug("Received invalid quorum round", zap.Uint64("seq", data.GetSequence()), zap.Stringer("from", from)) + // TODO: if empty notarizations occur for long periods, we may receive a nextSeqToCommit that has a round considered too far ahead. + // For now we allow only the nextSeqToCommit but we might want to accept a few more seqs ahead regardless of round. + if data.GetRound() > e.round+e.maxRoundWindow && data.GetSequence() != nextSeqToCommit { + e.Logger.Debug("Received quorum round for a round that is too far ahead", zap.Uint64("round", data.GetRound())) + // we are too far behind, we should ignore this message continue } - validRounds = append(validRounds, data) - e.replicationState.StoreQuorumRound(data) + if err := e.processQuorumRound(&data, from); err != nil { + e.Logger.Debug("Failed processing quorum round", zap.Error(err)) + } } - if err := e.processLatestRoundReceived(resp.LatestRound, from); err != nil { + if err := e.processQuorumRound(resp.LatestRound, from); err != nil { e.Logger.Debug("Failed processing latest round", zap.Error(err)) - return nil } - e.replicationState.receivedReplicationResponse(validRounds, from) + if err := e.processQuorumRound(resp.LatestSeq, from); err != nil { + e.Logger.Debug("Failed processing latest seq", zap.Error(err)) + } return e.processReplicationState() } @@ -2832,27 +2986,35 @@ func (e *Epoch) processEmptyNotarization(emptyNotarization *EmptyNotarization) e return e.processReplicationState() } -func (e *Epoch) processLatestRoundReceived(latestRound *QuorumRound, from NodeID) error { +// processQuorumRound processes a quorum round received from another node. +// It verifies the quorum round and stores it in the replication state if valid. +func (e *Epoch) processQuorumRound(latestRound *QuorumRound, from NodeID) error { if latestRound == nil { return nil } + nextSeqToCommit := e.nextSeqToCommit() + if latestRound.EmptyNotarization == nil && nextSeqToCommit > latestRound.GetSequence() { + return fmt.Errorf("quorum round too far behind: %d > %d", nextSeqToCommit, latestRound.GetSequence()) + } + // make sure the latest round is well formed if err := latestRound.IsWellFormed(); err != nil { - e.Logger.Debug("Received invalid latest round", zap.Error(err)) - return err + return fmt.Errorf("received malformed latest round: %w", err) } if err := e.verifyQuorumRound(*latestRound, from); err != nil { - e.Logger.Debug("Received invalid latest round", zap.Error(err)) - return err + return fmt.Errorf("failed verifying latest round: %w", err) } - e.replicationState.StoreQuorumRound(*latestRound) + e.replicationState.storeQuorumRound(*latestRound, from) return nil } func (e *Epoch) processReplicationState() error { + // We might have advanced the rounds from non-replicating paths such as future messages. clean up replication map accordingly. + e.replicationState.maybeAdvancedState(e.nextSeqToCommit(), e.round) + nextSeqToCommit := e.nextSeqToCommit() // check if we are done replicating and should start a new round @@ -2863,30 +3025,25 @@ func (e *Epoch) processReplicationState() error { return e.startRound() } - e.replicationState.maybeCollectFutureSequences(e.nextSeqToCommit()) - // first we check if we can commit the next sequence, it is ok to try and commit the next sequence // directly, since if there are any empty notarizations, `indexFinalization` will // increment the round properly. - block, finalization, exists := e.replicationState.GetFinalizedBlockForSequence(nextSeqToCommit) + block, finalization, exists := e.replicationState.getFinalizedBlockForSequence(nextSeqToCommit) if exists { - delete(e.replicationState.receivedQuorumRounds, block.BlockHeader().Round) return e.processFinalizedBlock(block, finalization) } - qRound, ok := e.replicationState.receivedQuorumRounds[e.round] - if ok && qRound.Notarization != nil { + qRound := e.replicationState.getNonFinalizedQuorumRound(e.round) + if qRound != nil && qRound.Notarization != nil { if qRound.Finalization != nil { - e.Logger.Debug("Delaying processing a QuorumRound that has an Finalization != NextSeqToCommit", zap.Stringer("QuorumRound", &qRound)) + e.Logger.Debug("Delaying processing a QuorumRound that has an Finalization != NextSeqToCommit", zap.Stringer("QuorumRound", qRound)) return nil } - delete(e.replicationState.receivedQuorumRounds, e.round) return e.processNotarizedBlock(qRound.Block, qRound.Notarization) } // the current round is an empty notarization - if ok && qRound.EmptyNotarization != nil { - delete(e.replicationState.receivedQuorumRounds, qRound.GetRound()) + if qRound != nil && qRound.EmptyNotarization != nil { return e.processEmptyNotarization(qRound.EmptyNotarization) } @@ -2914,11 +3071,12 @@ func (e *Epoch) maybeAdvanceRoundFromEmptyNotarizations() (bool, error) { round := e.round expectedSeq := e.metadata().Seq - nextSeqQuorum := e.replicationState.GetQuorumRoundWithSeq(expectedSeq) - if nextSeqQuorum != nil { + block, exists := e.replicationState.getBlockWithSeq(expectedSeq) + if exists { + bh := block.BlockHeader() // num empty notarizations - if round < nextSeqQuorum.GetRound() { - for range nextSeqQuorum.GetRound() - round { + if round < bh.Round { + for range bh.Round - round { e.increaseRound() } return true, nil @@ -2967,7 +3125,6 @@ func (e *Epoch) getLatestVerifiedQuorumRound() *VerifiedQuorumRound { return GetLatestVerifiedQuorumRound( e.getHighestRound(), e.getHighestEmptyNotarization(), - e.lastBlock, ) } diff --git a/msg.go b/msg.go index 9489e8e..b20bf67 100644 --- a/msg.go +++ b/msg.go @@ -211,18 +211,22 @@ type QuorumCertificate interface { } type ReplicationRequest struct { - Seqs []uint64 // sequences we are requesting - LatestRound uint64 // latest round that we are aware of + Seqs []uint64 // sequences we are requesting + Rounds []uint64 // rounds we are requesting + LatestRound uint64 // latest round that we are aware of + LatestFinalizedSeq uint64 // latest finalized sequence that we are aware of } type ReplicationResponse struct { Data []QuorumRound LatestRound *QuorumRound + LatestSeq *QuorumRound } type VerifiedReplicationResponse struct { - Data []VerifiedQuorumRound - LatestRound *VerifiedQuorumRound + Data []VerifiedQuorumRound + LatestRound *VerifiedQuorumRound + LatestFinalizedSeq *VerifiedQuorumRound } // QuorumRound represents a round that has achieved quorum on either @@ -302,7 +306,7 @@ func (q *QuorumRound) String() string { if err != nil { return fmt.Sprintf("QuorumRound{Error: %s}", err) } else { - return fmt.Sprintf("QuorumRound{Round: %d, Seq: %d}", q.GetRound(), q.GetSequence()) + return fmt.Sprintf("QuorumRound{Round: %d, Seq: %d, Finalized: %t}", q.GetRound(), q.GetSequence(), q.Finalization != nil) } } diff --git a/replication.go b/replication.go deleted file mode 100644 index fb5d2d3..0000000 --- a/replication.go +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package simplex - -import ( - "fmt" - "math" - "slices" - "sync" - "time" - - "go.uber.org/zap" -) - -// signedSequence is a sequence that has been signed by a quorum certificate. -// it essentially is a quorum round without the enforcement of needing a block with a -// finalization or notarization. -type signedSequence struct { - seq uint64 - signers NodeIDs -} - -func newSignedSequenceFromRound(round QuorumRound) (*signedSequence, error) { - ss := &signedSequence{} - switch { - case round.Finalization != nil: - ss.signers = round.Finalization.QC.Signers() - ss.seq = round.Finalization.Finalization.Seq - case round.Notarization != nil: - ss.signers = round.Notarization.QC.Signers() - ss.seq = round.Notarization.Vote.Seq - case round.EmptyNotarization != nil: - return nil, fmt.Errorf("should not create signed sequence from empty notarization") - default: - return nil, fmt.Errorf("round does not contain a finalization, empty notarization, or notarization") - } - - return ss, nil -} - -type ReplicationState struct { - lock *sync.Mutex - logger Logger - enabled bool - maxRoundWindow uint64 - comm Communication - id NodeID - - // latest seq requested - lastSequenceRequested uint64 - - // highest sequence we have received - highestSequenceObserved *signedSequence - - // receivedQuorumRounds maps rounds to quorum rounds - receivedQuorumRounds map[uint64]QuorumRound - - // request iterator - requestIterator int - - timeoutHandler *TimeoutHandler -} - -func NewReplicationState(logger Logger, comm Communication, id NodeID, maxRoundWindow uint64, enabled bool, start time.Time, lock *sync.Mutex) *ReplicationState { - return &ReplicationState{ - lock: lock, - logger: logger, - enabled: enabled, - comm: comm, - id: id, - maxRoundWindow: maxRoundWindow, - receivedQuorumRounds: make(map[uint64]QuorumRound), - timeoutHandler: NewTimeoutHandler(logger, start, comm.Nodes()), - } -} - -func (r *ReplicationState) AdvanceTime(now time.Time) { - r.timeoutHandler.Tick(now) -} - -// isReplicationComplete returns true if we have finished the replication process. -// The process is considered finished once [currentRound] has caught up to the highest round received. -func (r *ReplicationState) isReplicationComplete(nextSeqToCommit uint64, currentRound uint64) bool { - if r.highestSequenceObserved == nil { - return true - } - - return nextSeqToCommit > r.highestSequenceObserved.seq && currentRound > r.highestKnownRound() -} - -func (r *ReplicationState) collectMissingSequences(observedSignedSeq *signedSequence, nextSeqToCommit uint64) { - observedSeq := observedSignedSeq.seq - // Node is behind, but we've already sent messages to collect future finalizations - if r.lastSequenceRequested >= observedSeq && r.highestSequenceObserved != nil { - return - } - - if r.highestSequenceObserved == nil || observedSeq > r.highestSequenceObserved.seq { - r.highestSequenceObserved = observedSignedSeq - } - - startSeq := math.Max(float64(nextSeqToCommit), float64(r.lastSequenceRequested)) - // Don't exceed the max round window - endSeq := math.Min(float64(observedSeq), float64(r.maxRoundWindow+nextSeqToCommit)) - - r.logger.Debug("Node is behind, requesting missing finalizations", zap.Uint64("seq", observedSeq), zap.Uint64("startSeq", uint64(startSeq)), zap.Uint64("endSeq", uint64(endSeq))) - r.sendReplicationRequests(uint64(startSeq), uint64(endSeq)) -} - -// sendReplicationRequests sends requests for missing sequences for the -// range of sequences [start, end] <- inclusive. It does so by splitting the -// range of sequences equally amount the nodes that have signed [highestSequenceObserved]. -func (r *ReplicationState) sendReplicationRequests(start uint64, end uint64) { - // it's possible our node has signed [highestSequenceObserved]. - // For example this may happen if our node has sent a finalization - // for [highestSequenceObserved] and has not received the - // finalization from the network. - nodes := r.highestSequenceObserved.signers.Remove(r.id) - numNodes := len(nodes) - - seqRequests := DistributeSequenceRequests(start, end, numNodes) - - r.logger.Debug("Distributing replication requests", zap.Uint64("start", start), zap.Uint64("end", end), zap.Stringer("nodes", NodeIDs(nodes))) - for i, seqs := range seqRequests { - index := (i + r.requestIterator) % numNodes - r.sendRequestToNode(seqs.Start, seqs.End, nodes, index) - } - - r.lastSequenceRequested = end - // next time we send requests, we start with a different permutation - r.requestIterator++ -} - -// sendRequestToNode requests the sequences [start, end] from nodes[index]. -// In case the nodes[index] does not respond, we create a timeout that will -// re-send the request. -func (r *ReplicationState) sendRequestToNode(start uint64, end uint64, nodes []NodeID, index int) { - r.logger.Debug("Requesting missing finalizations ", - zap.Stringer("from", nodes[index]), - zap.Uint64("start", start), - zap.Uint64("end", end)) - seqs := make([]uint64, (end+1)-start) - for i := start; i <= end; i++ { - seqs[i-start] = i - } - request := &ReplicationRequest{ - Seqs: seqs, - LatestRound: r.highestSequenceObserved.seq, - } - msg := &Message{ReplicationRequest: request} - - task := r.createReplicationTimeoutTask(start, end, nodes, index) - - r.timeoutHandler.AddTask(task) - - r.comm.Send(msg, nodes[index]) -} - -func (r *ReplicationState) createReplicationTimeoutTask(start, end uint64, nodes []NodeID, index int) *TimeoutTask { - taskFunc := func() { - r.lock.Lock() - defer r.lock.Unlock() - r.sendRequestToNode(start, end, nodes, (index+1)%len(nodes)) - } - timeoutTask := &TimeoutTask{ - Start: start, - End: end, - NodeID: nodes[index], - TaskID: getTimeoutID(start, end), - Task: taskFunc, - Deadline: r.timeoutHandler.GetTime().Add(DefaultReplicationRequestTimeout), - } - - return timeoutTask -} - -// receivedReplicationResponse notifies the task handler a response was received. If the response -// was incomplete(meaning our timeout expected more seqs), then we will create a new timeout -// for the missing sequences and send the request to a different node. -func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node NodeID) { - seqs := make([]uint64, 0, len(data)) - - // remove all sequences where we expect a finalization but only received a notarization - highestSeq := r.highestSequenceObserved.seq - for _, qr := range data { - if qr.GetSequence() <= highestSeq && qr.Finalization == nil && qr.Notarization != nil { - r.logger.Debug("Received notarization without finalization, skipping", zap.Stringer("from", node), zap.Uint64("seq", qr.GetSequence())) - continue - } - - seqs = append(seqs, qr.GetSequence()) - } - - slices.Sort(seqs) - - task := FindReplicationTask(r.timeoutHandler, node, seqs) - if task == nil { - r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node), zap.Any("seqs", seqs)) - return - } - r.timeoutHandler.RemoveTask(node, task.TaskID) - - // we found the timeout, now make sure all seqs were returned - missing := findMissingNumbersInRange(task.Start, task.End, seqs) - if len(missing) == 0 { - return - } - - // if not all sequences were returned, create new timeouts - r.logger.Debug("Received missing sequences in the replication response", zap.Stringer("from", node), zap.Any("missing", missing)) - nodes := r.highestSequenceObserved.signers.Remove(r.id) - numNodes := len(nodes) - segments := CompressSequences(missing) - for i, seqs := range segments { - index := i % numNodes - newTask := r.createReplicationTimeoutTask(seqs.Start, seqs.End, nodes, index) - r.timeoutHandler.AddTask(newTask) - } -} - -// findMissingNumbersInRange finds numbers in an array constructed by [start...end] that are not in [nums] -// ex. (3, 10, [1,2,3,4,5,6]) -> [7,8,9,10] -func findMissingNumbersInRange(start, end uint64, nums []uint64) []uint64 { - numMap := make(map[uint64]struct{}) - for _, num := range nums { - numMap[num] = struct{}{} - } - - var result []uint64 - - for i := start; i <= end; i++ { - if _, exists := numMap[i]; !exists { - result = append(result, i) - } - } - - return result -} - -func (r *ReplicationState) replicateBlocks(finalization *Finalization, nextSeqToCommit uint64) { - if !r.enabled { - return - } - - signedSequence := &signedSequence{ - seq: finalization.Finalization.Seq, - signers: finalization.QC.Signers(), - } - - r.collectMissingSequences(signedSequence, nextSeqToCommit) -} - -// maybeCollectFutureSequences attempts to collect future sequences if -// there are more to be collected and the round has caught up for us to send the request. -func (r *ReplicationState) maybeCollectFutureSequences(nextSequenceToCommit uint64) { - if !r.enabled { - return - } - - if r.lastSequenceRequested >= r.highestSequenceObserved.seq { - return - } - - // we send out more requests once our seq has caught up to 1/2 of the maxRoundWindow - if nextSequenceToCommit+r.maxRoundWindow/2 > r.lastSequenceRequested { - r.collectMissingSequences(r.highestSequenceObserved, nextSequenceToCommit) - } -} - -func (r *ReplicationState) StoreQuorumRound(round QuorumRound) { - if _, ok := r.receivedQuorumRounds[round.GetRound()]; ok { - // maybe this quorum round was behind - if r.receivedQuorumRounds[round.GetRound()].Finalization == nil && round.Finalization != nil { - r.receivedQuorumRounds[round.GetRound()] = round - } - return - } - - if round.EmptyNotarization == nil && round.GetSequence() > r.highestSequenceObserved.seq { - signedSeq, err := newSignedSequenceFromRound(round) - if err != nil { - // should never be here since we already checked the QuorumRound was valid - r.logger.Error("Error creating signed sequence from round", zap.Error(err)) - return - } - - r.highestSequenceObserved = signedSeq - } - - r.logger.Debug("Stored quorum round ", zap.Stringer("qr", &round)) - r.receivedQuorumRounds[round.GetRound()] = round -} - -func (r *ReplicationState) GetFinalizedBlockForSequence(seq uint64) (Block, Finalization, bool) { - for _, round := range r.receivedQuorumRounds { - if round.GetSequence() == seq { - if round.Block == nil || round.Finalization == nil { - // this could be an empty notarization - continue - } - return round.Block, *round.Finalization, true - } - } - return nil, Finalization{}, false -} - -func (r *ReplicationState) highestKnownRound() uint64 { - var highestRound uint64 - for round := range r.receivedQuorumRounds { - if round > highestRound { - highestRound = round - } - } - return highestRound -} - -func (r *ReplicationState) GetQuorumRoundWithSeq(seq uint64) *QuorumRound { - for _, round := range r.receivedQuorumRounds { - if round.GetSequence() == seq { - return &round - } - } - return nil -} - -// FindReplicationTask returns a TimeoutTask assigned to [node] that contains the lowest sequence in [seqs]. -// A sequence is considered "contained" if it falls between a task's Start (inclusive) and End (inclusive). -func FindReplicationTask(t *TimeoutHandler, node NodeID, seqs []uint64) *TimeoutTask { - var lowestTask *TimeoutTask - - t.forEach(string(node), func(tt *TimeoutTask) { - for _, seq := range seqs { - if seq >= tt.Start && seq <= tt.End { - if lowestTask == nil { - lowestTask = tt - } else if seq < lowestTask.Start { - lowestTask = tt - } - } - } - }) - - return lowestTask -} diff --git a/replication_request_test.go b/replication_request_test.go index db271c8..8c65bab 100644 --- a/replication_request_test.go +++ b/replication_request_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "testing" + "time" "github.com/ava-labs/simplex" "github.com/ava-labs/simplex/testutil" @@ -43,6 +44,7 @@ func TestReplicationRequestIndexedBlocks(t *testing.T) { msg := <-comm.in resp := msg.VerifiedReplicationResponse require.Nil(t, resp.LatestRound) + require.Nil(t, resp.LatestFinalizedSeq) require.Equal(t, len(sequences), len(resp.Data)) for i, data := range resp.Data { @@ -60,9 +62,7 @@ func TestReplicationRequestIndexedBlocks(t *testing.T) { err = e.HandleMessage(req, nodes[1]) require.NoError(t, err) - msg = <-comm.in - resp = msg.VerifiedReplicationResponse - require.Zero(t, len(resp.Data)) + require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) } // TestReplicationRequestNotarizations tests replication requests for notarized blocks. @@ -97,8 +97,9 @@ func TestReplicationRequestNotarizations(t *testing.T) { } req := &simplex.Message{ ReplicationRequest: &simplex.ReplicationRequest{ - Seqs: seqs, - LatestRound: 0, + Seqs: seqs, + LatestRound: 1, + LatestFinalizedSeq: 0, }, } @@ -109,7 +110,30 @@ func TestReplicationRequestNotarizations(t *testing.T) { resp := msg.VerifiedReplicationResponse require.NoError(t, err) require.NotNil(t, resp) + require.NotNil(t, resp.LatestRound) + require.Nil(t, resp.LatestFinalizedSeq) require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + + for _, round := range resp.Data { + require.Nil(t, round.EmptyNotarization) + notarizedBlock, ok := rounds[round.VerifiedBlock.BlockHeader().Round] + require.True(t, ok) + require.Equal(t, notarizedBlock.VerifiedBlock, round.VerifiedBlock) + require.Equal(t, notarizedBlock.Notarization, round.Notarization) + } + + // now ask for the notarizations as rounds + req = &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: seqs, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + msg = <-comm.in + resp = msg.VerifiedReplicationResponse for _, round := range resp.Data { require.Nil(t, round.EmptyNotarization) notarizedBlock, ok := rounds[round.VerifiedBlock.BlockHeader().Round] @@ -134,6 +158,9 @@ func TestReplicationRequestMixed(t *testing.T) { numBlocks := uint64(8) rounds := make(map[uint64]simplex.VerifiedQuorumRound) + + numExpectedRounds := 0 + tailNotarizations := 0 // only produce a notarization for blocks we are the leader, otherwise produce an empty notarization for i := range numBlocks { leaderForRound := bytes.Equal(simplex.LeaderForRound(nodes, uint64(i)), e.ID) @@ -147,6 +174,7 @@ func TestReplicationRequestMixed(t *testing.T) { rounds[i] = simplex.VerifiedQuorumRound{ EmptyNotarization: emptyNotarization, } + tailNotarizations++ continue } block, notarization := advanceRoundFromNotarization(t, e, bb) @@ -155,18 +183,77 @@ func TestReplicationRequestMixed(t *testing.T) { VerifiedBlock: block, Notarization: notarization, } + + numExpectedRounds++ + tailNotarizations = 0 } + numExpectedRounds += tailNotarizations require.Equal(t, uint64(numBlocks), e.Metadata().Round) - seqs := make([]uint64, 0, len(rounds)) + roundsRequested := make([]uint64, 0, len(rounds)) for k := range rounds { - seqs = append(seqs, k) + roundsRequested = append(roundsRequested, k) + } + + req := &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: roundsRequested, + LatestRound: 1, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + msg := <-comm.in + resp := msg.VerifiedReplicationResponse + require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + require.Equal(t, numExpectedRounds, len(resp.Data)) + + for _, round := range resp.Data { + notarizedBlock, ok := rounds[round.GetRound()] + require.True(t, ok) + require.Equal(t, notarizedBlock.VerifiedBlock, round.VerifiedBlock) + require.Equal(t, notarizedBlock.Notarization, round.Notarization) + require.Equal(t, notarizedBlock.EmptyNotarization, round.EmptyNotarization) + } +} + +func TestReplicationRequestTailingEmptyNotarizations(t *testing.T) { + bb := &testutil.TestBlockBuilder{Out: make(chan *testutil.TestBlock, 1)} + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + require.NoError(t, e.Start()) + + numBlocks := uint64(8) + rounds := make(map[uint64]simplex.VerifiedQuorumRound) + // only produce a notarization for blocks we are the leader, otherwise produce an empty notarization + for i := range numBlocks { + emptyNotarization := testutil.NewEmptyNotarization(nodes, uint64(i)) + e.HandleMessage(&simplex.Message{ + EmptyNotarization: emptyNotarization, + }, nodes[1]) + wal.AssertNotarization(uint64(i)) + rounds[i] = simplex.VerifiedQuorumRound{ + EmptyNotarization: emptyNotarization, + } + } + + require.Equal(t, uint64(numBlocks), e.Metadata().Round) + roundsRequested := make([]uint64, 0, len(rounds)) + for k := range rounds { + roundsRequested = append(roundsRequested, k) } req := &simplex.Message{ ReplicationRequest: &simplex.ReplicationRequest{ - Seqs: seqs, - LatestRound: 0, + Rounds: roundsRequested, + LatestRound: 1, }, } @@ -177,6 +264,7 @@ func TestReplicationRequestMixed(t *testing.T) { resp := msg.VerifiedReplicationResponse require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + require.Equal(t, len(roundsRequested), len(resp.Data)) for _, round := range resp.Data { notarizedBlock, ok := rounds[round.GetRound()] require.True(t, ok) @@ -186,6 +274,31 @@ func TestReplicationRequestMixed(t *testing.T) { } } +func TestReplicationRequestUnknownSeqsAndRounds(t *testing.T) { + bb := &testutil.TestBlockBuilder{Out: make(chan *testutil.TestBlock, 1)} + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + require.NoError(t, e.Start()) + + req := &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: []uint64{100, 101, 102}, + Seqs: []uint64{200, 201, 202}, + LatestRound: 1, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) +} + func TestNilReplicationResponse(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} net := testutil.NewInMemNetwork(t, nodes) @@ -202,7 +315,7 @@ func TestNilReplicationResponse(t *testing.T) { } // TestMalformedReplicationResponse tests that a malformed replication response is handled correctly. -// This replication response is malformeds since it must also include a notarization or +// This replication response is malformed since it must also include a notarization or // finalization. func TestMalformedReplicationResponse(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} diff --git a/replication_state.go b/replication_state.go new file mode 100644 index 0000000..32896a6 --- /dev/null +++ b/replication_state.go @@ -0,0 +1,160 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "crypto/rand" + "math/big" + "sync" + "time" + + "go.uber.org/zap" +) + +type ReplicationState struct { + enabled bool + logger Logger + sequenceReplicator *replicator + roundReplicator *replicator +} + +func NewReplicationState(logger Logger, comm Communication, id NodeID, maxRoundWindow uint64, enabled bool, start time.Time, lock *sync.Mutex) *ReplicationState { + if !enabled { + return &ReplicationState{ + enabled: enabled, + logger: logger, + } + } + + return &ReplicationState{ + enabled: enabled, + sequenceReplicator: newReplicator(logger, comm, id, maxRoundWindow, start, lock), + roundReplicator: newReplicator(logger, comm, id, maxRoundWindow, start, lock), + logger: logger, + } +} + +func (r *ReplicationState) AdvanceTime(now time.Time) { + if !r.enabled { + return + } + r.sequenceReplicator.advanceTime(now) + r.roundReplicator.advanceTime(now) +} + +// isReplicationComplete returns true if we have finished the replication process. +// The process is considered finished once [currentRound] has caught up to the highest round received. +func (r *ReplicationState) isReplicationComplete(nextSeqToCommit uint64, currentRound uint64) bool { + if !r.enabled { + return true + } + + return r.sequenceReplicator.isReplicationComplete(nextSeqToCommit) && r.roundReplicator.isReplicationComplete(currentRound) +} + +// maybeSendFutureRequests attempts to collect future sequences if +// there are more to be collected and the round has caught up for us to send the request. +func (r *ReplicationState) maybeAdvancedState(nextSequenceToCommit uint64, currentRound uint64) { + if !r.enabled { + return + } + + r.sequenceReplicator.updateState(nextSequenceToCommit) + r.roundReplicator.updateState(currentRound) +} + +func (r *ReplicationState) storeQuorumRound(round QuorumRound, from NodeID) { + if round.Finalization != nil { + r.sequenceReplicator.storeQuorumRound(round, from, round.Finalization.Finalization.Seq) + r.roundReplicator.removeOldValues(round.Finalization.Finalization.Round) + return + } + + // otherwise we are storing a round without finalization + // don't bother storing rounds that are older than the highest finalized round we know + // todo: grab a lock for sequence replicator + if r.sequenceReplicator.getHighestRound() >= round.GetRound() { + return + } + + r.roundReplicator.storeQuorumRound(round, from, round.GetRound()) +} + +func (r *ReplicationState) getFinalizedBlockForSequence(seq uint64) (Block, Finalization, bool) { + qr, ok := r.sequenceReplicator.retrieveQuorumRound(seq) + if !ok || qr.Finalization == nil || qr.Block == nil { + return nil, Finalization{}, false + } + + return qr.Block, *qr.Finalization, true +} + +func (r *ReplicationState) getBlockWithSeq(seq uint64) (Block, bool) { + qr, ok := r.sequenceReplicator.retrieveQuorumRound(seq) + if ok && qr.Block != nil { + return qr.Block, true + } + + // check notarization replicator + qr, ok = r.roundReplicator.retrieveQuorumRoundBySeq(seq) + if ok && qr.Block != nil { + return qr.Block, true + } + + return nil, false +} + +func (r *ReplicationState) resendFinalizationRequest(seq uint64, signers []NodeID) error { + if !r.enabled { + return nil + } + + numSigners := int64(len(signers)) + index, err := rand.Int(rand.Reader, big.NewInt(numSigners)) + if err != nil { + return err + } + + // because we are resending because the block failed to verify, we should remove the stored quorum round + // so that we can try to get a new block & finalization + delete(r.sequenceReplicator.receivedQuorumRounds, seq) + r.sequenceReplicator.sendRequestToNode(seq, seq, signers[index.Int64()]) + return nil +} + +func (r *ReplicationState) getNonFinalizedQuorumRound(round uint64) *QuorumRound { + qr, ok := r.roundReplicator.retrieveQuorumRound(round) + if ok { + return qr + } + return nil +} + +// receivedFutureFinalization processes a finalization that was created in a future round. +func (r *ReplicationState) receivedFutureFinalization(finalization *Finalization, nextSeqToCommit uint64) { + if !r.enabled { + return + } + + signedSequence := newSignedRoundOrSeqFromFinalization(finalization, r.sequenceReplicator.myNodeID) + + // maybe this finalization was for a round that we initially thought only had notarizations + // remove from the round replicator since we now have a finalization for this round + r.roundReplicator.removeOldValues(finalization.Finalization.BlockHeader.Round) + r.sequenceReplicator.maybeSendMoreReplicationRequests(signedSequence, nextSeqToCommit) +} + +func (r *ReplicationState) receivedFutureRound(round uint64, signers []NodeID, currentRound uint64) { + if !r.enabled { + return + } + + if r.sequenceReplicator.getHighestRound() >= round { + r.logger.Debug("Ignoring round replication for a future round since we have a finalization for a higher round", zap.Uint64("round", round)) + return + } + + signedSequence := newSignedRoundOrSeqFromRound(round, signers, r.roundReplicator.myNodeID) + r.roundReplicator.maybeSendMoreReplicationRequests(signedSequence, currentRound) +} diff --git a/replication_test.go b/replication_test.go index b07ff53..33a3331 100644 --- a/replication_test.go +++ b/replication_test.go @@ -70,7 +70,6 @@ func testReplication(t *testing.T, startSeq uint64, nodes []simplex.NodeID) { for _, n := range net.Instances { n.Storage.WaitForBlockCommit(startSeq) } - assertEqualLedgers(t, net) } // TestReplicationAdversarialNode tests the replication process of a node that @@ -287,7 +286,6 @@ func testReplicationEmptyNotarizations(t *testing.T, nodes []simplex.NodeID, end NewSimplexNode(t, nodes[3], net, newNodeConfig(nodes[3])) NewSimplexNode(t, nodes[4], net, newNodeConfig(nodes[4])) laggingNode := NewSimplexNode(t, nodes[5], net, newNodeConfig(nodes[5])) - for _, n := range net.Instances { require.Equal(t, uint64(0), n.Storage.NumBlocks()) startTimes = append(startTimes, n.E.StartTime) @@ -333,7 +331,19 @@ func testReplicationEmptyNotarizations(t *testing.T, nodes []simplex.NodeID, end net.SetAllNodesMessageFilter(AllowAllMessages) net.Connect(laggingNode.E.ID) net.TriggerLeaderBlockBuilder(endRound) - for _, n := range net.Instances { + for i, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + // maybe lagging node has requested finalizations to a node without it, we may need to resend the request + for { + if n.Storage.NumBlocks() == 2 { + break + } + time.Sleep(10 * time.Millisecond) + startTimes[i] = startTimes[i].Add(2 * simplex.DefaultMaxProposalWaitTime) + n.E.AdvanceTime(startTimes[i]) + } + continue + } n.Storage.WaitForBlockCommit(1) } @@ -599,7 +609,6 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { tbb := &TestBlockBuilder{Out: make(chan *TestBlock, 1), BlockShouldBeBuilt: make(chan struct{}, 1), In: make(chan *TestBlock, 1)} bb := NewTestControlledBlockBuilder(t) bb.TestBlockBuilder = *tbb - storage := NewInMemStorage() nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} blocks := createBlocks(t, nodes, 5) @@ -1108,3 +1117,221 @@ func TestReplicationVerifyEmptyNotarization(t *testing.T) { return wal.ContainsEmptyNotarization(0) }, time.Millisecond*500, time.Millisecond*10, "Did not expect an empty notarization with a corrupt QC to be written to the WAL") } + +// almostFinalizeBlocks is a message filter that allows all messages except for finalized votes +// and finalizations, unless the message is from node 1. This way each node will have 2 finalized votes, +// which is one short from quorum. +func almostFinalizeBlocks(msg *simplex.Message, from, _ simplex.NodeID) bool { + // block finalized votes and finalizations + if msg.Finalization != nil || msg.FinalizeVote != nil { + return from.Equals(simplex.NodeID{1}) + } + return true +} + +// TestReplicationVotesForNotarizations tests that a lagging node will replicate +// finalizations and notarizations. It ensures the node sends finalized votes for rounds +// without finalizations. +func TestReplicationVotesForNotarizations(t *testing.T) { + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + + // TODO: numFinalized and numNotarized could be parameterized to test different scenarios + numFinalizedBlocks := uint64(5) + // number of notarized blocks after the finalized blocks + numNotarizedBlocks := uint64(11) + net := NewInMemNetwork(t, nodes) + + storageData := createBlocks(t, nodes, numFinalizedBlocks) + nodeConfig := func(from simplex.NodeID) *TestNodeConfig { + comm := NewTestComm(from, net, almostFinalizeBlocks) + return &TestNodeConfig{ + InitialStorage: storageData, + Comm: comm, + ReplicationEnabled: true, + } + } + + n1 := NewSimplexNode(t, nodes[0], net, nodeConfig(nodes[0])) + n2 := NewSimplexNode(t, nodes[1], net, nodeConfig(nodes[1])) + adversary := NewSimplexNode(t, nodes[2], net, nodeConfig(nodes[2])) + laggingNode := NewSimplexNode(t, nodes[3], net, &TestNodeConfig{ + ReplicationEnabled: true, + }) + + startTimes := make([]time.Time, 0, len(nodes)) + for _, n := range net.Instances { + startTimes = append(startTimes, n.E.StartTime) + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + continue + } + require.Equal(t, numFinalizedBlocks, n.Storage.NumBlocks()) + } + + // lagging node should be disconnected while nodes create notarizations without finalizations + net.Disconnect(laggingNode.E.ID) + + net.StartInstances() + + missedSeqs := uint64(0) + // normal nodes continue to make progress + for round := numFinalizedBlocks; round < numFinalizedBlocks+numNotarizedBlocks; round++ { + emptyRound := bytes.Equal(simplex.LeaderForRound(nodes, round), laggingNode.E.ID) + if emptyRound { + missedSeqs++ + net.AdvanceWithoutLeader(startTimes, round, laggingNode.E.ID) + } else { + net.TriggerLeaderBlockBuilder(round) + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + continue + } + n.WAL.AssertNotarization(round) + } + } + } + + // all nodes should be on round [numFinalizedBlocks + numNotarizedBlocks - 1] + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.E.Metadata().Round) + continue + } + require.Equal(t, numFinalizedBlocks, n.Storage.NumBlocks()) + require.Equal(t, numFinalizedBlocks+numNotarizedBlocks, n.E.Metadata().Round) + } + + // at this point in time, the adversarial node will disconnect + // since each node has sent 2 finalized votes, which is one short of a quorum + // the lagging node will need to replicate the finalizations, and then send votes for notarizations + net.Disconnect(adversary.E.ID) + net.Connect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(AllowAllMessages) + + // the adversary should not be the leader(to simplify test) + isAdversaryLeader := bytes.Equal(simplex.LeaderForRound(nodes, numFinalizedBlocks+numNotarizedBlocks), adversary.E.ID) + require.False(t, isAdversaryLeader) + + // lagging node should not be leader + isLaggingNodeLeader := bytes.Equal(simplex.LeaderForRound(nodes, numFinalizedBlocks+numNotarizedBlocks), laggingNode.E.ID) + require.False(t, isLaggingNodeLeader) + + // trigger block building, but we only have 2 connected nodes so the nodes will time out + net.TriggerLeaderBlockBuilder(numFinalizedBlocks + numNotarizedBlocks) + + // ensure time out on required nodes + n1.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + n2.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + require.Equal(t, uint64(0), laggingNode.E.Metadata().Round) + laggingNode.TimeoutOnRound(0) + + expectedNumBlocks := numFinalizedBlocks + numNotarizedBlocks - missedSeqs + // because the adversarial node is offline , we may need to send replication requests many times + for { + time.Sleep(time.Millisecond * 100) + if laggingNode.Storage.NumBlocks() == expectedNumBlocks { + break + } + + startTimes[3] = startTimes[3].Add(simplex.DefaultReplicationRequestTimeout) + laggingNode.E.AdvanceTime(startTimes[3]) + } + + for _, n := range net.Instances { + if n.E.ID.Equals(adversary.E.ID) { + continue + } + n.Storage.WaitForBlockCommit(expectedNumBlocks - 1) // subtract -1 because seq starts at 0 + } + + laggingNode.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + for _, n := range net.Instances { + if n.E.ID.Equals(adversary.E.ID) { + continue + } + WaitToEnterRound(t, n.E, numFinalizedBlocks+numNotarizedBlocks+1) + require.True(t, n.WAL.ContainsEmptyNotarization(numFinalizedBlocks+numNotarizedBlocks)) + } +} + +// TestReplicationEmptyNotarizations ensures a lagging node will properly replicate +// a tail of empty notarizations. +func TestReplicationEmptyNotarizationsTail(t *testing.T) { + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}, {5}, {6}} + + for endRound := uint64(2); endRound <= 2*simplex.DefaultMaxRoundWindow; endRound++ { + isLaggingNodeLeader := bytes.Equal(simplex.LeaderForRound(nodes, endRound), nodes[5]) + if isLaggingNodeLeader { + continue + } + + testName := fmt.Sprintf("Empty_notarizations_end_round%d", endRound) + t.Run(testName, func(t *testing.T) { + t.Parallel() + testReplicationEmptyNotarizationsTail(t, nodes, endRound) + }) + } +} + +func testReplicationEmptyNotarizationsTail(t *testing.T, nodes []simplex.NodeID, endRound uint64) { + net := NewInMemNetwork(t, nodes) + newNodeConfig := func(from simplex.NodeID) *TestNodeConfig { + comm := NewTestComm(from, net, AllowAllMessages) + return &TestNodeConfig{ + Comm: comm, + ReplicationEnabled: true, + } + } + + startTimes := make([]time.Time, 0, len(nodes)) + NewSimplexNode(t, nodes[0], net, newNodeConfig(nodes[0])) + NewSimplexNode(t, nodes[1], net, newNodeConfig(nodes[1])) + NewSimplexNode(t, nodes[2], net, newNodeConfig(nodes[2])) + NewSimplexNode(t, nodes[3], net, newNodeConfig(nodes[3])) + NewSimplexNode(t, nodes[4], net, newNodeConfig(nodes[4])) + laggingNode := NewSimplexNode(t, nodes[5], net, newNodeConfig(nodes[5])) + for _, n := range net.Instances { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + startTimes = append(startTimes, n.E.StartTime) + } + + net.StartInstances() + + net.Disconnect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(onlyAllowEmptyRoundMessages) + + // normal nodes continue to make progress + for i := uint64(0); i < endRound; i++ { + leader := simplex.LeaderForRound(nodes, i) + if !leader.Equals(laggingNode.E.ID) { + net.TriggerLeaderBlockBuilder(i) + } + + net.AdvanceWithoutLeader(startTimes, i, laggingNode.E.ID) + } + + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.E.Metadata().Round) + continue + } + + // assert metadata + require.Equal(t, uint64(endRound), n.E.Metadata().Round) + require.Equal(t, uint64(0), n.E.Metadata().Seq) + require.Equal(t, uint64(0), n.E.Storage.NumBlocks()) + } + + net.Connect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(AllowAllMessages) + + // have the lagging node timeout to trigger replication + laggingNode.E.AdvanceTime(time.Now().Add(laggingNode.E.MaxProposalWait)) + + for _, n := range net.Instances { + WaitToEnterRound(t, n.E, endRound) + require.Equal(t, uint64(endRound), n.E.Metadata().Round) + } +} diff --git a/replication_timeout_test.go b/replication_timeout_test.go index 084d684..b02ac64 100644 --- a/replication_timeout_test.go +++ b/replication_timeout_test.go @@ -121,16 +121,17 @@ func TestReplicationRequestTimeoutCancels(t *testing.T) { // all blocks except the lagging node start at round 8, seq 8. // lagging node starts at round 0, seq 0. // this asserts that the lagging node catches up to the latest round - for i := 0; i <= int(startSeq); i++ { - for _, n := range net.Instances { - n.Storage.WaitForBlockCommit(uint64(startSeq)) - } + for _, n := range net.Instances { + n.Storage.WaitForBlockCommit(startSeq) } // ensure lagging node doesn't resend requests mf := &testTimeoutMessageFilter{ t: t, } + + // allow the replication state to cancel the request before setting filter + time.Sleep(100 * time.Millisecond) laggingNode.E.Comm.(*testutil.TestComm).SetFilter(mf.failOnReplicationRequest) laggingNode.E.AdvanceTime(laggingNode.E.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2)) @@ -367,7 +368,10 @@ func (c *collectNotarizationComm) removeFinalizationsFromReplicationResponses(ms newData = append(newData, qr) } msg.VerifiedReplicationResponse.Data = newData - c.replicationResponses <- struct{}{} + select { + case c.replicationResponses <- struct{}{}: + default: + } } if msg.Finalization != nil && msg.Finalization.Finalization.Round == 0 { diff --git a/replicator.go b/replicator.go new file mode 100644 index 0000000..065b5c9 --- /dev/null +++ b/replicator.go @@ -0,0 +1,310 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "fmt" + "math" + "slices" + "sync" + "time" + + "go.uber.org/zap" +) + +// signedRoundOrSeq is a round or sequence that has been signed by a quorum certificate. +type signedRoundOrSeq struct { + round uint64 + seq uint64 + signers NodeIDs + isRound bool +} + +func newSignedRoundOrSeq(round QuorumRound, myNodeID NodeID) (*signedRoundOrSeq, error) { + ss := &signedRoundOrSeq{} + switch { + case round.Finalization != nil: + ss.signers = round.Finalization.QC.Signers() + ss.round = round.GetRound() + ss.seq = round.GetSequence() + ss.isRound = false + case round.Notarization != nil: + ss.signers = round.Notarization.QC.Signers() + ss.round = round.GetRound() + ss.seq = round.GetSequence() + ss.isRound = true + case round.EmptyNotarization != nil: + ss.signers = round.EmptyNotarization.QC.Signers() + ss.round = round.GetRound() + ss.seq = 0 + ss.isRound = true + default: + return nil, fmt.Errorf("round does not contain a finalization, empty notarization, or notarization") + } + + // it's possible our node has signed this ss. + // For example this may happen if our node has sent a finalized vote + // for this round and has not received the + // finalization from the network. + ss.signers = ss.signers.Remove(myNodeID) + return ss, nil +} + +func newSignedRoundOrSeqFromFinalization(finalization *Finalization, myNodeID NodeID) *signedRoundOrSeq { + return &signedRoundOrSeq{ + round: finalization.Finalization.Round, + seq: finalization.Finalization.Seq, + signers: NodeIDs(finalization.QC.Signers()).Remove(myNodeID), + isRound: false, + } +} + +func newSignedRoundOrSeqFromRound(round uint64, signers NodeIDs, myNodeID NodeID) *signedRoundOrSeq { + ss := &signedRoundOrSeq{ + round: round, + seq: 0, // seq not needed for round replicator + signers: signers.Remove(myNodeID), + isRound: true, + } + return ss +} + +// roundOrSeq returns either the round or sequence depending on whether the replicator is +// replicating rounds or sequences. +func (s *signedRoundOrSeq) roundOrSeq() uint64 { + if s.isRound { + return s.round + } + return s.seq +} + +type sender interface { + // Send sends a message to the given destination node + Send(msg *Message, destination NodeID) +} + +// replicator manages the state for replicating sequences or rounds until highestObserved. +type replicator struct { + sender sender + myNodeID NodeID + logger Logger + maxRoundWindow uint64 + epochLock *sync.Mutex + + // highest sequence or round we have requested. Ensures we don't request the + // same sequence multiple times, also allows us to limit the number of + // outstanding requests to be at most [maxRoundWindow] ahead of highestRequested + highestRequested uint64 + + // highest we have received + highestObserved *signedRoundOrSeq + + // receivedQuorumRounds maps either sequences or rounds to quorum rounds + receivedQuorumRounds map[uint64]QuorumRound + + // request iterator + requestIterator int + + timeoutHandler *TimeoutHandler +} + +func newReplicator(logger Logger, sender sender, ourNodeID NodeID, maxRoundWindow uint64, start time.Time, lock *sync.Mutex) *replicator { + r := &replicator{ + receivedQuorumRounds: make(map[uint64]QuorumRound), + sender: sender, + myNodeID: ourNodeID, + logger: logger, + maxRoundWindow: maxRoundWindow, + epochLock: lock, + } + + r.timeoutHandler = NewTimeoutHandler(logger, start, DefaultReplicationRequestTimeout, r.resendReplicationRequests) + return r +} + +func (r *replicator) advanceTime(now time.Time) { + r.timeoutHandler.Tick(now) +} + +func (r *replicator) resendReplicationRequests(missingIds []uint64) { + // we call this function in the timeout handler goroutine, so we need to + // ensure we don't have concurrent access to highestObserved + r.epochLock.Lock() + defer r.epochLock.Unlock() + + nodes := r.highestObserved.signers + numNodes := len(nodes) + slices.Sort(missingIds) + segments := CompressSequences(missingIds) + for i, seqsOrRounds := range segments { + index := (i + r.requestIterator) % numNodes + r.sendRequestToNode(seqsOrRounds.Start, seqsOrRounds.End, nodes[index]) + } + + r.requestIterator++ +} + +// isReplicationComplete returns true if we have finished the replication process. +// The process is considered finished once highestObserved has caught up to the target +// (either nextSeqToCommit or currentRound). +func (r *replicator) isReplicationComplete(target uint64) bool { + if r.highestObserved != nil && r.highestObserved.roundOrSeq() >= target { + return false + } + return true +} + +func (r *replicator) getHighestRound() uint64 { + if r.highestObserved != nil { + return r.highestObserved.round + } + return 0 +} + +// maybeSendMoreReplicationRequests checks if we need to send more replication requests given an observed round or sequence. +// it limits the amount of outstanding requests to be at most [maxRoundWindow] ahead of [currentRoundOrNextSequence] which is +// either nextSeqToCommit or currentRound depending on if we are replicating sequences or rounds. +func (r *replicator) maybeSendMoreReplicationRequests(observed *signedRoundOrSeq, currentRoundOrNextSequence uint64) { + observedRoundOrSeq := observed.roundOrSeq() + + // we've observed something we've already requested + if r.highestRequested >= observedRoundOrSeq && r.highestObserved != nil { + r.logger.Debug("Already requested observed value, skipping", zap.Uint64("value", observedRoundOrSeq), zap.Bool("isRound", observed.isRound)) + return + } + + // if this is the highest observed sequence or round, update our state + if r.highestObserved == nil || observedRoundOrSeq > r.highestObserved.roundOrSeq() { + r.highestObserved = observed + } + + start := math.Max(float64(currentRoundOrNextSequence), float64(r.highestRequested)) + // we limit the number of outstanding requests to be at most maxRoundWindow ahead of nextSeqToCommit + end := math.Min(float64(observedRoundOrSeq), float64(r.maxRoundWindow+currentRoundOrNextSequence)) + + r.logger.Debug("Node is behind, attempting to request missing values", zap.Uint64("value", observedRoundOrSeq), zap.Uint64("start", uint64(start)), zap.Uint64("end", uint64(end)), zap.Bool("isRound", observed.isRound)) + r.sendReplicationRequests(uint64(start), uint64(end)) +} + +func (r *replicator) updateState(currentRoundOrNextSeq uint64) { + r.removeOldValues(currentRoundOrNextSeq) + + // we send out more requests once our seq has caught up to 1/2 of the maxRoundWindow + if currentRoundOrNextSeq+r.maxRoundWindow/2 > r.highestRequested && r.highestObserved != nil { + r.maybeSendMoreReplicationRequests(r.highestObserved, currentRoundOrNextSeq) + } +} + +func (r *replicator) removeOldValues(newValue uint64) { + r.timeoutHandler.RemoveOldTasks(newValue) + + for storedRound := range r.receivedQuorumRounds { + if storedRound < newValue { + delete(r.receivedQuorumRounds, storedRound) + } + } +} + +// sendReplicationRequests sends requests for missing sequences for the +// range of sequences [start, end] <- inclusive. It does so by splitting the +// range of sequences equally amount the nodes that have signed [highestSequenceObserved]. +func (r *replicator) sendReplicationRequests(start uint64, end uint64) { + // it's possible our node has signed [highestSequenceObserved]. + // For example this may happen if our node has sent a finalization + // for [highestSequenceObserved] and has not received the + // finalization from the network. + nodes := r.highestObserved.signers + numNodes := len(nodes) + + seqRequests := DistributeSequenceRequests(start, end, numNodes) + + r.logger.Debug("Distributing replication requests", zap.Uint64("start", start), zap.Uint64("end", end), zap.Stringer("nodes", NodeIDs(nodes))) + for i, seqsOrRounds := range seqRequests { + index := (i + r.requestIterator) % numNodes + r.sendRequestToNode(seqsOrRounds.Start, seqsOrRounds.End, r.highestObserved.signers[index]) + } + + // next time we send requests, we start with a different permutation + r.requestIterator++ +} + +// sendRequestToNode requests the sequences [start, end] from nodes[index]. +// In case the nodes[index] does not respond, we create a timeout that will +// re-send the request. +func (r *replicator) sendRequestToNode(start uint64, end uint64, node NodeID) { + roundsOrSeqs := make([]uint64, (end+1)-start) + for i := start; i <= end; i++ { + roundsOrSeqs[i-start] = i + // ensure we set a timeout for this sequence + r.timeoutHandler.AddTask(i) + } + + val := r.highestObserved.roundOrSeq() + if r.highestRequested < end { + r.highestRequested = end + } + + request := &ReplicationRequest{} + if r.highestObserved.isRound { + request.LatestRound = val + request.Rounds = roundsOrSeqs + } else { + request.LatestFinalizedSeq = val + request.Seqs = roundsOrSeqs + } + + msg := &Message{ReplicationRequest: request} + + r.logger.Debug("Requesting missing rounds/sequences ", + zap.Stringer("from", node), + zap.Uint64("start", start), + zap.Uint64("end", end), + zap.Bool("isRound", r.highestObserved.isRound), + zap.Uint64("latestRound", request.LatestRound), + zap.Uint64("latestSeq", request.LatestFinalizedSeq), + ) + r.sender.Send(msg, node) +} + +func (r *replicator) storeQuorumRound(round QuorumRound, from NodeID, roundOrSeq uint64) { + // check if this is the highest round or seq we have seen + if r.highestObserved == nil || roundOrSeq > r.highestObserved.roundOrSeq() { + signedSeq, err := newSignedRoundOrSeq(round, r.myNodeID) + if err != nil { + // should never be here since we already checked the QuorumRound was valid + r.logger.Error("Error creating signed sequence from round", zap.Error(err)) + return + } + + r.highestObserved = signedSeq + } + + if _, exists := r.receivedQuorumRounds[roundOrSeq]; exists { + // we've already stored this round + return + } + + r.receivedQuorumRounds[roundOrSeq] = round + + // we received this sequence, remove the timeout task + r.timeoutHandler.RemoveTask(roundOrSeq) + r.logger.Debug("Stored quorum round ", zap.Stringer("qr", &round), zap.String("from", from.String())) +} + +func (r *replicator) retrieveQuorumRound(key uint64) (*QuorumRound, bool) { + qr, ok := r.receivedQuorumRounds[key] + if ok { + return &qr, true + } + return nil, false +} + +func (r *replicator) retrieveQuorumRoundBySeq(seq uint64) (*QuorumRound, bool) { + for _, qr := range r.receivedQuorumRounds { + if qr.Block != nil && qr.Block.BlockHeader().Seq == seq { + return &qr, true + } + } + return nil, false +} diff --git a/testutil/comm.go b/testutil/comm.go index da0d38a..07cbe0e 100644 --- a/testutil/comm.go +++ b/testutil/comm.go @@ -108,21 +108,8 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess data = append(data, quorumRound) } - var latestRound *simplex.QuorumRound - if msg.VerifiedReplicationResponse.LatestRound != nil { - if msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization != nil { - latestRound = &simplex.QuorumRound{ - EmptyNotarization: msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization, - } - } else { - latestRound = &simplex.QuorumRound{ - Block: msg.VerifiedReplicationResponse.LatestRound.VerifiedBlock.(simplex.Block), - Notarization: msg.VerifiedReplicationResponse.LatestRound.Notarization, - Finalization: msg.VerifiedReplicationResponse.LatestRound.Finalization, - EmptyNotarization: msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization, - } - } - } + latestRound := verifiedQRtoQR(msg.VerifiedReplicationResponse.LatestRound) + latestSeq := verifiedQRtoQR(msg.VerifiedReplicationResponse.LatestFinalizedSeq) require.Nil( c.net.t, @@ -133,6 +120,7 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess msg.ReplicationResponse = &simplex.ReplicationResponse{ Data: data, LatestRound: latestRound, + LatestSeq: latestSeq, } } @@ -145,6 +133,24 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess } } +func verifiedQRtoQR(qr *simplex.VerifiedQuorumRound) *simplex.QuorumRound { + if qr == nil { + return nil + } + + if qr.EmptyNotarization != nil { + return &simplex.QuorumRound{ + EmptyNotarization: qr.EmptyNotarization, + } + } + + return &simplex.QuorumRound{ + Block: qr.VerifiedBlock.(simplex.Block), + Notarization: qr.Notarization, + Finalization: qr.Finalization, + } +} + func (c *TestComm) isMessagePermitted(msg *simplex.Message, destination simplex.NodeID) bool { c.lock.RLock() defer c.lock.RUnlock() diff --git a/testutil/node.go b/testutil/node.go index 7d5df3a..4dd8945 100644 --- a/testutil/node.go +++ b/testutil/node.go @@ -117,3 +117,26 @@ func (t *TestNode) handleMessages() { } } } + +// TimeoutOnRound advances time until the node times out of the given round. +func (t *TestNode) TimeoutOnRound(round uint64) { + startTime := time.UnixMilli(t.currentTime.Load()) + for { + currentRound := t.E.Metadata().Round + if currentRound > round { + return + } + if len(t.BB.BlockShouldBeBuilt) == 0 { + t.BB.BlockShouldBeBuilt <- struct{}{} + } + startTime = startTime.Add(t.E.MaxProposalWait) + t.E.AdvanceTime(startTime) + + // check the wal for an empty vote for that round + if hasVote := t.WAL.ContainsEmptyVote(round); hasVote { + return + } + + time.Sleep(50 * time.Millisecond) + } +} diff --git a/timeout_handler.go b/timeout_handler.go index bf6d4de..8ae4dd8 100644 --- a/timeout_handler.go +++ b/timeout_handler.go @@ -4,35 +4,25 @@ package simplex import ( - "container/heap" - "fmt" "sync" "time" "go.uber.org/zap" ) -type TimeoutTask struct { - NodeID NodeID - TaskID string - Task func() - Deadline time.Time - - // for replication tasks - Start uint64 - End uint64 - - index int // for heap to work more efficiently -} +type timeoutRunner func(ids []uint64) type TimeoutHandler struct { - lock sync.Mutex + // how often to run through the tasks + runInterval time.Duration + // function to run tasks + taskRunner timeoutRunner + lock sync.Mutex ticks chan time.Time close chan struct{} - // nodeids -> range -> task - tasks map[string]map[string]*TimeoutTask - heap TaskHeap + // maps id to a task + tasks map[uint64]struct{} now time.Time log Logger @@ -40,36 +30,34 @@ type TimeoutHandler struct { // NewTimeoutHandler returns a TimeoutHandler and starts a new goroutine that // listens for ticks and executes TimeoutTasks. -func NewTimeoutHandler(log Logger, startTime time.Time, nodes []NodeID) *TimeoutHandler { - tasks := make(map[string]map[string]*TimeoutTask) - for _, node := range nodes { - tasks[string(node)] = make(map[string]*TimeoutTask) - } - +func NewTimeoutHandler(log Logger, startTime time.Time, runInterval time.Duration, taskRunner timeoutRunner) *TimeoutHandler { t := &TimeoutHandler{ - now: startTime, - tasks: tasks, - ticks: make(chan time.Time, 1), - close: make(chan struct{}), - log: log, + now: startTime, + tasks: make(map[uint64]struct{}), + ticks: make(chan time.Time, 1), + close: make(chan struct{}), + runInterval: runInterval, + taskRunner: taskRunner, + log: log, } - go t.run() + go t.run(startTime) return t } -func (t *TimeoutHandler) GetTime() time.Time { - t.lock.Lock() - defer t.lock.Unlock() - - return t.now -} +func (t *TimeoutHandler) run(startTime time.Time) { + lastTickTime := startTime -func (t *TimeoutHandler) run() { for t.shouldRun() { select { case now := <-t.ticks: + if now.Sub(lastTickTime) < t.runInterval { + continue + } + lastTickTime = now + + // update the current time t.lock.Lock() t.now = now t.lock.Unlock() @@ -83,25 +71,20 @@ func (t *TimeoutHandler) run() { func (t *TimeoutHandler) maybeRunTasks() { // go through the heap executing relevant tasks - for { - t.lock.Lock() - if t.heap.Len() == 0 { - t.lock.Unlock() - break - } + // grab all sequences + ids := make([]uint64, 0, len(t.tasks)) - next := t.heap[0] - if next.Deadline.After(t.now) { - t.lock.Unlock() - break - } + t.lock.Lock() + for id := range t.tasks { + ids = append(ids, id) + } + t.lock.Unlock() - heap.Pop(&t.heap) - delete(t.tasks[string(next.NodeID)], next.TaskID) - t.lock.Unlock() - t.log.Debug("Executing timeout task", zap.String("taskid", next.TaskID)) - next.Task() + if len(ids) == 0 { + return } + + t.taskRunner(ids) } func (t *TimeoutHandler) shouldRun() bool { @@ -124,57 +107,37 @@ func (t *TimeoutHandler) Tick(now time.Time) { } } -func (t *TimeoutHandler) AddTask(task *TimeoutTask) { +func (t *TimeoutHandler) AddTask(id uint64) { t.lock.Lock() defer t.lock.Unlock() - if _, ok := t.tasks[string(task.NodeID)]; !ok { - t.log.Debug("Attempting to add a task for an unknown node", zap.Stringer("from", task.NodeID)) - return - } - - // adds a task to the heap and the tasks map - if _, ok := t.tasks[string(task.NodeID)][task.TaskID]; ok { - t.log.Debug("Trying to add an already included task", zap.Stringer("from", task.NodeID), zap.String("Task ID", task.TaskID)) - return - } - - t.tasks[string(task.NodeID)][task.TaskID] = task - t.log.Debug("Adding timeout task", zap.Stringer("from", task.NodeID), zap.String("taskid", task.TaskID)) - heap.Push(&t.heap, task) + t.tasks[id] = struct{}{} + t.log.Debug("Adding timeout task", zap.Uint64("id", id)) } -func (t *TimeoutHandler) RemoveTask(nodeID NodeID, ID string) { +func (t *TimeoutHandler) RemoveTask(ID uint64) { t.lock.Lock() defer t.lock.Unlock() - if _, ok := t.tasks[string(nodeID)]; !ok { - t.log.Debug("Attempting to remove a task for an unknown node", zap.Stringer("from", nodeID)) - return - } - - if _, ok := t.tasks[string(nodeID)][ID]; !ok { + if _, ok := t.tasks[ID]; !ok { return } // find the task using the task map // remove it from the heap using the index - t.log.Debug("Removing timeout task", zap.Stringer("from", nodeID), zap.String("taskid", ID)) - heap.Remove(&t.heap, t.tasks[string(nodeID)][ID].index) - delete(t.tasks[string(nodeID)], ID) + t.log.Debug("Removing timeout task", zap.Uint64("id", ID)) + delete(t.tasks, ID) } -func (t *TimeoutHandler) forEach(nodeID string, f func(tt *TimeoutTask)) { +func (t *TimeoutHandler) RemoveOldTasks(cutoff uint64) { t.lock.Lock() defer t.lock.Unlock() - tasks, exists := t.tasks[nodeID] - if !exists { - return - } - - for _, task := range tasks { - f(task) + for id := range t.tasks { + if id < cutoff { + t.log.Debug("Removing old timeout task", zap.Uint64("id", id)) + delete(t.tasks, id) + } } } @@ -186,40 +149,3 @@ func (t *TimeoutHandler) Close() { close(t.close) } } - -const delimiter = "_" - -func getTimeoutID(start, end uint64) string { - return fmt.Sprintf("%d%s%d", start, delimiter, end) -} - -// ---------------------------------------------------------------------- -type TaskHeap []*TimeoutTask - -func (h *TaskHeap) Len() int { return len(*h) } - -// Less returns if the task at index [i] has a lower timeout than the task at index [j] -func (h *TaskHeap) Less(i, j int) bool { return (*h)[i].Deadline.Before((*h)[j].Deadline) } - -// Swap swaps the values at index [i] and [j] -func (h *TaskHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] - (*h)[i].index = i - (*h)[j].index = j -} - -func (h *TaskHeap) Push(x any) { - task := x.(*TimeoutTask) - task.index = h.Len() - *h = append(*h, task) -} - -func (h *TaskHeap) Pop() any { - old := *h - len := h.Len() - task := old[len-1] - old[len-1] = nil - *h = old[0 : len-1] - task.index = -1 - return task -} diff --git a/timeout_handler_test.go b/timeout_handler_test.go index 4c56126..111a672 100644 --- a/timeout_handler_test.go +++ b/timeout_handler_test.go @@ -4,8 +4,6 @@ package simplex_test import ( - "sync" - "sync/atomic" "testing" "time" @@ -15,317 +13,93 @@ import ( "github.com/stretchr/testify/require" ) +const testRunInterval = 1 * time.Second + func TestAddAndRunTask(t *testing.T) { start := time.Now() l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - sent := make(chan struct{}, 1) - var count atomic.Int64 - - task := &simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "simplerun", - Deadline: start.Add(5 * time.Second), - Task: func() { - sent <- struct{}{} - count.Add(1) - }, + ran := make(chan uint64, 1) + runner := func(ids []uint64) { + require.Len(t, ids, 1) + ran <- ids[0] } - handler.AddTask(task) - handler.Tick(start.Add(2 * time.Second)) - time.Sleep(10 * time.Millisecond) - - require.Zero(t, len(sent)) - handler.Tick(start.Add(6 * time.Second)) - <-sent - require.Equal(t, int64(1), count.Load()) - - // test we only execute task once - handler.Tick(start.Add(12 * time.Second)) - time.Sleep(10 * time.Millisecond) - require.Equal(t, int64(1), count.Load()) -} - -func TestRemoveTask(t *testing.T) { - start := time.Now() - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) + handler := simplex.NewTimeoutHandler(l, start, testRunInterval, runner) defer handler.Close() - var ran bool - task := &simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "task2", - Deadline: start.Add(1 * time.Second), - Task: func() { - ran = true - }, - } + handler.AddTask(1) + handler.Tick(start.Add(testRunInterval)) + value := <-ran + require.Equal(t, uint64(1), value) - handler.AddTask(task) - handler.RemoveTask(nodes[0], "task2") - handler.Tick(start.Add(2 * time.Second)) - require.False(t, ran) + // if we dont remove the task, it should run again + handler.Tick(start.Add(2 * testRunInterval)) + value = <-ran + require.Equal(t, uint64(1), value) - // ensure no panic - handler.RemoveTask(nodes[1], "task-doesn't-exist") + handler.RemoveTask(1) + handler.Tick(start.Add(3 * testRunInterval)) + time.Sleep(100 * time.Millisecond) // give some time for the task to run if it was going to + require.Empty(t, ran) } -func TestTaskOrder(t *testing.T) { +func TestRemoveTask(t *testing.T) { start := time.Now() l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) + runner := func(ids []uint64) { + require.Fail(t, "shouldn't run") + } + handler := simplex.NewTimeoutHandler(l, start, testRunInterval, runner) defer handler.Close() - finished := make(chan struct{}) - - var mu sync.Mutex - var results []string - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "first", - Deadline: start.Add(1 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "first") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[1], - TaskID: "second", - Deadline: start.Add(2 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "second") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "noruntask", - Deadline: start.Add(4 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "norun") - mu.Unlock() - }, - }) - - handler.Tick(start.Add(3 * time.Second)) - - <-finished - <-finished - - mu.Lock() - defer mu.Unlock() - - require.Equal(t, 2, len(results)) - require.Equal(t, results[0], "first") - require.Equal(t, results[1], "second") + handler.AddTask(1) + handler.Tick(start.Add(testRunInterval / 2)) + handler.RemoveTask(1) + handler.Tick(start.Add(testRunInterval)) + time.Sleep(100 * time.Millisecond) // give some time for the task to run if it was going to } -func TestAddTasksOutOfOrder(t *testing.T) { +func TestMultipleTasks(t *testing.T) { start := time.Now() l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - finished := make(chan struct{}) - var mu sync.Mutex - var results []string - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "third", - Deadline: start.Add(3 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "third") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "second", - Deadline: start.Add(2 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "second") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[1], - TaskID: "fourth", - Deadline: start.Add(4 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "fourth") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "first", - Deadline: start.Add(1 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "first") - finished <- struct{}{} - mu.Unlock() - }, - }) - - handler.Tick(start.Add(1 * time.Second)) - <-finished - mu.Lock() - require.Equal(t, 1, len(results)) - require.Equal(t, results[0], "first") - mu.Unlock() + ran := make(chan uint64, 2) + runner := func(ids []uint64) { + for _, id := range ids { + ran <- id + } + } - handler.Tick(start.Add(3 * time.Second)) - <-finished - <-finished - mu.Lock() - require.Equal(t, 3, len(results)) - require.Equal(t, results[1], "second") - require.Equal(t, results[2], "third") - mu.Unlock() + handler := simplex.NewTimeoutHandler(l, start, testRunInterval, runner) + defer handler.Close() - handler.Tick(start.Add(4 * time.Second)) - <-finished - mu.Lock() - require.Equal(t, 4, len(results)) - require.Equal(t, results[3], "fourth") - mu.Unlock() + handler.AddTask(1) + handler.AddTask(2) + handler.Tick(start.Add(testRunInterval)) + value1 := <-ran + value2 := <-ran + require.Contains(t, []uint64{1, 2}, value1) + require.Contains(t, []uint64{1, 2}, value2) + require.NotEqual(t, value1, value2) + + // remove one task, the other should still run + handler.RemoveTask(value1) + handler.Tick(start.Add(2 * testRunInterval)) + value := <-ran + require.Equal(t, value2, value) + require.Empty(t, ran) } -func TestFindTask(t *testing.T) { - // Setup a mock logger +func TestClosed(t *testing.T) { + start := time.Now() l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - startTime := time.Now() - - handler := simplex.NewTimeoutHandler(l, startTime, nodes) - defer handler.Close() - - // Create some test tasks - task1 := &simplex.TimeoutTask{ - TaskID: "task1", - NodeID: nodes[0], - Start: 5, - End: 10, - } - - taskSameRangeDiffNode := &simplex.TimeoutTask{ - TaskID: "taskSameDiff", - NodeID: nodes[1], - Start: 5, - End: 10, + runner := func(ids []uint64) { + require.Fail(t, "shouldn't run") } - task3 := &simplex.TimeoutTask{ - TaskID: "task3", - NodeID: nodes[1], - Start: 25, - End: 30, - } + handler := simplex.NewTimeoutHandler(l, start, testRunInterval, runner) - task4 := &simplex.TimeoutTask{ - TaskID: "task4", - NodeID: nodes[1], - Start: 31, - End: 36, - } - - // Add tasks to handler - handler.AddTask(task1) - handler.AddTask(taskSameRangeDiffNode) - handler.AddTask(task3) - handler.AddTask(task4) - - tests := []struct { - name string - node simplex.NodeID - seqs []uint64 - expected *simplex.TimeoutTask - }{ - { - name: "Find task with sequence in middle of range", - node: nodes[0], - seqs: []uint64{7, 8, 9}, - expected: task1, - }, - { - name: "Find task with sequence at boundary (inclusive)", - node: nodes[0], - seqs: []uint64{5, 7}, - expected: task1, - }, - { - name: "Find task with mixed sequences (first valid sequence)", - node: nodes[0], - seqs: []uint64{3, 4, 5, 11}, - expected: task1, // 5 is in range - }, - { - name: "Same sequences, but different node", - node: nodes[1], - seqs: []uint64{7, 8, 9}, - expected: taskSameRangeDiffNode, - }, - { - name: "No sequences in range", - node: nodes[0], - seqs: []uint64{1, 2, 3, 4, 11, 12, 13, 14}, - expected: nil, - }, - { - name: "Span across many tasks", - node: nodes[1], - seqs: []uint64{26, 27, 30, 31, 33}, - expected: task3, - }, - { - name: "Unknown node", - node: simplex.NodeID("unknown"), - seqs: []uint64{5, 15, 25}, - expected: nil, - }, - { - name: "Empty sequence list", - node: nodes[1], - seqs: []uint64{}, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := simplex.FindReplicationTask(handler, tt.node, tt.seqs) - if tt.expected != result { - require.Fail(t, "not equal") - } - require.Equal(t, tt.expected, result) - }) - } + handler.Close() + handler.AddTask(1) + handler.Tick(start.Add(testRunInterval)) + time.Sleep(100 * time.Millisecond) // give some time for the task to run if it was going to } diff --git a/util.go b/util.go index 6ccea53..49d0329 100644 --- a/util.go +++ b/util.go @@ -86,13 +86,13 @@ func VerifyQC(qc QuorumCertificate, logger Logger, messageType string, quorumSiz } // GetLatestVerifiedQuorumRound returns the latest verified quorum round given -// a round, empty notarization, and last block. If all are nil, it returns nil. -func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotarization, lastBlock *VerifiedFinalizedBlock) *VerifiedQuorumRound { +// a round and empty notarization. If both are nil, it returns nil. +func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotarization) *VerifiedQuorumRound { var verifiedQuorumRound *VerifiedQuorumRound var highestRound uint64 var exists bool - if round != nil { + if round != nil && (round.finalization != nil || round.notarization != nil) { highestRound = round.num verifiedQuorumRound = &VerifiedQuorumRound{ VerifiedBlock: round.block, @@ -108,15 +108,6 @@ func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotariza verifiedQuorumRound = &VerifiedQuorumRound{ EmptyNotarization: emptyNotarization, } - highestRound = emptyNotarization.Vote.Round - exists = true - } - } - - if lastBlock != nil && (lastBlock.VerifiedBlock.BlockHeader().Round > highestRound || !exists) { - verifiedQuorumRound = &VerifiedQuorumRound{ - VerifiedBlock: lastBlock.VerifiedBlock, - Finalization: &lastBlock.Finalization, } } diff --git a/util_test.go b/util_test.go index 2668684..c86dd93 100644 --- a/util_test.go +++ b/util_test.go @@ -185,7 +185,6 @@ func TestGetHighestQuorumRound(t *testing.T) { block10 := testutil.NewTestBlock(ProtocolMetadata{Seq: 10, Round: 10}, emptyBlacklist) notarization10, err := testutil.NewNotarization(l, signatureAggregator, block10, nodes) require.NoError(t, err) - finalization10, _ := testutil.NewFinalizationRecord(t, l, signatureAggregator, block10, nodes) tests := []struct { name string @@ -202,18 +201,11 @@ func TestGetHighestQuorumRound(t *testing.T) { }, }, { - name: "only last block", - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, - expectedQr: &VerifiedQuorumRound{ - VerifiedBlock: block1, - Finalization: &finalization1, - }, + name: "nothing", + expectedQr: nil, }, { - name: "round", + name: "round with finalization", round: SetRound(block1, nil, &finalization1), expectedQr: &VerifiedQuorumRound{ VerifiedBlock: block1, @@ -229,36 +221,17 @@ func TestGetHighestQuorumRound(t *testing.T) { }, }, { - name: "higher notarized round than indexed", + name: "higher round than empty notarization", round: SetRound(block10, ¬arization10, nil), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, + eNote: testutil.NewEmptyNotarization(nodes, 1), expectedQr: &VerifiedQuorumRound{ VerifiedBlock: block10, Notarization: ¬arization10, }, }, - { - name: "higher indexed than in round", - round: SetRound(block1, ¬arization1, nil), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block10, - Finalization: finalization10, - }, - expectedQr: &VerifiedQuorumRound{ - VerifiedBlock: block10, - Finalization: &finalization10, - }, - }, { name: "higher empty notarization", eNote: testutil.NewEmptyNotarization(nodes, 100), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, round: SetRound(block10, ¬arization10, nil), expectedQr: &VerifiedQuorumRound{ EmptyNotarization: testutil.NewEmptyNotarization(nodes, 100), @@ -268,7 +241,7 @@ func TestGetHighestQuorumRound(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - qr := GetLatestVerifiedQuorumRound(tt.round, tt.eNote, tt.lastBlock) + qr := GetLatestVerifiedQuorumRound(tt.round, tt.eNote) require.Equal(t, tt.expectedQr, qr) }) }