diff --git a/dbft.go b/dbft.go index 810acee9..995d5da7 100644 --- a/dbft.go +++ b/dbft.go @@ -596,19 +596,12 @@ func (d *DBFT[H]) onCommit(msg ConsensusPayload[H]) { func (d *DBFT[H]) onRecoveryRequest(msg ConsensusPayload[H]) { if !d.CommitSent() && (!d.isAntiMEVExtensionEnabled() || !d.PreCommitSent()) { - // Limit recoveries to be sent from no more than F nodes - // TODO replace loop with a single if - shouldSend := false - - for i := 1; i <= d.F()+1; i++ { - ind := (int(msg.ValidatorIndex()) + i) % len(d.Validators) - if ind == d.MyIndex { - shouldSend = true - break - } - } + // Ignore the message if our index is not in F+1 range of the + // next (%N) ones from the sender. This limits recovery + // messages to be broadcasted through the network and F+1 + // guarantees that at least one node responds. - if !shouldSend { + if (d.MyIndex-int(msg.ValidatorIndex())+d.N()-1)%d.N() > d.F() { return } } diff --git a/dbft_test.go b/dbft_test.go index a08195fe..f6ef2b75 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -399,6 +399,53 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { }) } +func TestDBFT_OnReceiveRecoveryRequestResponds(t *testing.T) { + type recoveryset struct { + nodes int + sender int + receiver int + replies bool + } + var params []recoveryset + + for _, nodes := range []int{4, 5, 7, 10} { // 5 is a bad BFT number, but we want to test the logic anyway. + for sender := 0; sender < nodes; sender++ { + for recv := 0; recv < nodes; recv++ { + params = append(params, recoveryset{nodes, sender, recv, false}) + + for i := 1; i <= ((nodes-1)/3)+1; i++ { + ind := (sender + i) % nodes + if ind == recv { + params[len(params)-1].replies = true + break + } + } + } + } + } + + for _, param := range params { + t.Run(fmt.Sprintf("%d nodes, %d sender, %d receiver", param.nodes, param.sender, param.receiver), func(t *testing.T) { + s := newTestState(param.receiver, param.nodes) + s.currHeight = 1 + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) + service.Start(uint64(param.receiver)) + + _ = s.tryRecv() // Flush the queue if primary. + + rr := s.getRecoveryRequest(uint16(param.sender)) + service.OnReceive(rr) + rm := s.tryRecv() + if param.replies { + require.NotNil(t, rm) + require.Equal(t, dbft.RecoveryMessageType, rm.Type()) + } else { + require.Nil(t, rm) + } + }) + } +} + func TestDBFT_OnReceiveChangeView(t *testing.T) { s := newTestState(2, 4) t.Run("change view correctly", func(t *testing.T) {