Skip to content

Commit

Permalink
checkpoint: complete tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed Dec 14, 2023
1 parent fa5d3bd commit 03fac30
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 66 deletions.
8 changes: 4 additions & 4 deletions dot/parachain/dispute/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ type runtimeInfo struct {
runtime parachain.RuntimeInstance
}

func (r runtimeInfo) ParachainHostPersistedValidationData(parachaidID uint32, assumption parachainTypes.OccupiedCoreAssumption) (*parachainTypes.PersistedValidationData, error) {
return r.runtime.ParachainHostPersistedValidationData(parachaidID, assumption)
func (r runtimeInfo) ParachainHostPersistedValidationData(parachainID uint32, assumption parachainTypes.OccupiedCoreAssumption) (*parachainTypes.PersistedValidationData, error) {
return r.runtime.ParachainHostPersistedValidationData(parachainID, assumption)
}

func (r runtimeInfo) ParachainHostValidationCode(parachaidID uint32, assumption parachainTypes.OccupiedCoreAssumption) (*parachainTypes.ValidationCode, error) {
return r.runtime.ParachainHostValidationCode(parachaidID, assumption)
func (r runtimeInfo) ParachainHostValidationCode(parachainID uint32, assumption parachainTypes.OccupiedCoreAssumption) (*parachainTypes.ValidationCode, error) {
return r.runtime.ParachainHostValidationCode(parachainID, assumption)
}

func (r runtimeInfo) ParachainHostCheckValidationOutputs(parachainID uint32, outputs parachainTypes.CandidateCommitments) (bool, error) {
Expand Down
2 changes: 1 addition & 1 deletion dot/parachain/dispute/comm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"time"
)

const timeout = 30 * time.Second
const timeout = 10 * time.Second

// sendMessage sends the given message to the given channel with a timeout
func sendMessage(channel chan<- any, message any) error {
Expand Down
58 changes: 58 additions & 0 deletions dot/parachain/dispute/comm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package dispute

import (
"context"
"github.com/stretchr/testify/require"
"testing"
)

func TestSendMessage(t *testing.T) {
t.Run("successful send", func(t *testing.T) {
ch := make(chan any, 1)
defer close(ch)

err := sendMessage(ch, "test")
require.NoError(t, err)
})
t.Run("timeout", func(t *testing.T) {
ch := make(chan any)
defer close(ch)

err := sendMessage(ch, "test")
require.NoError(t, err)
})
}

func TestCall(t *testing.T) {
t.Run("successful call", func(t *testing.T) {
receiver := make(chan any)
response := make(chan any)
defer close(receiver)
defer close(response)

go func() {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
select {
case <-receiver:
response <- "pong"
case <-ctx.Done():
require.NoError(t, ctx.Err())
}
}()

res, err := call(receiver, "ping", response)
require.NoError(t, err)
require.Equal(t, "pong", res)
})
t.Run("timeout", func(t *testing.T) {
receiver := make(chan any)
response := make(chan any)
defer close(receiver)
defer close(response)

res, err := call(receiver, "ping", response)
require.Error(t, err)
require.Nil(t, res)
})
}
50 changes: 50 additions & 0 deletions dot/parachain/dispute/common/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package common

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestGetByzantineThreshold(t *testing.T) {
cases := []struct {
n, expected int
}{
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 1},
{5, 1},
{6, 1},
{9, 2},
{10, 3},
// Additional cases can be added here
}

for _, c := range cases {
got := GetByzantineThreshold(c.n)
require.Equal(t, c.expected, got)
}
}

func TestGetSuperMajorityThreshold(t *testing.T) {
cases := []struct {
n, expected int
}{
{0, 0},
{1, 1},
{2, 2},
{3, 3},
{4, 3},
{5, 4},
{6, 5},
{9, 7},
{10, 7},
// Additional cases can be added here
}

for _, c := range cases {
got := GetSuperMajorityThreshold(c.n)
require.Equal(t, c.expected, got)
}
}
12 changes: 10 additions & 2 deletions dot/parachain/dispute/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const Window = 6

var logger = log.NewFromGlobal(log.AddContext("parachain", "disputes"))

// Coordinator implements the CoordinatorSubsystem interface.
// Coordinator is the disputes coordinator
type Coordinator struct {
keystore keystore.Keystore
store *overlayBackend
Expand All @@ -28,6 +28,7 @@ type Coordinator struct {
receiver chan any
}

// startupResult is the result of the startup phase
type startupResult struct {
participation []ParticipationData
votes []parachainTypes.ScrapedOnChainVotes
Expand All @@ -37,13 +38,15 @@ type startupResult struct {
gapsInCache bool
}

// initializeResult is the result of the initialization phase
type initializeResult struct {
participation []ParticipationData
votes []parachainTypes.ScrapedOnChainVotes
activatedLeaf *overseer.ActivatedLeaf
initialized *Initialized
}

// sendDisputeMessages sends the dispute message to the given receiver
func (d *Coordinator) sendDisputeMessages(
receiver chan<- any,
env types.CandidateEnvironment,
Expand Down Expand Up @@ -92,6 +95,7 @@ func (d *Coordinator) sendDisputeMessages(
}
}

// waitForFirstLeaf waits for the first active leaf update
func (d *Coordinator) waitForFirstLeaf() (*overseer.ActivatedLeaf, error) {
for {
select {
Expand All @@ -109,6 +113,7 @@ func (d *Coordinator) waitForFirstLeaf() (*overseer.ActivatedLeaf, error) {
}
}

// initialize initializes the dispute coordinator
func (d *Coordinator) initialize(sender chan<- any) (
*initializeResult,
error,
Expand Down Expand Up @@ -156,6 +161,7 @@ func (d *Coordinator) initialize(sender chan<- any) (
}, nil
}

// handleStartup handles the startup phase
func (d *Coordinator) handleStartup(sender chan<- any, initialHead *overseer.ActivatedLeaf) (
*startupResult,
error,
Expand Down Expand Up @@ -276,13 +282,14 @@ func (d *Coordinator) handleStartup(sender chan<- any, initialHead *overseer.Act
}, nil
}

// Run runs the dispute coordinator
func (d *Coordinator) Run(sender chan<- any) error {
initResult, err := d.initialize(sender)
if err != nil {
return fmt.Errorf("initialize dispute coordinator: %w", err)
}

initData := InitialData{
initData := initialData{
Participation: initResult.participation,
Votes: initResult.votes,
Leaf: initResult.activatedLeaf,
Expand All @@ -291,6 +298,7 @@ func (d *Coordinator) Run(sender chan<- any) error {
return nil
}

// NewDisputesCoordinator returns a new dispute coordinator
func NewDisputesCoordinator(db *badger.DB, receiver chan any) (*Coordinator, error) {
dbBackend := NewDBBackend(db)
backend := newOverlayBackend(dbBackend)
Expand Down
133 changes: 133 additions & 0 deletions dot/parachain/dispute/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2989,7 +2989,140 @@ func TestDisputesCoordinator(t *testing.T) {
ts.conclude(t)
})
t.Run("informs_chain_selection_when_dispute_concluded_against", func(t *testing.T) {
t.Parallel()
ts := newTestState(t)
session := parachaintypes.SessionIndex(1)
initialised := false
restarted := false
sessionEvents, err := parachaintypes.NewCandidateEvents()
require.NoError(t, err)
ts.mockRuntimeCalls(t, session, nil, &sessionEvents, nil, &initialised, &restarted)

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
ts.mockResumeSync(t, &session)
}()
go func() {
defer wg.Done()
ts.run(t)
}()
wg.Wait()
initialised = true

candidateReceipt := getValidCandidateReceipt(t)
candidateHash, err := candidateReceipt.Hash()
require.NoError(t, err)
parent1Number := uint32(1)
parent2Number := uint32(2)
block1Header := types.Header{
ParentHash: ts.lastBlock,
Number: uint(parent1Number),
StateRoot: common.Hash{},
ExtrinsicsRoot: common.Hash{},
Digest: types.NewDigest(),
}
parent1Hash := block1Header.Hash()
event := getCandidateIncludedEvent(t, candidateReceipt)
value, err := event.Value()
require.NoError(t, err)
err = sessionEvents.Add(value)
require.NoError(t, err)
ts.activateLeafAtSession(t, session, uint(parent1Number))

block2Header := types.Header{
ParentHash: ts.lastBlock,
Number: uint(parent2Number),
StateRoot: common.Hash{},
ExtrinsicsRoot: common.Hash{},
Digest: types.NewDigest(),
}
parent2Hash := block2Header.Hash()
ts.activateLeafAtSession(t, session, uint(parent2Number))

byzantineThreshold := disputesCommon.GetByzantineThreshold(len(ts.validators))
validVote, invalidVote := ts.generateOpposingVotesPair(t,
2,
1,
candidateHash,
session,
ExplicitVote,
)
wg = sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
ts.handleApprovalVoteRequest(t, candidateHash, []overseer.ApprovalSignature{})
}()
statements := []disputetypes.Statement{
{
SignedDisputeStatement: validVote,
ValidatorIndex: 2,
},
{
SignedDisputeStatement: invalidVote,
ValidatorIndex: 1,
},
}
importResult := ts.sendImportStatementsMessage(t, candidateReceipt, session, statements, make(chan any))
require.Equal(t, ValidImport, importResult)
wg.Wait()

// Use a different expected commitments hash to ensure the candidate validation returns
// invalid.
handleParticipationWithDistribution(t, ts.mockOverseer, ts.runtime, candidateHash, common.Hash{1})

statements = []disputetypes.Statement{}
for i := 3; i < byzantineThreshold+3; i++ {
vote := ts.issueExplicitStatementWithIndex(t, parachaintypes.ValidatorIndex(i), candidateHash, session, false)
statements = append(statements, disputetypes.Statement{
SignedDisputeStatement: vote,
ValidatorIndex: parachaintypes.ValidatorIndex(i),
})
}
importResult = ts.sendImportStatementsMessage(t, candidateReceipt, session, statements, make(chan any))
require.Equal(t, ValidImport, importResult)

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
select {
case msg := <-ts.mockOverseer:
switch data := msg.(type) {
case overseer.ChainSelectionMessage[overseer.RevertBlocks]:
parent1Exists := false
parent2Exists := false
for _, b := range data.Message.Blocks {
if b.Hash == parent1Hash && b.Number == parent1Number {
parent1Exists = true
}
if b.Hash == parent2Hash && b.Number == parent2Number {
parent2Exists = true
}
}
require.True(t, parent1Exists)
require.True(t, parent2Exists)
default:
err := fmt.Errorf("unexpected message type: %T", msg)
require.NoError(t, err)
}
case <-ctx.Done():
err := fmt.Errorf("timeout waiting for chain selection message")
require.NoError(t, err)
}

// One more import which should not trigger reversion
// Validator index is `byzantineThreshold + 4`
vote := ts.issueExplicitStatementWithIndex(t, parachaintypes.ValidatorIndex(byzantineThreshold+4), candidateHash, session, false)
statements = []disputetypes.Statement{
{
SignedDisputeStatement: vote,
ValidatorIndex: parachaintypes.ValidatorIndex(byzantineThreshold + 4),
},
}
_ = ts.sendImportStatementsMessage(t, candidateReceipt, session, statements, make(chan any))
ts.awaitConclude(t)
ts.conclude(t)
})
}

Expand Down
6 changes: 6 additions & 0 deletions dot/parachain/dispute/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type ImportResult interface {
) (ImportResult, error)
// IntoUpdatedVotes returns the updated votes after the import
IntoUpdatedVotes() *types.CandidateVotes
// HasFreshByzantineThresholdAgainst returns true if there are byzantineThreshold + 1 invalid votes
HasFreshByzantineThresholdAgainst() bool
}

// ImportResultHandler implements ImportResult interface
Expand Down Expand Up @@ -207,6 +209,10 @@ func (i ImportResultHandler) IntoUpdatedVotes() *types.CandidateVotes {
return &i.newState.Votes
}

func (i ImportResultHandler) HasFreshByzantineThresholdAgainst() bool {
return !i.oldState.ByzantineThresholdAgainst && i.newState.ByzantineThresholdAgainst
}

var _ ImportResult = (*ImportResultHandler)(nil)

func NewImportResultFromStatements(
Expand Down
Loading

0 comments on commit 03fac30

Please sign in to comment.