Skip to content

Commit

Permalink
handle incoming
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed Nov 1, 2023
1 parent a08a13d commit 67e7b60
Show file tree
Hide file tree
Showing 14 changed files with 498 additions and 60 deletions.
12 changes: 6 additions & 6 deletions dot/parachain/dispute/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type OverlayBackend interface {
// WriteToDB writes the given dispute to the database.
WriteToDB() error
// GetActiveDisputes returns the active disputes.
GetActiveDisputes(now int64) (*btree.BTree, error)
GetActiveDisputes(now uint64) (*btree.BTree, error)
// NoteEarliestSession prunes data in the DB based on the provided session index.
NoteEarliestSession(session parachainTypes.SessionIndex) error
}
Expand Down Expand Up @@ -68,7 +68,7 @@ type syncedRecentDisputes struct {

func newSyncedRecentDisputes() syncedRecentDisputes {
return syncedRecentDisputes{
BTree: btree.New(types.DisputeComparator),
BTree: btree.New(types.CompareDisputes),
}
}

Expand Down Expand Up @@ -161,12 +161,12 @@ func (b *overlayBackend) SetCandidateVotes(session parachainTypes.SessionIndex,
const ActiveDuration = 180 * time.Second

// GetActiveDisputes returns the active disputes, if any.
func (b *overlayBackend) GetActiveDisputes(now int64) (*btree.BTree, error) {
func (b *overlayBackend) GetActiveDisputes(now uint64) (*btree.BTree, error) {
b.recentDisputes.RLock()
recentDisputes := b.recentDisputes.Copy()
b.recentDisputes.RUnlock()

activeDisputes := btree.New(types.DisputeComparator)
activeDisputes := btree.New(types.CompareDisputes)
recentDisputes.Ascend(nil, func(i interface{}) bool {
dispute, ok := i.(*types.Dispute)
if !ok {
Expand All @@ -180,7 +180,7 @@ func (b *overlayBackend) GetActiveDisputes(now int64) (*btree.BTree, error) {
return true
}

if concludedAt != nil && *concludedAt+uint64(ActiveDuration.Seconds()) > uint64(now) {
if concludedAt != nil && *concludedAt+uint64(ActiveDuration.Seconds()) > now {
activeDisputes.Set(dispute)
}

Expand Down Expand Up @@ -213,7 +213,7 @@ func (b *overlayBackend) NoteEarliestSession(session parachainTypes.SessionIndex
}

// determine new recent disputes
newRecentDisputes := btree.New(types.DisputeComparator)
newRecentDisputes := btree.New(types.CompareDisputes)
recentDisputes.Ascend(nil, func(item interface{}) bool {
dispute := item.(*types.Dispute)
if dispute.Comparator.SessionIndex >= session {
Expand Down
8 changes: 4 additions & 4 deletions dot/parachain/dispute/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestOverlayBackend_RecentDisputes(t *testing.T) {
// with
db, err := badger.Open(badger.DefaultOptions(t.TempDir()))
require.NoError(t, err)
disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)

dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive)
require.NoError(t, err)
Expand Down Expand Up @@ -88,7 +88,7 @@ func TestOverlayBackend_GetActiveDisputes(t *testing.T) {
// with
db, err := badger.Open(badger.DefaultOptions(t.TempDir()))
require.NoError(t, err)
disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)

dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive)
require.NoError(t, err)
Expand All @@ -105,7 +105,7 @@ func TestOverlayBackend_GetActiveDisputes(t *testing.T) {
require.NoError(t, err)

// then
activeDisputes, err := backend.GetActiveDisputes(time.Now().Unix())
activeDisputes, err := backend.GetActiveDisputes(uint64(time.Now().Unix()))
require.NoError(t, err)
require.True(t, compareBTrees(disputes, activeDisputes))
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestOverlayBackend_Concurrency(t *testing.T) {
defer wg.Done()

for j := 0; j < numIterations; j++ {
disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)

dispute1, err := types.DummyDispute(parachainTypes.SessionIndex(j),
common.Hash{byte(j)},
Expand Down
2 changes: 1 addition & 1 deletion dot/parachain/dispute/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (d *disputeCoordinator) handleStartup(context overseer.Context, initialHead
error,
) {
var now = time.Now().Unix()
activeDisputes, err := d.store.GetActiveDisputes(now)
activeDisputes, err := d.store.GetActiveDisputes(uint64(now))
if err != nil {
return nil, fmt.Errorf("get active disputes: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion dot/parachain/dispute/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (b *BadgerBackend) GetEarliestSession() (*parachainTypes.SessionIndex, erro
}

func (b *BadgerBackend) GetRecentDisputes() (*btree.BTree, error) {
recentDisputes := btree.New(types.DisputeComparator)
recentDisputes := btree.New(types.CompareDisputes)

if err := b.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
Expand Down
8 changes: 4 additions & 4 deletions dot/parachain/dispute/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestDBBackend_SetRecentDisputes(t *testing.T) {
// with
db, err := badger.Open(badger.DefaultOptions(t.TempDir()))
require.NoError(t, err)
disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)
dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive)
require.NoError(t, err)
disputes.Set(dispute1)
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestDBBackend_Write(t *testing.T) {
db, err := badger.Open(badger.DefaultOptions(t.TempDir()))
require.NoError(t, err)
earliestSession := getSessionIndex(1)
disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)
dispute1, err := types.DummyDispute(1, common.Hash{1}, types.DisputeStatusActive)
require.NoError(t, err)
disputes.Set(dispute1)
Expand Down Expand Up @@ -219,7 +219,7 @@ func BenchmarkBadgerBackend_SetRecentDisputes(b *testing.B) {
require.NoError(b, err)
backend := NewDBBackend(db)

disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)
for i := 0; i < 10000; i++ {
dispute, err := types.DummyDispute(parachainTypes.SessionIndex(i), common.Hash{byte(i)}, types.DisputeStatusActive)
require.NoError(b, err)
Expand Down Expand Up @@ -267,7 +267,7 @@ func BenchmarkBadgerBackend_GetRecentDisputes(b *testing.B) {
require.NoError(b, err)
backend := NewDBBackend(db)

disputes := btree.New(types.DisputeComparator)
disputes := btree.New(types.CompareDisputes)
for i := 0; i < 1000; i++ {
dispute, err := types.DummyDispute(parachainTypes.SessionIndex(i), common.Hash{byte(1)}, types.DisputeStatusActive)
require.NoError(b, err)
Expand Down
167 changes: 155 additions & 12 deletions dot/parachain/dispute/initialized.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import (
"time"
)

type ImportStatementResult uint

const (
InvalidImport ImportStatementResult = iota
ValidImport
)

const ChainImportMaxBatchSize = 6

type Initialized struct {
Expand Down Expand Up @@ -349,13 +356,13 @@ func (i *Initialized) ProcessOnChainVotes(

// Importantly, handling import statements for backing votes also
// clears spam slots for any newly backed candidates
if err := i.HandleImportStatements(context,
if outcome := i.HandleImportStatements(context,
backend,
backingValidators.CandidateReceipt,
votes.Session,
statements,
now,
); err != nil {
); outcome == InvalidImport {
logger.Errorf("attempted import of on-chain backing votes failed. session: %v, relayParent: %v",
votes.Session,
backingValidators.CandidateReceipt.Descriptor.RelayParent,
Expand Down Expand Up @@ -417,13 +424,13 @@ func (i *Initialized) ProcessOnChainVotes(
continue
}

if err := i.HandleImportStatements(context,
if outcome := i.HandleImportStatements(context,
backend,
backingValidators.CandidateReceipt,
votes.Session,
filteredStatements,
now,
); err != nil {
); outcome == InvalidImport {
logger.Errorf("attempted import of on-chain dispute votes failed. "+
"session: %v, candidateHash: %v",
votes.Session,
Expand All @@ -447,20 +454,107 @@ func (i *Initialized) HandleIncoming(
backend OverlayBackend,
message types.DisputeCoordinatorMessage,
now uint64,
) error {
) (func() error, error) {
switch {
case message.ImportStatements != nil:
logger.Tracef("in HandleIncoming::ImportStatements")
logger.Tracef("HandleIncoming::ImportStatements")
outcome := i.HandleImportStatements(context,
backend,
message.ImportStatements.CandidateReceipt,
message.ImportStatements.Session,
message.ImportStatements.Statements,
now,
)

report := func() error {
if message.ImportStatements.PendingConfirmation != nil {
if err := message.ImportStatements.PendingConfirmation.SendMessage(outcome); err != nil {
return fmt.Errorf("confirm import statements: %w", err)
}
}

return nil
}

if outcome == InvalidImport {
return nil, report()
}

return report, nil
case message.RecentDisputes != nil:
logger.Tracef("HandleIncoming::RecentDisputes")
recentDisputes, err := backend.GetRecentDisputes()
if err != nil {
return nil, fmt.Errorf("get recent disputes: %w", err)
}

if err := message.RecentDisputes.Sender.SendMessage(recentDisputes); err != nil {
return nil, fmt.Errorf("send recent disputes: %w", err)
}
case message.ActiveDisputes != nil:
logger.Tracef("HandleIncoming::ActiveDisputes")
activeDisputes, err := backend.GetActiveDisputes(now)
if err != nil {
return nil, fmt.Errorf("get active disputes: %w", err)
}

if err := message.ActiveDisputes.Sender.SendMessage(activeDisputes); err != nil {
return nil, fmt.Errorf("send active disputes: %w", err)
}
case message.QueryCandidateVotes != nil:
logger.Tracef("HandleIncoming::QueryCandidateVotes")

var queryOutput []types.QueryCandidateVotesResponse
for _, query := range message.QueryCandidateVotes.Queries {
candidateVotes, err := backend.GetCandidateVotes(query.Session, query.CandidateHash)
if err != nil {
logger.Debugf("no candidate votes found for query. session: %v, candidateHash: %v",
query.Session,
query.CandidateHash,
)
return nil, fmt.Errorf("get candidate votes: %w", err)
}

queryOutput = append(queryOutput, types.QueryCandidateVotesResponse{
Session: query.Session,
CandidateHash: query.CandidateHash,
Votes: *candidateVotes,
})
}

if err := message.QueryCandidateVotes.Sender.SendMessage(queryOutput); err != nil {
return nil, fmt.Errorf("send candidate votes: %w", err)
}
case message.IssueLocalStatement != nil:
logger.Tracef("HandleIncoming::IssueLocalStatement")
if err := i.IssueLocalStatement(context,
backend,
message.IssueLocalStatement.CandidateHash,
message.IssueLocalStatement.CandidateReceipt,
message.IssueLocalStatement.Session,
message.IssueLocalStatement.Valid,
now,
); err != nil {
return nil, fmt.Errorf("issue local statement: %w", err)
}
case message.DetermineUndisputedChain != nil:
logger.Tracef("HandleIncoming::DetermineUndisputedChain")
undisputedChain, err := i.determineUndisputedChain(backend,
message.DetermineUndisputedChain.Base,
message.DetermineUndisputedChain.BlockDescriptions,
)
if err != nil {
return nil, fmt.Errorf("determine undisputed chain: %w", err)
}

if err := message.DetermineUndisputedChain.Tx.SendMessage(undisputedChain); err != nil {
return nil, fmt.Errorf("send undisputed chain: %w", err)
}
default:
return fmt.Errorf("unknown dispute coordinator message")
return nil, fmt.Errorf("unknown dispute coordinator message")
}

return nil
return nil, nil
}

func (i *Initialized) HandleImportStatements(
Expand All @@ -470,12 +564,10 @@ func (i *Initialized) HandleImportStatements(
session parachainTypes.SessionIndex,
statements []types.Statement,
now uint64,
) error {
) ImportStatementResult {
logger.Tracef("in HandleImportStatements")

if i.sessionIsAncient(session) {

}
return InvalidImport
}

func (i *Initialized) IssueLocalStatement(
Expand All @@ -496,6 +588,57 @@ func (i *Initialized) sessionIsAncient(session parachainTypes.SessionIndex) bool
return session < diff || session < i.HighestSessionSeen
}

func (i *Initialized) determineUndisputedChain(backend OverlayBackend,
baseBlock types.Block,
blockDescriptions []types.BlockDescription,
) (types.Block, error) {
last := types.NewBlock(baseBlock.BlockNumber+uint32(len(blockDescriptions)),
blockDescriptions[len(blockDescriptions)-1].BlockHash,
)

recentDisputes, err := backend.GetRecentDisputes()
if err != nil {
return types.Block{}, fmt.Errorf("get recent disputes: %w", err)
}

if recentDisputes == nil || recentDisputes.Len() == 0 {
return last, nil
}

isPossiblyInvalid := func(session parachainTypes.SessionIndex, candidateHash common.Hash) bool {
disputeStatus := recentDisputes.Get(types.NewDisputeComparator(session, candidateHash))
status, ok := disputeStatus.(types.DisputeStatus)
if !ok {
logger.Errorf("cast to dispute status. Expected types.DisputeStatus, got %T", disputeStatus)
return false
}

isPossiblyInvalid, err := status.IsPossiblyInvalid()
if err != nil {
logger.Errorf("is possibly invalid: %s", err)
return false
}

return isPossiblyInvalid
}

for i, blockDescription := range blockDescriptions {
for _, candidate := range blockDescription.Candidates {
if isPossiblyInvalid(blockDescription.Session, candidate.Value) {
if i == 0 {
return baseBlock, nil
} else {
return types.NewBlock(baseBlock.BlockNumber+uint32(i-1),
blockDescriptions[i-1].BlockHash,
), nil
}
}
}
}

return last, nil
}

func NewInitializedState(sender overseer.Sender,
runtime parachainRuntime.RuntimeInstance,
spamSlots SpamSlots,
Expand Down
12 changes: 10 additions & 2 deletions dot/parachain/dispute/types/dispute.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ type Comparator struct {
CandidateHash common.Hash `scale:"2"`
}

// NewDisputeComparator creates a new dispute comparator.
func NewDisputeComparator(sessionIndex parachainTypes.SessionIndex, candidateHash common.Hash) Comparator {
return Comparator{
SessionIndex: sessionIndex,
CandidateHash: candidateHash,
}
}

// Dispute is a dispute for a candidate.
// It is used as an item in the btree.BTree ordered by the Comparator.
type Dispute struct {
Expand All @@ -35,8 +43,8 @@ func NewDispute() (*Dispute, error) {
}, nil
}

// DisputeComparator compares two disputes.
func DisputeComparator(a, b any) bool {
// CompareDisputes compares two disputes.
func CompareDisputes(a, b any) bool {
d1, d2 := a.(*Dispute), b.(*Dispute)

if d1.Comparator.SessionIndex == d2.Comparator.SessionIndex {
Expand Down
Loading

0 comments on commit 67e7b60

Please sign in to comment.