From 67e7b60c33370f7b46575c1ea587cd094c9e785d Mon Sep 17 00:00:00 2001 From: Kanishka Date: Thu, 28 Sep 2023 21:24:06 +0200 Subject: [PATCH] handle incoming --- dot/parachain/dispute/backend.go | 12 +- dot/parachain/dispute/backend_test.go | 8 +- dot/parachain/dispute/coordinator.go | 2 +- dot/parachain/dispute/db.go | 2 +- dot/parachain/dispute/db_test.go | 8 +- dot/parachain/dispute/initialized.go | 167 +++++++++++++- dot/parachain/dispute/types/dispute.go | 12 +- dot/parachain/dispute/types/dispute_test.go | 4 +- dot/parachain/dispute/types/message.go | 230 +++++++++++++++++++- dot/parachain/dispute/types/status.go | 16 ++ dot/parachain/dispute/types/vote.go | 55 +++-- dot/parachain/types/types.go | 5 + dot/parachain/types/types_test.go | 2 +- lib/babe/inherents/parachain_inherents.go | 35 +++ 14 files changed, 498 insertions(+), 60 deletions(-) diff --git a/dot/parachain/dispute/backend.go b/dot/parachain/dispute/backend.go index a38d27f1cd..9edfb2e486 100644 --- a/dot/parachain/dispute/backend.go +++ b/dot/parachain/dispute/backend.go @@ -37,7 +37,7 @@ type OverlayBackend interface { // WriteToDB writes the given dispute to the database. WriteToDB() error // GetActiveDisputes returns the active disputes. - GetActiveDisputes(now int64) (*btree.BTree, error) + GetActiveDisputes(now uint64) (*btree.BTree, error) // NoteEarliestSession prunes data in the DB based on the provided session index. NoteEarliestSession(session parachainTypes.SessionIndex) error } @@ -68,7 +68,7 @@ type syncedRecentDisputes struct { func newSyncedRecentDisputes() syncedRecentDisputes { return syncedRecentDisputes{ - BTree: btree.New(types.DisputeComparator), + BTree: btree.New(types.CompareDisputes), } } @@ -161,12 +161,12 @@ func (b *overlayBackend) SetCandidateVotes(session parachainTypes.SessionIndex, const ActiveDuration = 180 * time.Second // GetActiveDisputes returns the active disputes, if any. -func (b *overlayBackend) GetActiveDisputes(now int64) (*btree.BTree, error) { +func (b *overlayBackend) GetActiveDisputes(now uint64) (*btree.BTree, error) { b.recentDisputes.RLock() recentDisputes := b.recentDisputes.Copy() b.recentDisputes.RUnlock() - activeDisputes := btree.New(types.DisputeComparator) + activeDisputes := btree.New(types.CompareDisputes) recentDisputes.Ascend(nil, func(i interface{}) bool { dispute, ok := i.(*types.Dispute) if !ok { @@ -180,7 +180,7 @@ func (b *overlayBackend) GetActiveDisputes(now int64) (*btree.BTree, error) { return true } - if concludedAt != nil && *concludedAt+uint64(ActiveDuration.Seconds()) > uint64(now) { + if concludedAt != nil && *concludedAt+uint64(ActiveDuration.Seconds()) > now { activeDisputes.Set(dispute) } @@ -213,7 +213,7 @@ func (b *overlayBackend) NoteEarliestSession(session parachainTypes.SessionIndex } // determine new recent disputes - newRecentDisputes := btree.New(types.DisputeComparator) + newRecentDisputes := btree.New(types.CompareDisputes) recentDisputes.Ascend(nil, func(item interface{}) bool { dispute := item.(*types.Dispute) if dispute.Comparator.SessionIndex >= session { diff --git a/dot/parachain/dispute/backend_test.go b/dot/parachain/dispute/backend_test.go index 8b28f14630..e39547be47 100644 --- a/dot/parachain/dispute/backend_test.go +++ b/dot/parachain/dispute/backend_test.go @@ -40,7 +40,7 @@ func TestOverlayBackend_RecentDisputes(t *testing.T) { // with db, err := badger.Open(badger.DefaultOptions(t.TempDir())) require.NoError(t, err) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive) require.NoError(t, err) @@ -88,7 +88,7 @@ func TestOverlayBackend_GetActiveDisputes(t *testing.T) { // with db, err := badger.Open(badger.DefaultOptions(t.TempDir())) require.NoError(t, err) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive) require.NoError(t, err) @@ -105,7 +105,7 @@ func TestOverlayBackend_GetActiveDisputes(t *testing.T) { require.NoError(t, err) // then - activeDisputes, err := backend.GetActiveDisputes(time.Now().Unix()) + activeDisputes, err := backend.GetActiveDisputes(uint64(time.Now().Unix())) require.NoError(t, err) require.True(t, compareBTrees(disputes, activeDisputes)) } @@ -160,7 +160,7 @@ func TestOverlayBackend_Concurrency(t *testing.T) { defer wg.Done() for j := 0; j < numIterations; j++ { - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) dispute1, err := types.DummyDispute(parachainTypes.SessionIndex(j), common.Hash{byte(j)}, diff --git a/dot/parachain/dispute/coordinator.go b/dot/parachain/dispute/coordinator.go index 832ce0181f..ecac71b95b 100644 --- a/dot/parachain/dispute/coordinator.go +++ b/dot/parachain/dispute/coordinator.go @@ -116,7 +116,7 @@ func (d *disputeCoordinator) handleStartup(context overseer.Context, initialHead error, ) { var now = time.Now().Unix() - activeDisputes, err := d.store.GetActiveDisputes(now) + activeDisputes, err := d.store.GetActiveDisputes(uint64(now)) if err != nil { return nil, fmt.Errorf("get active disputes: %w", err) } diff --git a/dot/parachain/dispute/db.go b/dot/parachain/dispute/db.go index 2c95b95463..4dae0faa32 100644 --- a/dot/parachain/dispute/db.go +++ b/dot/parachain/dispute/db.go @@ -73,7 +73,7 @@ func (b *BadgerBackend) GetEarliestSession() (*parachainTypes.SessionIndex, erro } func (b *BadgerBackend) GetRecentDisputes() (*btree.BTree, error) { - recentDisputes := btree.New(types.DisputeComparator) + recentDisputes := btree.New(types.CompareDisputes) if err := b.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions diff --git a/dot/parachain/dispute/db_test.go b/dot/parachain/dispute/db_test.go index ad07148204..cdcafd4489 100644 --- a/dot/parachain/dispute/db_test.go +++ b/dot/parachain/dispute/db_test.go @@ -65,7 +65,7 @@ func TestDBBackend_SetRecentDisputes(t *testing.T) { // with db, err := badger.Open(badger.DefaultOptions(t.TempDir())) require.NoError(t, err) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive) require.NoError(t, err) disputes.Set(dispute1) @@ -110,7 +110,7 @@ func TestDBBackend_Write(t *testing.T) { db, err := badger.Open(badger.DefaultOptions(t.TempDir())) require.NoError(t, err) earliestSession := getSessionIndex(1) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive) require.NoError(t, err) disputes.Set(dispute1) @@ -219,7 +219,7 @@ func BenchmarkBadgerBackend_SetRecentDisputes(b *testing.B) { require.NoError(b, err) backend := NewDBBackend(db) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) for i := 0; i < 10000; i++ { dispute, err := types.DummyDispute(parachainTypes.SessionIndex(i), common.Hash{byte(i)}, types.DisputeStatusActive) require.NoError(b, err) @@ -267,7 +267,7 @@ func BenchmarkBadgerBackend_GetRecentDisputes(b *testing.B) { require.NoError(b, err) backend := NewDBBackend(db) - disputes := btree.New(types.DisputeComparator) + disputes := btree.New(types.CompareDisputes) for i := 0; i < 1000; i++ { dispute, err := types.DummyDispute(parachainTypes.SessionIndex(i), common.Hash{byte(1)}, types.DisputeStatusActive) require.NoError(b, err) diff --git a/dot/parachain/dispute/initialized.go b/dot/parachain/dispute/initialized.go index f179b2700a..01383ee33e 100644 --- a/dot/parachain/dispute/initialized.go +++ b/dot/parachain/dispute/initialized.go @@ -14,6 +14,13 @@ import ( "time" ) +type ImportStatementResult uint + +const ( + InvalidImport ImportStatementResult = iota + ValidImport +) + const ChainImportMaxBatchSize = 6 type Initialized struct { @@ -349,13 +356,13 @@ func (i *Initialized) ProcessOnChainVotes( // Importantly, handling import statements for backing votes also // clears spam slots for any newly backed candidates - if err := i.HandleImportStatements(context, + if outcome := i.HandleImportStatements(context, backend, backingValidators.CandidateReceipt, votes.Session, statements, now, - ); err != nil { + ); outcome == InvalidImport { logger.Errorf("attempted import of on-chain backing votes failed. session: %v, relayParent: %v", votes.Session, backingValidators.CandidateReceipt.Descriptor.RelayParent, @@ -417,13 +424,13 @@ func (i *Initialized) ProcessOnChainVotes( continue } - if err := i.HandleImportStatements(context, + if outcome := i.HandleImportStatements(context, backend, backingValidators.CandidateReceipt, votes.Session, filteredStatements, now, - ); err != nil { + ); outcome == InvalidImport { logger.Errorf("attempted import of on-chain dispute votes failed. "+ "session: %v, candidateHash: %v", votes.Session, @@ -447,20 +454,107 @@ func (i *Initialized) HandleIncoming( backend OverlayBackend, message types.DisputeCoordinatorMessage, now uint64, -) error { +) (func() error, error) { switch { case message.ImportStatements != nil: - logger.Tracef("in HandleIncoming::ImportStatements") + logger.Tracef("HandleIncoming::ImportStatements") + outcome := i.HandleImportStatements(context, + backend, + message.ImportStatements.CandidateReceipt, + message.ImportStatements.Session, + message.ImportStatements.Statements, + now, + ) + + report := func() error { + if message.ImportStatements.PendingConfirmation != nil { + if err := message.ImportStatements.PendingConfirmation.SendMessage(outcome); err != nil { + return fmt.Errorf("confirm import statements: %w", err) + } + } + + return nil + } + + if outcome == InvalidImport { + return nil, report() + } + + return report, nil case message.RecentDisputes != nil: + logger.Tracef("HandleIncoming::RecentDisputes") + recentDisputes, err := backend.GetRecentDisputes() + if err != nil { + return nil, fmt.Errorf("get recent disputes: %w", err) + } + + if err := message.RecentDisputes.Sender.SendMessage(recentDisputes); err != nil { + return nil, fmt.Errorf("send recent disputes: %w", err) + } case message.ActiveDisputes != nil: + logger.Tracef("HandleIncoming::ActiveDisputes") + activeDisputes, err := backend.GetActiveDisputes(now) + if err != nil { + return nil, fmt.Errorf("get active disputes: %w", err) + } + + if err := message.ActiveDisputes.Sender.SendMessage(activeDisputes); err != nil { + return nil, fmt.Errorf("send active disputes: %w", err) + } case message.QueryCandidateVotes != nil: + logger.Tracef("HandleIncoming::QueryCandidateVotes") + + var queryOutput []types.QueryCandidateVotesResponse + for _, query := range message.QueryCandidateVotes.Queries { + candidateVotes, err := backend.GetCandidateVotes(query.Session, query.CandidateHash) + if err != nil { + logger.Debugf("no candidate votes found for query. session: %v, candidateHash: %v", + query.Session, + query.CandidateHash, + ) + return nil, fmt.Errorf("get candidate votes: %w", err) + } + + queryOutput = append(queryOutput, types.QueryCandidateVotesResponse{ + Session: query.Session, + CandidateHash: query.CandidateHash, + Votes: *candidateVotes, + }) + } + + if err := message.QueryCandidateVotes.Sender.SendMessage(queryOutput); err != nil { + return nil, fmt.Errorf("send candidate votes: %w", err) + } case message.IssueLocalStatement != nil: + logger.Tracef("HandleIncoming::IssueLocalStatement") + if err := i.IssueLocalStatement(context, + backend, + message.IssueLocalStatement.CandidateHash, + message.IssueLocalStatement.CandidateReceipt, + message.IssueLocalStatement.Session, + message.IssueLocalStatement.Valid, + now, + ); err != nil { + return nil, fmt.Errorf("issue local statement: %w", err) + } case message.DetermineUndisputedChain != nil: + logger.Tracef("HandleIncoming::DetermineUndisputedChain") + undisputedChain, err := i.determineUndisputedChain(backend, + message.DetermineUndisputedChain.Base, + message.DetermineUndisputedChain.BlockDescriptions, + ) + if err != nil { + return nil, fmt.Errorf("determine undisputed chain: %w", err) + } + + if err := message.DetermineUndisputedChain.Tx.SendMessage(undisputedChain); err != nil { + return nil, fmt.Errorf("send undisputed chain: %w", err) + } default: - return fmt.Errorf("unknown dispute coordinator message") + return nil, fmt.Errorf("unknown dispute coordinator message") } - return nil + return nil, nil } func (i *Initialized) HandleImportStatements( @@ -470,12 +564,10 @@ func (i *Initialized) HandleImportStatements( session parachainTypes.SessionIndex, statements []types.Statement, now uint64, -) error { +) ImportStatementResult { logger.Tracef("in HandleImportStatements") - if i.sessionIsAncient(session) { - - } + return InvalidImport } func (i *Initialized) IssueLocalStatement( @@ -496,6 +588,57 @@ func (i *Initialized) sessionIsAncient(session parachainTypes.SessionIndex) bool return session < diff || session < i.HighestSessionSeen } +func (i *Initialized) determineUndisputedChain(backend OverlayBackend, + baseBlock types.Block, + blockDescriptions []types.BlockDescription, +) (types.Block, error) { + last := types.NewBlock(baseBlock.BlockNumber+uint32(len(blockDescriptions)), + blockDescriptions[len(blockDescriptions)-1].BlockHash, + ) + + recentDisputes, err := backend.GetRecentDisputes() + if err != nil { + return types.Block{}, fmt.Errorf("get recent disputes: %w", err) + } + + if recentDisputes == nil || recentDisputes.Len() == 0 { + return last, nil + } + + isPossiblyInvalid := func(session parachainTypes.SessionIndex, candidateHash common.Hash) bool { + disputeStatus := recentDisputes.Get(types.NewDisputeComparator(session, candidateHash)) + status, ok := disputeStatus.(types.DisputeStatus) + if !ok { + logger.Errorf("cast to dispute status. Expected types.DisputeStatus, got %T", disputeStatus) + return false + } + + isPossiblyInvalid, err := status.IsPossiblyInvalid() + if err != nil { + logger.Errorf("is possibly invalid: %s", err) + return false + } + + return isPossiblyInvalid + } + + for i, blockDescription := range blockDescriptions { + for _, candidate := range blockDescription.Candidates { + if isPossiblyInvalid(blockDescription.Session, candidate.Value) { + if i == 0 { + return baseBlock, nil + } else { + return types.NewBlock(baseBlock.BlockNumber+uint32(i-1), + blockDescriptions[i-1].BlockHash, + ), nil + } + } + } + } + + return last, nil +} + func NewInitializedState(sender overseer.Sender, runtime parachainRuntime.RuntimeInstance, spamSlots SpamSlots, diff --git a/dot/parachain/dispute/types/dispute.go b/dot/parachain/dispute/types/dispute.go index f96f92a635..5c89d878c6 100644 --- a/dot/parachain/dispute/types/dispute.go +++ b/dot/parachain/dispute/types/dispute.go @@ -15,6 +15,14 @@ type Comparator struct { CandidateHash common.Hash `scale:"2"` } +// NewDisputeComparator creates a new dispute comparator. +func NewDisputeComparator(sessionIndex parachainTypes.SessionIndex, candidateHash common.Hash) Comparator { + return Comparator{ + SessionIndex: sessionIndex, + CandidateHash: candidateHash, + } +} + // Dispute is a dispute for a candidate. // It is used as an item in the btree.BTree ordered by the Comparator. type Dispute struct { @@ -35,8 +43,8 @@ func NewDispute() (*Dispute, error) { }, nil } -// DisputeComparator compares two disputes. -func DisputeComparator(a, b any) bool { +// CompareDisputes compares two disputes. +func CompareDisputes(a, b any) bool { d1, d2 := a.(*Dispute), b.(*Dispute) if d1.Comparator.SessionIndex == d2.Comparator.SessionIndex { diff --git a/dot/parachain/dispute/types/dispute_test.go b/dot/parachain/dispute/types/dispute_test.go index a55bb92a97..475cc14b53 100644 --- a/dot/parachain/dispute/types/dispute_test.go +++ b/dot/parachain/dispute/types/dispute_test.go @@ -72,8 +72,8 @@ func TestDispute_Comparator(t *testing.T) { } // when - less12 := DisputeComparator(&dispute1, &dispute2) - less23 := DisputeComparator(&dispute2, &dispute3) + less12 := CompareDisputes(&dispute1, &dispute2) + less23 := CompareDisputes(&dispute2, &dispute3) // then require.True(t, less12) diff --git a/dot/parachain/dispute/types/message.go b/dot/parachain/dispute/types/message.go index 8a2d4d3146..ecb3b34176 100644 --- a/dot/parachain/dispute/types/message.go +++ b/dot/parachain/dispute/types/message.go @@ -4,10 +4,12 @@ import ( "fmt" "github.com/ChainSafe/gossamer/dot/parachain/dispute/overseer" parachainTypes "github.com/ChainSafe/gossamer/dot/parachain/types" + "github.com/ChainSafe/gossamer/lib/babe/inherents" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" ) +// UncheckedDisputeMessage is a dispute message where signatures of statements have not yet been checked. type UncheckedDisputeMessage struct { candidateReceipt parachainTypes.CandidateReceipt sessionIndex parachainTypes.SessionIndex @@ -15,12 +17,15 @@ type UncheckedDisputeMessage struct { validVote Vote } +// Index returns the index of the UncheckedDisputeMessage enum func (UncheckedDisputeMessage) Index() uint { return 0 } +// DisputeMessage is a dispute message. type DisputeMessage scale.VaryingDataType +// Set will set a VaryingDataTypeValue using the underlying VaryingDataType func (dm *DisputeMessage) Set(val scale.VaryingDataTypeValue) (err error) { vdt := scale.VaryingDataType(*dm) err = vdt.Set(val) @@ -31,6 +36,7 @@ func (dm *DisputeMessage) Set(val scale.VaryingDataTypeValue) (err error) { return nil } +// Value returns the value from the underlying VaryingDataType func (dm *DisputeMessage) Value() (val scale.VaryingDataTypeValue, err error) { vdt := scale.VaryingDataType(*dm) val, err = vdt.Value() @@ -40,19 +46,200 @@ func (dm *DisputeMessage) Value() (val scale.VaryingDataTypeValue, err error) { return val, nil } -func NewDisputeMessage(info parachainTypes.SessionInfo, - votes CandidateVotes, - ourVote *SignedDisputeStatement, - ourIndex parachainTypes.ValidatorIndex, +// NewDisputeMessageFromSignedStatements build a `SignedDisputeMessage` and check what can be checked. +// +// This function checks that: +// +// - both statements concern the same candidate +// - both statements concern the same session +// - the invalid statement is indeed an invalid one +// - the valid statement is indeed a valid one +// - The passed `CandidateReceipt` has the correct hash (as signed in the statements). +// - the given validator indices match with the given `ValidatorId`s in the statements, +// given a `SessionInfo`. +// +// We don't check whether the given `SessionInfo` matches the `SessionIndex` in the +// statements, because we can't without doing a runtime query. Nevertheless, this smart +// constructor gives relative strong guarantees that the resulting `SignedDisputeStatement` is +// valid and good. Even the passed `SessionInfo` is most likely right if this function +// returns `Some`, because otherwise the passed `ValidatorId`s in the `SessionInfo` at +// their given index would very likely not match the `ValidatorId`s in the statements. +func NewDisputeMessageFromSignedStatements( + validStatement SignedDisputeStatement, + validIndex parachainTypes.ValidatorIndex, + invalidStatement SignedDisputeStatement, + invalidIndex parachainTypes.ValidatorIndex, + candidateReceipt parachainTypes.CandidateReceipt, + sessionInfo parachainTypes.SessionInfo, ) (DisputeMessage, error) { + candidateHash := validStatement.CandidateHash + + // check that both statements concern the same candidate + if candidateHash != invalidStatement.CandidateHash { + return DisputeMessage{}, fmt.Errorf("candidate hashes do not match") + } + + sessionIndex := validStatement.SessionIndex + + // check that both statements concern the same session + if sessionIndex != invalidStatement.SessionIndex { + return DisputeMessage{}, fmt.Errorf("session indices do not match") + } + + if validIndex > parachainTypes.ValidatorIndex(len(sessionInfo.Validators)) { + return DisputeMessage{}, fmt.Errorf("invalid validator index") + } + validID := sessionInfo.Validators[validIndex] + if validID != validStatement.ValidatorPublic { + return DisputeMessage{}, fmt.Errorf("valid validator ID does not match") + } + + if invalidIndex > parachainTypes.ValidatorIndex(len(sessionInfo.Validators)) { + return DisputeMessage{}, fmt.Errorf("invalid validator index") + } + invalidID := sessionInfo.Validators[invalidIndex] + if invalidID != invalidStatement.ValidatorPublic { + return DisputeMessage{}, fmt.Errorf("invalid validator ID does not match") + } + + candidateReceiptHash, err := candidateReceipt.Hash() + if err != nil { + return DisputeMessage{}, fmt.Errorf("hash candidate receipt: %w", err) + } + + // check that the passed `CandidateReceipt` has the correct hash (as signed in the statements) + if candidateReceiptHash != candidateHash { + return DisputeMessage{}, fmt.Errorf("candidate receipt hash does not match") + } + + kind, err := validStatement.DisputeStatement.Value() + if err != nil { + return DisputeMessage{}, fmt.Errorf("get valid dispute statement value: %w", err) + } + validKind, ok := kind.(inherents.ValidDisputeStatementKind) + if !ok { + return DisputeMessage{}, fmt.Errorf("valid dispute statement kind has invalid type") + } + + kind, err = invalidStatement.DisputeStatement.Value() + if err != nil { + return DisputeMessage{}, fmt.Errorf("get invalid dispute statement value: %w", err) + } + invalidKind, ok := kind.(inherents.InvalidDisputeStatementKind) + if !ok { + return DisputeMessage{}, fmt.Errorf("invalid dispute statement kind has valid type") + } + + validVote := Vote{ + ValidatorIndex: validIndex, + DisputeStatement: inherents.DisputeStatement(validKind), + ValidatorSignature: validStatement.ValidatorSignature, + } + invalidVote := Vote{ + ValidatorIndex: invalidIndex, + DisputeStatement: inherents.DisputeStatement(invalidKind), + ValidatorSignature: invalidStatement.ValidatorSignature, + } + vdt, err := scale.NewVaryingDataType(UncheckedDisputeMessage{}) if err != nil { return DisputeMessage{}, fmt.Errorf("failed to create varying data type: %w", err) } + disputeMessage := UncheckedDisputeMessage{ + candidateReceipt: candidateReceipt, + sessionIndex: sessionIndex, + invalidVote: invalidVote, + validVote: validVote, + } + if err := vdt.Set(disputeMessage); err != nil { + return DisputeMessage{}, fmt.Errorf("set dispute message: %w", err) + } return DisputeMessage(vdt), nil } +// NewDisputeMessage creates a new dispute message. +func NewDisputeMessage( + info parachainTypes.SessionInfo, + votes CandidateVotes, + ourVote *SignedDisputeStatement, + ourIndex parachainTypes.ValidatorIndex, +) (DisputeMessage, error) { + disputeStatement, err := ourVote.DisputeStatement.Value() + if err != nil { + return DisputeMessage{}, fmt.Errorf("get dispute statement value: %w", err) + } + + var ( + validStatement SignedDisputeStatement + validIndex parachainTypes.ValidatorIndex + invalidStatement SignedDisputeStatement + invalidIndex parachainTypes.ValidatorIndex + ) + + var firstVote Vote + _, ok := disputeStatement.(inherents.ValidDisputeStatementKind) + if ok { + votes.Invalid.Descend(nil, func(i interface{}) bool { + firstVote, ok = i.(Vote) + return ok + }) + + invalidDisputeStatement := inherents.NewInvalidDisputeStatement() + + if firstVote.ValidatorIndex > parachainTypes.ValidatorIndex(len(info.Validators)) { + return DisputeMessage{}, fmt.Errorf("invalid validator index") + } + + validStatement = *ourVote + validIndex = ourIndex + + validatorPublicKey := info.Validators[firstVote.ValidatorIndex] + invalidStatement = NewSignedDisputeStatement( + invalidDisputeStatement, + ourVote.CandidateHash, + ourVote.SessionIndex, + validatorPublicKey, + ourVote.ValidatorSignature, + ) + invalidIndex = firstVote.ValidatorIndex + } else { + votes.Valid.Value.Descend(nil, func(i interface{}) bool { + firstVote, ok = i.(Vote) + return ok + }) + + validDisputeStatement := inherents.NewValidDisputeStatement() + + if firstVote.ValidatorIndex > parachainTypes.ValidatorIndex(len(info.Validators)) { + return DisputeMessage{}, fmt.Errorf("invalid validator index") + } + + validIndex = firstVote.ValidatorIndex + validatorPublicKey := info.Validators[firstVote.ValidatorIndex] + validStatement = NewSignedDisputeStatement( + validDisputeStatement, + ourVote.CandidateHash, + ourVote.SessionIndex, + validatorPublicKey, + ourVote.ValidatorSignature, + ) + + invalidStatement = *ourVote + invalidIndex = ourIndex + } + + return NewDisputeMessageFromSignedStatements( + validStatement, + validIndex, + invalidStatement, + invalidIndex, + votes.CandidateReceipt, + info, + ) +} + +// ImportStatementsMessage import statements by validators about a candidate type ImportStatementsMessage struct { CandidateReceipt parachainTypes.CandidateReceipt Session parachainTypes.SessionIndex @@ -60,48 +247,72 @@ type ImportStatementsMessage struct { PendingConfirmation overseer.Sender } +// RecentDisputesMessage message to request recent disputes type RecentDisputesMessage struct { Sender overseer.Sender } +// ActiveDisputesMessage message to request active disputes type ActiveDisputesMessage struct { Sender overseer.Sender } +// CandidateVotesMessage message to request candidate votes type CandidateVotesMessage struct { Session parachainTypes.SessionIndex - CandidateHash parachainTypes.CandidateHash + CandidateHash common.Hash } +// QueryCandidateVotesMessage message to request candidate votes type QueryCandidateVotesMessage struct { Sender overseer.Sender Queries []CandidateVotesMessage } +// QueryCandidateVotesResponse response to a candidate votes query +type QueryCandidateVotesResponse struct { + Session parachainTypes.SessionIndex + CandidateHash common.Hash + Votes CandidateVotes +} + +// IssueLocalStatementMessage message to issue a local statement type IssueLocalStatementMessage struct { Session parachainTypes.SessionIndex - CandidateHash parachainTypes.CandidateHash + CandidateHash common.Hash CandidateReceipt parachainTypes.CandidateReceipt Valid bool } +// Block represents a block type Block struct { BlockNumber uint32 Hash common.Hash } +// NewBlock creates a new block +func NewBlock(blockNumber uint32, hash common.Hash) Block { + return Block{ + BlockNumber: blockNumber, + Hash: hash, + } +} + +// BlockDescription describes a block with its session and candidates type BlockDescription struct { BlockHash common.Hash Session parachainTypes.SessionIndex Candidates []parachainTypes.CandidateHash } +// DetermineUndisputedChainMessage message to determine the undisputed chain type DetermineUndisputedChainMessage struct { - Base Block - BlockDescription BlockDescription - Tx overseer.Sender + Base Block + BlockDescriptions []BlockDescription + Tx overseer.Sender } +// DisputeCoordinatorMessage messages received by the dispute coordinator subsystem type DisputeCoordinatorMessage struct { ImportStatements *ImportStatementsMessage RecentDisputes *RecentDisputesMessage @@ -111,6 +322,7 @@ type DisputeCoordinatorMessage struct { DetermineUndisputedChain *DetermineUndisputedChainMessage } +// OverseerSignal signals received by the overseer subsystem type OverseerSignal struct { ActiveLeaves *overseer.ActiveLeavesUpdate BlockFinalised *Block diff --git a/dot/parachain/dispute/types/status.go b/dot/parachain/dispute/types/status.go index e5254df103..bc49e5b49c 100644 --- a/dot/parachain/dispute/types/status.go +++ b/dot/parachain/dispute/types/status.go @@ -217,6 +217,22 @@ func (ds *DisputeStatus) IsConcludedAgainst() (bool, error) { return false, nil } +// IsPossiblyInvalid returns true if the dispute is possibly invalid. +func (ds *DisputeStatus) IsPossiblyInvalid() (bool, error) { + vdt := scale.VaryingDataType(*ds) + val, err := vdt.Value() + if err != nil { + return false, fmt.Errorf("getting value from DisputeStatus vdt: %w", err) + } + + switch val.(type) { + case ActiveStatus, ConfirmedStatus, ConcludedAgainstStatus: + return true, nil + default: + return false, nil + } +} + // NewDisputeStatus returns a new DisputeStatus. func NewDisputeStatus() (DisputeStatus, error) { vdt, err := scale.NewVaryingDataType(ActiveStatus{}, diff --git a/dot/parachain/dispute/types/vote.go b/dot/parachain/dispute/types/vote.go index ebd4675f06..ceb148fe4b 100644 --- a/dot/parachain/dispute/types/vote.go +++ b/dot/parachain/dispute/types/vote.go @@ -3,6 +3,7 @@ package types import ( "fmt" "github.com/emirpasic/gods/sets/treeset" + "github.com/tidwall/btree" parachainTypes "github.com/ChainSafe/gossamer/dot/parachain/types" "github.com/ChainSafe/gossamer/lib/babe/inherents" @@ -224,7 +225,7 @@ func NewCandidateVoteState(votes CandidateVotes, now uint64) (CandidateVoteState // TODO: get supermajority threshold superMajorityThreshold := 0 - isDisputed := !(len(votes.Invalid) == 0) && !(len(votes.Valid) == 0) + isDisputed := !(votes.Invalid.Len() == 0) && !(votes.Valid.Value.Len() == 0) if isDisputed { status, err = NewDisputeStatus() if err != nil { @@ -234,21 +235,21 @@ func NewCandidateVoteState(votes CandidateVotes, now uint64) (CandidateVoteState // TODO: get byzantine threshold byzantineThreshold := 0 - isConfirmed := len(votes.Valid) > byzantineThreshold + isConfirmed := votes.Valid.Value.Len() > byzantineThreshold if isConfirmed { if err := status.Confirm(); err != nil { return CandidateVoteState{}, fmt.Errorf("failed to confirm dispute status: %w", err) } } - isConcludedFor := len(votes.Valid) > superMajorityThreshold + isConcludedFor := votes.Valid.Value.Len() > superMajorityThreshold if isConcludedFor { if err := status.ConcludeFor(now); err != nil { return CandidateVoteState{}, fmt.Errorf("failed to conclude dispute status for: %w", err) } } - isConcludedAgainst := len(votes.Invalid) >= superMajorityThreshold + isConcludedAgainst := votes.Invalid.Len() >= superMajorityThreshold if isConcludedAgainst { if err := status.ConcludeAgainst(now); err != nil { return CandidateVoteState{}, fmt.Errorf("failed to conclude dispute status against: %w", err) @@ -278,12 +279,16 @@ func NewCandidateVoteStateFromReceipt(receipt parachainTypes.CandidateReceipt) ( } // ValidCandidateVotes is a list of valid votes for a candidate. -type ValidCandidateVotes map[parachainTypes.ValidatorIndex]Vote +type ValidCandidateVotes struct { + VotedValidators map[parachainTypes.ValidatorIndex]struct{} + Value *btree.BTree +} func (vcv ValidCandidateVotes) InsertVote(vote Vote) (bool, error) { - existingVote, ok := vcv[vote.ValidatorIndex] + existingVote, ok := vcv.Value.Get(vote.ValidatorIndex).(Vote) if !ok { - vcv[vote.ValidatorIndex] = vote + vcv.Value.Set(vote) + vcv.VotedValidators[vote.ValidatorIndex] = struct{}{} return true, nil } @@ -298,7 +303,8 @@ func (vcv ValidCandidateVotes) InsertVote(vote Vote) (bool, error) { case inherents.ExplicitValidDisputeStatementKind, inherents.ExplicitInvalidDisputeStatementKind, inherents.ApprovalChecking: - vcv[vote.ValidatorIndex] = vote + vcv.Value.Set(vote) + vcv.VotedValidators[vote.ValidatorIndex] = struct{}{} return true, nil default: return false, fmt.Errorf("invalid dispute statement type: %T", disputeStatement) @@ -309,27 +315,40 @@ func (vcv ValidCandidateVotes) InsertVote(vote Vote) (bool, error) { type CandidateVotes struct { CandidateReceipt parachainTypes.CandidateReceipt `scale:"1"` // TODO: check if we need to use btree for this in the future - Valid ValidCandidateVotes `scale:"2"` - Invalid map[parachainTypes.ValidatorIndex]Vote `scale:"3"` + Valid ValidCandidateVotes `scale:"2"` + Invalid *btree.BTree `scale:"3"` } func (cv *CandidateVotes) VotedIndices() *treeset.Set { votedIndices := treeset.NewWithIntComparator() - for validatorIndex := range cv.Valid { - votedIndices.Add(validatorIndex) - } + cv.Valid.Value.Ascend(nil, func(i interface{}) bool { + vote, ok := i.(Vote) + if ok { + votedIndices.Add(vote.ValidatorIndex) + } - for validatorIndex := range cv.Invalid { - votedIndices.Add(validatorIndex) - } + return true + }) + + cv.Invalid.Ascend(nil, func(i interface{}) bool { + vote, ok := i.(Vote) + if ok { + votedIndices.Add(vote.ValidatorIndex) + } + + return true + }) return votedIndices } func NewCandidateVotes() *CandidateVotes { return &CandidateVotes{ - Valid: make(map[parachainTypes.ValidatorIndex]Vote), - Invalid: make(map[parachainTypes.ValidatorIndex]Vote), + Valid: ValidCandidateVotes{ + VotedValidators: make(map[parachainTypes.ValidatorIndex]struct{}), + Value: btree.New(parachainTypes.CompareValidatorIndices), + }, + Invalid: btree.New(parachainTypes.CompareValidatorIndices), } } diff --git a/dot/parachain/types/types.go b/dot/parachain/types/types.go index 04a29aa7de..3675d79a47 100644 --- a/dot/parachain/types/types.go +++ b/dot/parachain/types/types.go @@ -20,6 +20,11 @@ import ( // ValidatorIndex Index of the validator. Used as a lightweight replacement of the `ValidatorId` when appropriate type ValidatorIndex uint32 +// CompareValidatorIndices compares two validator indices. +func CompareValidatorIndices(a, b any) bool { + return a.(ValidatorIndex) < b.(ValidatorIndex) +} + // ValidatorID The public key of a validator. type ValidatorID [sr25519.PublicKeyLength]byte diff --git a/dot/parachain/types/types_test.go b/dot/parachain/types/types_test.go index 48df89d08a..395550dcea 100644 --- a/dot/parachain/types/types_test.go +++ b/dot/parachain/types/types_test.go @@ -378,7 +378,7 @@ func TestCandidateReceipt_Hash(t *testing.T) { candidateHash, err := receipt.Hash() require.NoError(t, err) - require.Equal(t, testData["expectedCandidateReceipt"], common.BytesToHex(candidateHash.Value[:])) + require.Equal(t, testData["expectedCandidateReceipt"], common.BytesToHex(candidateHash[:])) } func mustHexTo32BArray(t *testing.T, inputHex string) (outputArray [32]byte) { diff --git a/lib/babe/inherents/parachain_inherents.go b/lib/babe/inherents/parachain_inherents.go index 1b349f595d..7c49fee480 100644 --- a/lib/babe/inherents/parachain_inherents.go +++ b/lib/babe/inherents/parachain_inherents.go @@ -283,6 +283,41 @@ func NewDisputeStatement() DisputeStatement { //skipcq return DisputeStatement(vdt) } +// NewInvalidDisputeStatement create a new DisputeStatement varying data type. +func NewInvalidDisputeStatement() DisputeStatement { //skipcq + disputeStatement := NewDisputeStatement() + + idsKind, err := scale.NewVaryingDataType(ExplicitInvalidDisputeStatementKind{}) + if err != nil { + panic(err) + } + + err = disputeStatement.Set(InvalidDisputeStatementKind(idsKind)) + if err != nil { + panic(err) + } + + return disputeStatement +} + +// NewValidDisputeStatement create a new DisputeStatement varying data type. +func NewValidDisputeStatement() DisputeStatement { //skipcq + disputeStatement := NewDisputeStatement() + + vdsKind, err := scale.NewVaryingDataType( + ExplicitValidDisputeStatementKind{}, BackingSeconded{}, BackingValid{}, ApprovalChecking{}) + if err != nil { + panic(err) + } + + err = disputeStatement.Set(ValidDisputeStatementKind(vdsKind)) + if err != nil { + panic(err) + } + + return disputeStatement +} + // collatorID is the collator's relay-chain account ID type collatorID sr25519.PublicKey