diff --git a/blacklist.go b/blacklist.go index 53599e9..f804249 100644 --- a/blacklist.go +++ b/blacklist.go @@ -428,12 +428,12 @@ func (bl *Blacklist) FromBytes(buff []byte) error { return nil } -func (bl *Blacklist) VerifyProposedBlacklist(candidateBlacklist Blacklist, nodeCount int, round uint64) error { - if candidateBlacklist.NodeCount != uint16(nodeCount) { - return fmt.Errorf("%s, expected %d, got %d", errBlacklistInvalidNodeCount, nodeCount, candidateBlacklist.NodeCount) +func (bl *Blacklist) VerifyProposedBlacklist(candidateBlacklist Blacklist, round uint64) error { + if candidateBlacklist.NodeCount != bl.NodeCount { + return fmt.Errorf("%s, expected %d, got %d", errBlacklistInvalidNodeCount, bl.NodeCount, candidateBlacklist.NodeCount) } // 1) First thing we check that the updates even make sense. - if err := bl.verifyBlacklistUpdates(candidateBlacklist.Updates, nodeCount); err != nil { + if err := bl.verifyBlacklistUpdates(candidateBlacklist.Updates); err != nil { return fmt.Errorf("%s: %w", errBlacklistInvalidUpdates, err) } updates := candidateBlacklist.Updates @@ -447,15 +447,15 @@ func (bl *Blacklist) VerifyProposedBlacklist(candidateBlacklist Blacklist, nodeC return nil } -func (bl *Blacklist) verifyBlacklistUpdates(updates []BlacklistUpdate, nodeCount int) error { +func (bl *Blacklist) verifyBlacklistUpdates(updates []BlacklistUpdate) error { seen := make(map[uint16]struct{}) - if len(updates) > nodeCount { - return fmt.Errorf("%w: %d, only %d nodes exist", errBlacklistTooManyUpdates, len(updates), nodeCount) + if len(updates) > int(bl.NodeCount) { + return fmt.Errorf("%w: %d, only %d nodes exist", errBlacklistTooManyUpdates, len(updates), bl.NodeCount) } for _, update := range updates { - if int(update.NodeIndex) >= nodeCount { + if update.NodeIndex >= bl.NodeCount { return fmt.Errorf("%w: %d, needs to be in [%d, %d]", - errBlacklistInvalidNodeIndex, update.NodeIndex, 0, nodeCount-1) + errBlacklistInvalidNodeIndex, update.NodeIndex, 0, bl.NodeCount-1) } if _, exists := seen[update.NodeIndex]; exists { diff --git a/blacklist_test.go b/blacklist_test.go index 4d2f403..f47fb5e 100644 --- a/blacklist_test.go +++ b/blacklist_test.go @@ -91,7 +91,7 @@ func TestBlacklistVerifyProposedBlacklist(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - err := testCase.blacklist.VerifyProposedBlacklist(testCase.proposedBlacklist, testCase.nodeCount, testCase.round) + err := testCase.blacklist.VerifyProposedBlacklist(testCase.proposedBlacklist, testCase.round) require.ErrorContains(t, err, testCase.expectedErr.Error()) }) } @@ -581,6 +581,10 @@ func TestUpdateBytesEqualsLen(t *testing.T) { } func TestVerifyBlacklistUpdates(t *testing.T) { + testBlacklist := Blacklist{ + NodeCount: 4, + } + for _, testCase := range []struct { name string Blacklist Blacklist @@ -604,6 +608,7 @@ func TestVerifyBlacklistUpdates(t *testing.T) { {Type: 3, NodeIndex: 1}, }, expectedErr: errBlacklistInvalidUpdateType, + Blacklist: testBlacklist, }, { name: "invalid index", @@ -611,6 +616,7 @@ func TestVerifyBlacklistUpdates(t *testing.T) { {Type: BlacklistOpType_NodeRedeemed, NodeIndex: 4}, }, expectedErr: errBlacklistInvalidNodeIndex, + Blacklist: testBlacklist, }, { name: "double vote", @@ -620,6 +626,7 @@ func TestVerifyBlacklistUpdates(t *testing.T) { {Type: BlacklistOpType_NodeSuspected, NodeIndex: 3}, }, expectedErr: errBlacklistNodeIndexAlreadyUpdated, + Blacklist: testBlacklist, }, { name: "already blacklisted", @@ -646,7 +653,7 @@ func TestVerifyBlacklistUpdates(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - err := testCase.Blacklist.verifyBlacklistUpdates(testCase.updates, 4) + err := testCase.Blacklist.verifyBlacklistUpdates(testCase.updates) require.ErrorContains(t, err, testCase.expectedErr.Error()) }) } @@ -797,7 +804,7 @@ func simulateRound(t *testing.T, blrsi blacklistRoundSimulationInput) Blacklist newBlacklist := prevBlacklist.ApplyUpdates(updates, round) - err := prevBlacklist.VerifyProposedBlacklist(newBlacklist, nodeCount, round) + err := prevBlacklist.VerifyProposedBlacklist(newBlacklist, round) require.NoError(t, err, "round %d", round) return newBlacklist diff --git a/epoch.go b/epoch.go index d4e1af9..14150f6 100644 --- a/epoch.go +++ b/epoch.go @@ -1938,7 +1938,7 @@ func (e *Epoch) verifyProposalMetadataAndBlacklist(block Block) bool { prevBlacklist = prevBlock.Blacklist() } - if err := prevBlacklist.VerifyProposedBlacklist(block.Blacklist(), len(e.nodes), e.round); err != nil { + if err := prevBlacklist.VerifyProposedBlacklist(block.Blacklist(), e.round); err != nil { e.Logger.Debug("Block contains an invalid blacklist", zap.Error(err)) return false }