From 5bf7fcf3b5fe52689fbf4d9269d367a26a17e732 Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Tue, 3 Oct 2023 13:27:05 -0400 Subject: [PATCH] lib(pkg): standalone `finality-grandpa` package (#3235) --- pkg/finality-grandpa/README | 1 + pkg/finality-grandpa/bitfield.go | 155 +++ pkg/finality-grandpa/bitfield_test.go | 145 +++ pkg/finality-grandpa/bridge_state.go | 101 ++ pkg/finality-grandpa/bridge_state_test.go | 57 ++ pkg/finality-grandpa/context.go | 139 +++ pkg/finality-grandpa/context_test.go | 96 ++ pkg/finality-grandpa/dummy_chain_test.go | 193 ++++ pkg/finality-grandpa/environment_test.go | 353 +++++++ pkg/finality-grandpa/lib.go | 400 ++++++++ pkg/finality-grandpa/lib_test.go | 179 ++++ pkg/finality-grandpa/logger.go | 61 ++ pkg/finality-grandpa/past_rounds.go | 364 +++++++ pkg/finality-grandpa/report.go | 34 + pkg/finality-grandpa/round.go | 692 +++++++++++++ pkg/finality-grandpa/round_test.go | 323 ++++++ pkg/finality-grandpa/vote_graph.go | 698 +++++++++++++ pkg/finality-grandpa/vote_graph_test.go | 350 +++++++ pkg/finality-grandpa/voter.go | 1125 +++++++++++++++++++++ pkg/finality-grandpa/voter_set.go | 185 ++++ pkg/finality-grandpa/voter_set_test.go | 110 ++ pkg/finality-grandpa/voter_test.go | 703 +++++++++++++ pkg/finality-grandpa/voting_round.go | 843 +++++++++++++++ pkg/finality-grandpa/weights.go | 24 + 24 files changed, 7331 insertions(+) create mode 100644 pkg/finality-grandpa/README create mode 100644 pkg/finality-grandpa/bitfield.go create mode 100644 pkg/finality-grandpa/bitfield_test.go create mode 100644 pkg/finality-grandpa/bridge_state.go create mode 100644 pkg/finality-grandpa/bridge_state_test.go create mode 100644 pkg/finality-grandpa/context.go create mode 100644 pkg/finality-grandpa/context_test.go create mode 100644 pkg/finality-grandpa/dummy_chain_test.go create mode 100644 pkg/finality-grandpa/environment_test.go create mode 100644 pkg/finality-grandpa/lib.go create mode 100644 pkg/finality-grandpa/lib_test.go create mode 100644 pkg/finality-grandpa/logger.go create mode 100644 pkg/finality-grandpa/past_rounds.go create mode 100644 pkg/finality-grandpa/report.go create mode 100644 pkg/finality-grandpa/round.go create mode 100644 pkg/finality-grandpa/round_test.go create mode 100644 pkg/finality-grandpa/vote_graph.go create mode 100644 pkg/finality-grandpa/vote_graph_test.go create mode 100644 pkg/finality-grandpa/voter.go create mode 100644 pkg/finality-grandpa/voter_set.go create mode 100644 pkg/finality-grandpa/voter_set_test.go create mode 100644 pkg/finality-grandpa/voter_test.go create mode 100644 pkg/finality-grandpa/voting_round.go create mode 100644 pkg/finality-grandpa/weights.go diff --git a/pkg/finality-grandpa/README b/pkg/finality-grandpa/README new file mode 100644 index 0000000000..284cc4b76f --- /dev/null +++ b/pkg/finality-grandpa/README @@ -0,0 +1 @@ +# finality-grandpa \ No newline at end of file diff --git a/pkg/finality-grandpa/bitfield.go b/pkg/finality-grandpa/bitfield.go new file mode 100644 index 0000000000..b6fbf23cdf --- /dev/null +++ b/pkg/finality-grandpa/bitfield.go @@ -0,0 +1,155 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +// A dynamically sized, write-once (per bit), lazily allocating bitfield. +type bitfield struct { + bits []uint64 +} + +// newBitfield creates a new empty bitfield. +func newBitfield() bitfield { + return bitfield{ + bits: make([]uint64, 0), + } +} + +// IsBlank returns Whether the bitfield is blank or empty. +func (b *bitfield) IsBlank() bool { //skipcq: GO-W1029 + return len(b.bits) == 0 +} + +// Merge another bitfield into this bitfield. +// +// As a result, this bitfield has all bits set that are set in either bitfield. +// +// This function only allocates if this bitfield is shorter than the other +// bitfield, in which case it is resized accordingly to accommodate for all +// bits of the other bitfield. +func (b *bitfield) Merge(other bitfield) *bitfield { //skipcq: GO-W1029 + if len(b.bits) < len(other.bits) { + b.bits = append(b.bits, make([]uint64, len(other.bits)-len(b.bits))...) + } + for i, word := range other.bits { + b.bits[i] |= word + } + return b +} + +// SetBit will set a bit in the bitfield at the specified position. +// +// If the bitfield is not large enough to accommodate for a bit set +// at the specified position, it is resized accordingly. +func (b *bitfield) SetBit(position uint) { //skipcq: GO-W1029 + wordOff := position / 64 + bitOff := position % 64 + + if wordOff >= uint(len(b.bits)) { + newLen := wordOff + 1 + b.bits = append(b.bits, make([]uint64, newLen-uint(len(b.bits)))...) + } + b.bits[wordOff] |= 1 << (63 - bitOff) +} + +// iter1s will get an iterator over all bits that are set (i.e. 1) in the bitfield, +// starting at bit position `start` and moving in steps of size `2^step` +// per word. +func (b *bitfield) iter1s(start, step uint) (bit1s []bit1) { //skipcq: GO-W1029 + return iter1s(b.bits, start, step) +} + +// Iter1sEven will get an iterator over all bits that are set (i.e. 1) at even bit positions. +func (b *bitfield) Iter1sEven() []bit1 { //skipcq: GO-W1029 + return b.iter1s(0, 1) +} + +// Iter1sOdd will get an iterator over all bits that are set (i.e. 1) at odd bit positions. +func (b *bitfield) Iter1sOdd() []bit1 { //skipcq: GO-W1029 + return b.iter1s(1, 1) +} + +// iter1sMerged will get an iterator over all bits that are set (i.e. 1) when merging +// this bitfield with another bitfield, without modifying either +// bitfield, starting at bit position `start` and moving in steps +// of size `2^step` per word. +func (b *bitfield) iter1sMerged(other bitfield, start, step uint) []bit1 { //skipcq: GO-W1029 + switch { + case len(b.bits) == len(other.bits): + zipped := make([]uint64, len(b.bits)) + for i, a := range b.bits { + b := other.bits[i] + zipped[i] = a | b + } + return iter1s(zipped, start, step) + case len(b.bits) < len(other.bits): + zipped := make([]uint64, len(other.bits)) + for i, bit := range other.bits { + var a uint64 + if i < len(b.bits) { + a = b.bits[i] + } + zipped[i] = a | bit + } + return iter1s(zipped, start, step) + case len(b.bits) > len(other.bits): + zipped := make([]uint64, len(b.bits)) + for i, a := range b.bits { + var b uint64 + if i < len(other.bits) { + b = other.bits[i] + } + zipped[i] = a | b + } + return iter1s(zipped, start, step) + default: + panic("unreachable") + } +} + +// Iter1sMergedEven will get an iterator over all bits that are set (i.e. 1) at even bit positions +// when merging this bitfield with another bitfield, without modifying +// either bitfield. +func (b *bitfield) Iter1sMergedEven(other bitfield) []bit1 { //skipcq: GO-W1029 + return b.iter1sMerged(other, 0, 1) +} + +// Iter1sMergedOdd will get an iterator over all bits that are set (i.e. 1) at odd bit positions +// when merging this bitfield with another bitfield, without modifying +// either bitfield. +func (b *bitfield) Iter1sMergedOdd(other bitfield) []bit1 { //skipcq: GO-W1029 + return b.iter1sMerged(other, 1, 1) +} + +// Turn an iterator over u64 words into an iterator over bits that +// are set (i.e. `1`) in these words, starting at bit position `start` +// and moving in steps of size `2^step` per word. +func iter1s(iter []uint64, start, step uint) (bit1s []bit1) { + if !(start < 64 && step < 7) { + panic("invalid start and step") + } + steps := (64 >> step) - (start >> step) + for i, word := range iter { + if word == 0 { + continue + } + for j := uint(0); j < steps; j++ { + bitPos := start + (j << step) + if testBit(word, bitPos) { + bit1s = append(bit1s, bit1{uint(i)*64 + bitPos}) + } + } + } + return bit1s +} + +func testBit(word uint64, position uint) bool { + mask := uint64(1 << (63 - position)) + return word&mask == mask +} + +// A bit that is set (i.e. 1) in a `bitfield`. +type bit1 struct { + // The position of the bit in the bitfield. + position uint +} diff --git a/pkg/finality-grandpa/bitfield_test.go b/pkg/finality-grandpa/bitfield_test.go new file mode 100644 index 0000000000..be1a47c4f9 --- /dev/null +++ b/pkg/finality-grandpa/bitfield_test.go @@ -0,0 +1,145 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "math" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" +) + +// Generate is used by testing/quick to genereate +func (bitfield) Generate(rand *rand.Rand, size int) reflect.Value { //skipcq: GO-W1029 + n := rand.Int() % size + bits := make([]uint64, n) + for i := range bits { + bits[i] = rand.Uint64() + } + + // we need to make sure we don't add empty words at the end of the + // bitfield otherwise it would break equality on some of the tests + // below. + for len(bits) > 0 && bits[len(bits)-1] == 0 { + bits = bits[:len(bits)-2] + } + return reflect.ValueOf(bitfield{ + bits: bits, + }) +} + +// Test if the bit at the specified position is set. +func (b *bitfield) testBit(position uint) bool { //skipcq: GO-W1029 + wordOff := position / 64 + if wordOff >= uint(len(b.bits)) { + return false + } + return testBit(b.bits[wordOff], position%64) +} + +func Test_SetBit(t *testing.T) { + f := func(a bitfield, idx uint) bool { + // let's bound the max bitfield index at 2^24. this is needed because when calling + // `set_bit` we will extend the backing vec to accommodate the given bitfield size, this + // way we restrict the maximum allocation size to 16MB. + idx = uint(math.Min(float64(idx), 1<<24)) + a.SetBit(idx) + return a.testBit(idx) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +// translated from bitor test in +// https://github.com/paritytech/finality-grandpa/blob/fbe2404574f74713bccddfe4104d60c2a32d1fe6/src/bitfield.rs#L243 +func Test_Merge(t *testing.T) { + f := func(a, b bitfield) bool { + c := newBitfield() + copy(a.bits, c.bits) + cBits := c.iter1s(0, 0) + for _, bit := range cBits { + if !(a.testBit(bit.position) || b.testBit(bit.position)) { + return false + } + } + return true + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func Test_iter1s(t *testing.T) { + t.Run("all", func(t *testing.T) { + f := func(a bitfield) bool { + b := newBitfield() + for _, bit1 := range a.iter1s(0, 0) { + b.SetBit(bit1.position) + } + return assert.Equal(t, a, b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) + + t.Run("even_odd", func(t *testing.T) { + f := func(a bitfield) bool { + b := newBitfield() + for _, bit1 := range a.Iter1sEven() { + assert.True(t, !b.testBit(bit1.position)) + assert.True(t, bit1.position%2 == 0) + b.SetBit(bit1.position) + } + for _, bit1 := range a.Iter1sOdd() { + assert.True(t, !b.testBit(bit1.position)) + assert.True(t, bit1.position%2 == 1) + b.SetBit(bit1.position) + } + return assert.Equal(t, a, b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) +} + +func Test_iter1sMerged(t *testing.T) { + t.Run("all", func(t *testing.T) { + f := func(a, b bitfield) bool { + c := newBitfield() + for _, bit1 := range a.iter1sMerged(b, 0, 0) { + c.SetBit(bit1.position) + } + return assert.Equal(t, &c, a.Merge(b)) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) + + t.Run("even_odd", func(t *testing.T) { + f := func(a, b bitfield) bool { + c := newBitfield() + for _, bit1 := range a.Iter1sMergedEven(b) { + assert.True(t, !c.testBit(bit1.position)) + assert.True(t, bit1.position%2 == 0) + c.SetBit(bit1.position) + } + for _, bit1 := range a.Iter1sMergedOdd(b) { + assert.True(t, !c.testBit(bit1.position)) + assert.True(t, bit1.position%2 == 1) + c.SetBit(bit1.position) + } + return assert.Equal(t, &c, a.Merge(b)) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) +} diff --git a/pkg/finality-grandpa/bridge_state.go b/pkg/finality-grandpa/bridge_state.go new file mode 100644 index 0000000000..7e68084901 --- /dev/null +++ b/pkg/finality-grandpa/bridge_state.go @@ -0,0 +1,101 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "sync" +) + +type waker struct { + mtx sync.RWMutex + wakeCh chan any +} + +func newWaker() *waker { + return &waker{wakeCh: make(chan any, 1000)} +} + +func (w *waker) wake() { + w.mtx.RLock() + defer w.mtx.RUnlock() + if w.wakeCh == nil { + return + } + go func() { + select { + case w.wakeCh <- nil: + default: + } + }() +} + +func (w *waker) channel() chan any { + return w.wakeCh +} + +func (w *waker) register(waker *waker) { + w.mtx.Lock() + defer w.mtx.Unlock() + w.wakeCh = waker.wakeCh +} + +// round state bridged across rounds. +type bridged[Hash, Number any] struct { + inner RoundState[Hash, Number] + // registered map[chan State[Hash, Number]]any + waker *waker + sync.RWMutex +} + +func (b *bridged[H, N]) update(new RoundState[H, N]) { + b.Lock() + b.inner = new + b.waker.wake() + b.Unlock() +} + +func (b *bridged[H, N]) get(waker *waker) RoundState[H, N] { + b.RLock() + defer b.RUnlock() + b.waker.register(waker) + return b.inner +} + +// A prior view of a round-state. +type priorView[Hash, Number any] struct { + bridged *bridged[Hash, Number] +} + +// Push an update to the latter view. +func (pv *priorView[H, N]) update(new RoundState[H, N]) { //skipcq: RVV-B0001 + pv.bridged.update(new) +} + +// A latter view of a round-state. +type latterView[Hash, Number any] struct { + bridged *bridged[Hash, Number] +} + +// Fetch a handle to the last round-state. +func (lv *latterView[H, N]) get(waker *waker) (state RoundState[H, N]) { //skipcq: RVV-B0001 + return lv.bridged.get(waker) +} + +// Constructs two views of a bridged round-state. +// +// The prior view is held by a round which produces the state and pushes updates to a latter view. +// When updating, the latter view's task is updated. +// +// The latter view is held by the subsequent round, which blocks certain activity +// while waiting for events on an older round. +func bridgeState[Hash, Number any](initial RoundState[Hash, Number]) ( + priorView[Hash, Number], + latterView[Hash, Number], +) { + br := bridged[Hash, Number]{ + inner: initial, + waker: newWaker(), + } + return priorView[Hash, Number]{&br}, latterView[Hash, Number]{&br} +} diff --git a/pkg/finality-grandpa/bridge_state_test.go b/pkg/finality-grandpa/bridge_state_test.go new file mode 100644 index 0000000000..e3ad3ab5d3 --- /dev/null +++ b/pkg/finality-grandpa/bridge_state_test.go @@ -0,0 +1,57 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "sync" + "testing" + "time" +) + +func TestBridgeState(_ *testing.T) { + initial := RoundState[string, int32]{} + + prior, latter := bridgeState(initial) + + barrier := make(chan any) + var wg sync.WaitGroup + + waker := &waker{ + wakeCh: make(chan any), + } + + var waitForFinality = func() bool { + return latter.get(waker).Finalized != nil + } + + wg.Add(2) + go func() { + defer wg.Done() + <-barrier + time.Sleep(5 * time.Millisecond) + prior.update(RoundState[string, int32]{ + PrevoteGHOST: &HashNumber[string, int32]{"5", 5}, + Finalized: &HashNumber[string, int32]{"1", 1}, + Estimate: &HashNumber[string, int32]{"3", 3}, + Completable: true, + }) + }() + + // block_on + go func() { + defer wg.Done() + <-barrier + if waitForFinality() { + return + } + for range waker.wakeCh { + if waitForFinality() { + return + } + } + }() + + close(barrier) + wg.Wait() +} diff --git a/pkg/finality-grandpa/context.go b/pkg/finality-grandpa/context.go new file mode 100644 index 0000000000..ee9e5087c7 --- /dev/null +++ b/pkg/finality-grandpa/context.go @@ -0,0 +1,139 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "golang.org/x/exp/constraints" +) + +// The context of a `Round` in which vote weights are calculated. +type context[ID constraints.Ordered] struct { + voters VoterSet[ID] + equivocations bitfield +} + +// newContext will create a new context for a round with the given set of voters. +func newContext[ID constraints.Ordered](voters VoterSet[ID]) context[ID] { + return context[ID]{ + voters: voters, + equivocations: newBitfield(), + } +} + +// Voters will return the set of voters. +func (c context[ID]) Voters() VoterSet[ID] { + return c.voters +} + +// EquivocationWeight returns the weight of observed equivocations in phase `p`. +func (c context[ID]) EquivocationWeight(p Phase) VoteWeight { + switch p { + case PrevotePhase: + return weight(c.equivocations.Iter1sEven(), c.voters) + case PrecommitPhase: + return weight(c.equivocations.Iter1sOdd(), c.voters) + default: + panic("invalid Phase") + } +} + +// Equivocated will record voter `v` as an equivocator in phase `p`. +func (c *context[ID]) Equivocated(v VoterInfo, p Phase) { + c.equivocations.SetBit(newVote[ID](v, p).bit.position) +} + +// Weight computes the vote weight on node `n` in phase `p`, taking into account +// equivocations. +func (c context[ID]) Weight(n voteNode[ID], p Phase) VoteWeight { + if c.equivocations.IsBlank() { + switch p { + case PrevotePhase: + return weight(n.bits.Iter1sEven(), c.voters) + case PrecommitPhase: + return weight(n.bits.Iter1sOdd(), c.voters) + default: + panic("invalid Phase") + } + } else { + switch p { + case PrevotePhase: + bits := n.bits.Iter1sMergedEven(c.equivocations) + return weight(bits, c.voters) + case PrecommitPhase: + bits := n.bits.Iter1sMergedOdd(c.equivocations) + return weight(bits, c.voters) + default: + panic("invalid Phase") + } + } +} + +// A single vote that can be incorporated into a `voteNode`. +type vote[ID constraints.Ordered] struct { + bit bit1 +} + +// NewVote will create a new vote cast by voter `v` in phase `p`. +func newVote[ID constraints.Ordered](v VoterInfo, p Phase) vote[ID] { + switch p { + case PrevotePhase: + return vote[ID]{ + bit: bit1{ + position: v.position * 2, + }, + } + case PrecommitPhase: + return vote[ID]{ + bit: bit1{ + position: v.position*2 + 1, + }, + } + default: + panic("invalid Phase") + } +} + +// Get the voter who cast the vote from the given voter set, +// if it is contained in that set. +func (v vote[ID]) voter(vs VoterSet[ID]) *IDVoterInfo[ID] { + return vs.Nth(v.bit.position / 2) +} + +func weight[ID constraints.Ordered](bits []bit1, voters VoterSet[ID]) (total VoteWeight) { //skipcq: RVV-B0001 + for _, bit := range bits { + vote := vote[ID]{bit} + ivi := vote.voter(voters) + if ivi != nil { + total = total + VoteWeight(ivi.VoterInfo.weight) + } + } + return +} + +type voteNodeI[voteNode, Vote any] interface { + Add(other voteNode) + AddVote(other Vote) + Copy() voteNode +} + +type voteNode[ID constraints.Ordered] struct { + bits bitfield +} + +func (vn *voteNode[ID]) Add(other *voteNode[ID]) { + vn.bits.Merge(other.bits) +} + +func (vn *voteNode[ID]) AddVote(vote vote[ID]) { + vn.bits.SetBit(vote.bit.position) +} + +func (vn *voteNode[ID]) Copy() *voteNode[ID] { + copiedBits := newBitfield() + copiedBits.bits = make([]uint64, len(vn.bits.bits)) + copy(copiedBits.bits, vn.bits.bits) + return &voteNode[ID]{ + bits: copiedBits, + } +} diff --git a/pkg/finality-grandpa/context_test.go b/pkg/finality-grandpa/context_test.go new file mode 100644 index 0000000000..f190835ae8 --- /dev/null +++ b/pkg/finality-grandpa/context_test.go @@ -0,0 +1,96 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" +) + +func (Phase) Generate(rand *rand.Rand, _ int) reflect.Value { + index := rand.Intn(2) + return reflect.ValueOf([]Phase{PrevotePhase, PrecommitPhase}[index]) +} + +func (context[ID]) Generate(rand *rand.Rand, size int) reflect.Value { + vs := VoterSet[ID]{}.Generate(rand, size).Interface().(VoterSet[ID]) + + n := rand.Int() % len(vs.voters) + equivocators := make([]VoterInfo, n+1) + for i := 0; i <= n; i++ { + ivi := vs.NthMod(uint(rand.Uint64())) + equivocators[i] = ivi.VoterInfo + } + + c := context[ID]{ + voters: vs, + } + for _, v := range equivocators { + c.Equivocated(v, Phase(0).Generate(rand, size).Interface().(Phase)) + } + return reflect.ValueOf(c) +} + +func TestVote_voter(t *testing.T) { + f := func(vs VoterSet[uint], phase Phase) bool { + for _, idv := range vs.Iter() { + id := idv.ID + v := idv.VoterInfo + eq := assert.Equal(t, &IDVoterInfo[uint]{id, v}, newVote[uint](v, phase).voter(vs)) + if !eq { + return false + } + } + return true + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestWeights(t *testing.T) { + f := func(ctx context[uint], phase Phase, voters []uint) bool { + ew := ctx.EquivocationWeight(phase) + tw := ctx.voters.TotalWeight() + + // The equivocation weight must never be larger than the total + // voter weight. + if !assert.True(t, uint64(ew) <= uint64(tw)) { + return false + } + + // Let a random subset of voters cast a vote, whether already + // an equivocator or not. + n := voteNode[uint]{} + expected := ew + for _, v := range voters { + idvi := ctx.voters.NthMod(v) + vote := newVote[uint](idvi.VoterInfo, phase) + + // We only expect the weight to increase if the voter did not + // start out as an equivocator and did not yet vote. + if !ctx.equivocations.testBit(vote.bit.position) && !n.bits.testBit(vote.bit.position) { + expected = expected + VoteWeight(idvi.VoterInfo.weight) + } + n.AddVote(vote) + } + + // Let the context compute the weight. + w := ctx.Weight(n, phase) + + // A vote-node weight must never be greater than the total voter weight. + if !assert.True(t, uint64(w) <= uint64(tw)) { + return false + } + + return assert.Equal(t, expected, w) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} diff --git a/pkg/finality-grandpa/dummy_chain_test.go b/pkg/finality-grandpa/dummy_chain_test.go new file mode 100644 index 0000000000..5dba35b64f --- /dev/null +++ b/pkg/finality-grandpa/dummy_chain_test.go @@ -0,0 +1,193 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" +) + +const GenesisHash = "genesis" +const nullHash = "NULL" + +type blockRecord struct { + hash string + number uint32 + parent string +} + +// dummyChain is translation of finality_grandpa::testing::chain::DummyChain +type dummyChain struct { + inner map[string]blockRecord + leaves []blockRecord + finalized struct { + hash string + number uint32 + } +} + +func newDummyChain() *dummyChain { + dc := &dummyChain{ + inner: make(map[string]blockRecord), + leaves: make([]blockRecord, 0), + } + dc.inner[GenesisHash] = blockRecord{ + number: 1, + parent: nullHash, + hash: GenesisHash, + } + dc.leaves = append(dc.leaves, dc.inner[GenesisHash]) + dc.finalized.hash = GenesisHash + dc.finalized.number = 1 + return dc +} + +func (dc *dummyChain) Ancestry(base, block string) (ancestors []string, err error) { + ancestors = make([]string, 0) +loop: + for { + br, ok := dc.inner[block] + if !ok { + return nil, fmt.Errorf("Block not descendent of base") + } + block = br.parent + + switch block { + case nullHash: + return nil, fmt.Errorf("Block not descendent of base") + case base: + break loop + } + ancestors = append(ancestors, block) + } + return ancestors, nil +} + +func (dc *dummyChain) IsEqualOrDescendantOf(base, block string) bool { + if base == block { + return true + } + + _, err := dc.Ancestry(base, block) + return err == nil +} + +func (dc *dummyChain) PushBlocks(parent string, blocks []string) { + br, ok := dc.inner[parent] + if !ok { + panic("could not find parent hash") + } + baseNumber := br.number + 1 + + for i, leaf := range dc.leaves { + if leaf.hash == parent { + dc.leaves = append(dc.leaves[:i], dc.leaves[i+1:]...) + } + } + + for i, descendant := range blocks { + dc.inner[descendant] = blockRecord{ + hash: descendant, + number: baseNumber + uint32(i), + parent: parent, + } + parent = descendant + } + + newLeafHash := blocks[len(blocks)-1] + newLeaf := dc.inner[newLeafHash] + insertionIndex, _ := slices.BinarySearchFunc(dc.leaves, newLeaf, func(a, b blockRecord) int { + switch { + case a.number == b.number: + return 0 + case a.number > b.number: + return -1 + case b.number > a.number: + return 1 + default: + panic("huh?") + } + }) + + switch { + case len(dc.leaves) == 0 && insertionIndex == 0: + dc.leaves = append(dc.leaves, newLeaf) + case insertionIndex == len(dc.leaves): + dc.leaves = append(dc.leaves, newLeaf) + default: + dc.leaves = append( + dc.leaves[:insertionIndex], + append([]blockRecord{newLeaf}, dc.leaves[insertionIndex:]...)...) + } +} + +func (dc *dummyChain) Number(hash string) uint32 { + e, ok := dc.inner[hash] + if !ok { + panic("huh?") + } + return e.number +} + +func (dc *dummyChain) LastFinalized() (string, uint32) { + return dc.finalized.hash, dc.finalized.number +} + +func (dc *dummyChain) SetLastFinalized(hash string, number uint32) { + dc.finalized.hash = hash + dc.finalized.number = number +} + +func (dc *dummyChain) BestChainContaining(base string) *HashNumber[string, uint32] { + baseRecord, ok := dc.inner[base] + if !ok { + return nil + } + baseNumber := baseRecord.number + + for _, leaf := range dc.leaves { + // leaves are in descending order. + leafNumber := leaf.number + if leafNumber < baseNumber { + break + } + + if leaf.hash == base { + return &HashNumber[string, uint32]{leaf.hash, leafNumber} + } + + _, err := dc.Ancestry(base, leaf.hash) + if err == nil { + return &HashNumber[string, uint32]{leaf.hash, leafNumber} + } + } + + return nil +} + +func TestDummyGraphPushBlocks(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{ + "A", "B", "C", + }) + c.PushBlocks(GenesisHash, []string{ + "A'", "B'", "C'", + }) + assert.Equal(t, []blockRecord{ + {hash: "C'", number: 4, parent: "B'"}, + {hash: "C", number: 4, parent: "B"}, + }, c.leaves) + assert.Equal(t, c.inner, map[string]blockRecord{ + GenesisHash: {GenesisHash, 1, nullHash}, + "A": {"A", 2, GenesisHash}, + "A'": {"A'", 2, GenesisHash}, + "B": {"B", 3, "A"}, + "B'": {"B'", 3, "A'"}, + "C": {"C", 4, "B"}, + "C'": {"C'", 4, "B'"}, + }) +} diff --git a/pkg/finality-grandpa/environment_test.go b/pkg/finality-grandpa/environment_test.go new file mode 100644 index 0000000000..07698cd3f5 --- /dev/null +++ b/pkg/finality-grandpa/environment_test.go @@ -0,0 +1,353 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/exp/rand" +) + +type ID uint32 + +type Signature uint32 + +type timer struct { + wakerChan *wakerChan[error] + expired bool +} + +func newTimer(in <-chan time.Time) *timer { + inErr := make(chan error) + wc := newWakerChan(inErr) + t := timer{wakerChan: wc} + go func() { + <-in + inErr <- nil + t.expired = true + }() + return &t +} + +func (t *timer) SetWaker(waker *waker) { + t.wakerChan.setWaker(waker) +} + +func (t *timer) Elapsed() (bool, error) { + return t.expired, nil +} + +type listenerItem struct { + Hash string + Number uint32 + Commit Commit[string, uint32, Signature, ID] +} + +type environment struct { + chain *dummyChain + localID ID + network *Network + listeners []chan listenerItem + lastCompleteAndConcluded [2]uint64 + mtx sync.Mutex +} + +func newEnvironment(network *Network, localID ID) environment { + return environment{ + chain: newDummyChain(), + localID: localID, + network: network, + } +} + +func (e *environment) WithChain(f func(*dummyChain)) { + e.mtx.Lock() + defer e.mtx.Unlock() + f(e.chain) +} + +func (e *environment) FinalizedStream() chan listenerItem { + e.mtx.Lock() + defer e.mtx.Unlock() + ch := make(chan listenerItem) + e.listeners = append(e.listeners, ch) + return ch +} + +func (e *environment) LastCompletedAndConcluded() [2]uint64 { + e.mtx.Lock() + defer e.mtx.Unlock() + return e.lastCompleteAndConcluded +} + +func (e *environment) Ancestry(base, block string) (ancestors []string, err error) { + return e.chain.Ancestry(base, block) +} + +func (e *environment) IsEqualOrDescendantOf(base, block string) bool { + return e.chain.IsEqualOrDescendantOf(base, block) +} + +func (e *environment) BestChainContaining(base string) BestChain[string, uint32] { + e.mtx.Lock() + defer e.mtx.Unlock() + + ch := make(chan BestChainOutput[string, uint32], 1) + ch <- BestChainOutput[string, uint32]{Value: e.chain.BestChainContaining(base)} + return ch +} + +func (e *environment) RoundData( + round uint64, + outgoing Output[string, uint32], +) RoundData[string, uint32, Signature, ID] { + incoming := e.network.MakeRoundComms(round, e.localID, outgoing) + + rd := RoundData[string, uint32, Signature, ID]{ + VoterID: &e.localID, + PrevoteTimer: newTimer(time.NewTimer(500 * time.Millisecond).C), + PrecommitTimer: newTimer(time.NewTimer(1000 * time.Millisecond).C), + Incoming: incoming, + } + return rd +} + +func (*environment) RoundCommitTimer() Timer { + inner := time.NewTimer(time.Duration(rand.Int63n(1000)) * time.Millisecond).C + timer := newTimer(inner) + return timer +} + +func (e *environment) Completed( + round uint64, + _ RoundState[string, uint32], + _ HashNumber[string, uint32], + _ HistoricalVotes[string, uint32, Signature, ID], +) error { + e.mtx.Lock() + defer e.mtx.Unlock() + e.lastCompleteAndConcluded[0] = round + return nil +} + +func (e *environment) Concluded( + round uint64, + _ RoundState[string, uint32], + _ HashNumber[string, uint32], + _ HistoricalVotes[string, uint32, Signature, ID], +) error { + e.mtx.Lock() + defer e.mtx.Unlock() + e.lastCompleteAndConcluded[1] = round + return nil +} + +func (e *environment) FinalizeBlock( + hash string, + number uint32, + _ uint64, + commit Commit[string, uint32, Signature, ID], +) error { + e.mtx.Lock() + defer e.mtx.Unlock() + + lastFinalizedHash, lastFinalizedNumber := e.chain.LastFinalized() + if number <= lastFinalizedNumber { + panic("Attempted to finalize backwards") + } + + if _, err := e.chain.Ancestry(lastFinalizedHash, hash); err != nil { + panic("Safety violation: reverting finalized block.") + } + + e.chain.SetLastFinalized(hash, number) + for _, listener := range e.listeners { + listener <- listenerItem{ + hash, number, commit, + } + } + return nil +} + +func (*environment) Proposed(_ uint64, _ PrimaryPropose[string, uint32]) error { + return nil +} + +func (*environment) Prevoted(_ uint64, _ Prevote[string, uint32]) error { + return nil +} + +func (*environment) Precommitted(_ uint64, _ Precommit[string, uint32]) error { + return nil +} + +func (*environment) PrevoteEquivocation( + round uint64, + equivocation Equivocation[ID, Prevote[string, uint32], Signature], +) { + panic(fmt.Errorf("Encountered equivocation in round %v: %v", round, equivocation)) +} + +// Note that an equivocation in prevotes has occurred. +func (*environment) PrecommitEquivocation( + round uint64, + equivocation Equivocation[ID, Precommit[string, uint32], Signature], +) { + panic(fmt.Errorf("Encountered equivocation in round %v: %v", round, equivocation)) +} + +// p2p network data for a round. +type BroadcastNetwork[M, N any] struct { + receiver chan M + senders []chan M + history []M + routing bool +} + +func NewBroadcastNetwork[M, N any]() BroadcastNetwork[M, N] { + bn := BroadcastNetwork[M, N]{ + receiver: make(chan M, 10000), + } + return bn +} + +func (bm *BroadcastNetwork[M, N]) SendMessage(message M) { + bm.receiver <- message +} + +func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) { + // buffer to 100 messages for now + in = make(chan M, 10000) + + // get history to the node. + for _, priorMessage := range bm.history { + in <- priorMessage + } + + bm.senders = append(bm.senders, in) + + if !bm.routing { + bm.routing = true + go bm.route() + } + + go func() { + for n := range out { + bm.receiver <- f(n) + } + }() + return in +} + +func (bm *BroadcastNetwork[M, N]) route() { + for msg := range bm.receiver { + bm.history = append(bm.history, msg) + for _, sender := range bm.senders { + sender <- msg + } + } +} + +type RoundNetwork struct { + BroadcastNetwork[SignedMessageError[string, uint32, Signature, ID], Message[string, uint32]] +} + +func NewRoundNetwork() *RoundNetwork { + bn := NewBroadcastNetwork[SignedMessageError[string, uint32, Signature, ID], Message[string, uint32]]() + rn := RoundNetwork{bn} + return &rn +} + +func (rn *RoundNetwork) AddNode( + f func(Message[string, uint32]) SignedMessageError[string, uint32, Signature, ID], + out chan Message[string, uint32], +) (in chan SignedMessageError[string, uint32, Signature, ID]) { + return rn.BroadcastNetwork.AddNode(f, out) +} + +type GlobalMessageNetwork struct { + BroadcastNetwork[globalInItem, CommunicationOut] +} + +func NewGlobalMessageNetwork() *GlobalMessageNetwork { + bn := NewBroadcastNetwork[globalInItem, CommunicationOut]() + gmn := GlobalMessageNetwork{bn} + return &gmn +} + +func (gmn *GlobalMessageNetwork) AddNode( + f func(CommunicationOut) globalInItem, + out chan CommunicationOut, +) (in chan globalInItem) { + return gmn.BroadcastNetwork.AddNode(f, out) +} + +// A test network. Instantiate this with `make_network`, +type Network struct { + rounds map[uint64]*RoundNetwork + globalMessages GlobalMessageNetwork + mtx sync.Mutex +} + +func NewNetwork() *Network { + return &Network{ + rounds: make(map[uint64]*RoundNetwork), + globalMessages: *NewGlobalMessageNetwork(), + } +} + +func (n *Network) MakeRoundComms( + roundNumber uint64, + nodeID ID, + out chan Message[string, uint32], +) (in chan SignedMessageError[string, uint32, Signature, ID]) { + n.mtx.Lock() + defer n.mtx.Unlock() + + round, ok := n.rounds[roundNumber] + if !ok { + round = NewRoundNetwork() + n.rounds[roundNumber] = round + } + return round.AddNode(func(message Message[string, uint32]) SignedMessageError[string, uint32, Signature, ID] { + return SignedMessageError[string, uint32, Signature, ID]{ + SignedMessage: SignedMessage[string, uint32, Signature, ID]{ + Message: message, + Signature: Signature(nodeID), + ID: nodeID, + }, + } + }, out, + ) +} + +func (n *Network) MakeGlobalComms(out chan CommunicationOut) chan globalInItem { + n.mtx.Lock() + defer n.mtx.Unlock() + + return n.globalMessages.AddNode(func(message CommunicationOut) globalInItem { + if message.variant == nil { + panic("nil message variant") + } + switch message := message.variant.(type) { + case CommunicationOutCommit[string, uint32, Signature, ID]: + ci := newCommunicationIn[string, uint32, Signature, ID](CommunicationInCommit[string, uint32, Signature, ID]{ + Number: message.Number, + CompactCommit: message.Commit.CompactCommit(), + Callback: nil, + }) + return globalInItem{ + CommunicationIn: ci, + } + default: + panic("invalid CommunicationOut variant") + } + }, out) +} + +func (n *Network) SendMessage(message CommunicationIn) { + n.globalMessages.SendMessage(globalInItem{message, nil}) +} diff --git a/pkg/finality-grandpa/lib.go b/pkg/finality-grandpa/lib.go new file mode 100644 index 0000000000..39246c4985 --- /dev/null +++ b/pkg/finality-grandpa/lib.go @@ -0,0 +1,400 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "github.com/tidwall/btree" + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) + +// HashNumber contains a block hash and block number +type HashNumber[Hash, Number any] struct { + Hash Hash + Number Number +} + +type targetHashTargetNumber[Hash, Number any] struct { + TargetHash Hash + TargetNumber Number +} + +// Prevote is a prevote for a block and its ancestors. +type Prevote[Hash, Number any] targetHashTargetNumber[Hash, Number] + +// Precommit is a precommit for a block and its ancestors. +type Precommit[Hash, Number any] targetHashTargetNumber[Hash, Number] + +// PrimaryPropose is a primary proposed block, this is a broadcast of the last round's estimate. +type PrimaryPropose[Hash, Number any] targetHashTargetNumber[Hash, Number] + +// Chain context necessary for implementation of the finality gadget. +type Chain[Hash, Number comparable] interface { + // Get the ancestry of a block up to but not including the base hash. + // Should be in reverse order from `block`'s parent. + // + // If the block is not a descendent of `base`, returns an error. + Ancestry(base, block Hash) ([]Hash, error) + // Returns true if `block` is a descendent of or equal to the given `base`. + IsEqualOrDescendantOf(base, block Hash) bool +} + +// Equivocation is an equivocation (double-vote) in a given round. +type Equivocation[ID constraints.Ordered, Vote, Signature comparable] struct { + // The round number equivocated in. + RoundNumber uint64 + // The identity of the equivocator. + Identity ID + // The first vote in the equivocation. + First voteSignature[Vote, Signature] + // The second vote in the equivocation. + Second voteSignature[Vote, Signature] +} + +// Message is a protocol message or vote. +type Message[Hash, Number any] struct { + value any +} + +// Target returns the target block of the vote. +func (m Message[H, N]) Target() HashNumber[H, N] { + switch message := m.value.(type) { + case Prevote[H, N]: + return HashNumber[H, N]{ + message.TargetHash, + message.TargetNumber, + } + case Precommit[H, N]: + return HashNumber[H, N]{ + message.TargetHash, + message.TargetNumber, + } + case PrimaryPropose[H, N]: + return HashNumber[H, N]{ + message.TargetHash, + message.TargetNumber, + } + default: + panic("unsupported Message type") + } +} + +// Value returns the message constrained by `Messages` +func (m Message[H, N]) Value() any { + return m.value +} + +// Messages is the interface constraint for `Message` +type Messages[Hash, Number any] interface { + Prevote[Hash, Number] | Precommit[Hash, Number] | PrimaryPropose[Hash, Number] +} + +func setMessage[Hash, Number any, T Messages[Hash, Number]](m *Message[Hash, Number], val T) { + m.value = val +} + +func newMessage[Hash, Number any, T Messages[Hash, Number]](val T) (m Message[Hash, Number]) { + msg := Message[Hash, Number]{} + setMessage(&msg, val) + return msg +} + +// SignedMessage is a signed message. +type SignedMessage[Hash, Number, Signature, ID any] struct { + // The internal message which has been signed. + Message Message[Hash, Number] + // The signature on the message. + Signature Signature + // The Id of the signer + ID ID +} + +// Commit is a commit message which is an aggregate of precommits. +type Commit[Hash, Number, Signature, ID any] struct { + // The target block's hash. + TargetHash Hash + // The target block's number. + TargetNumber Number + // Precommits for target block or any block after it that justify this commit. + Precommits []SignedPrecommit[Hash, Number, Signature, ID] +} + +func (c Commit[Hash, Number, Signature, ID]) CompactCommit() CompactCommit[Hash, Number, Signature, ID] { + precommits := make([]Precommit[Hash, Number], len(c.Precommits)) + authData := make(MultiAuthData[Signature, ID], len(c.Precommits)) + for i, signed := range c.Precommits { + precommits[i] = signed.Precommit + authData[i] = struct { + Signature Signature + ID ID + }{signed.Signature, signed.ID} + } + return CompactCommit[Hash, Number, Signature, ID]{ + TargetHash: c.TargetHash, + TargetNumber: c.TargetNumber, + Precommits: precommits, + AuthData: authData, + } +} + +// SignedPrevote is a signed prevote message. +type SignedPrevote[Hash, Number, Signature, ID any] struct { + // The prevote message which has been signed. + Prevote Prevote[Hash, Number] + // The signature on the message. + Signature Signature + // The ID of the signer. + ID ID +} + +// SignedPrecommit is a signed precommit message. +type SignedPrecommit[Hash, Number, Signature, ID any] struct { + // The precommit message which has been signed. + Precommit Precommit[Hash, Number] + // The signature on the message. + Signature Signature + // The ID of the signer. + ID ID +} + +// CompactCommit is a commit message with compact representation of authentication data. +type CompactCommit[Hash, Number, Signature, ID any] struct { + TargetHash Hash + TargetNumber Number + Precommits []Precommit[Hash, Number] + AuthData MultiAuthData[Signature, ID] +} + +func (cc CompactCommit[Hash, Number, Signature, ID]) Commit() Commit[Hash, Number, Signature, ID] { + signedPrecommits := make([]SignedPrecommit[Hash, Number, Signature, ID], len(cc.Precommits)) + for i, precommit := range cc.Precommits { + signedPrecommits[i] = SignedPrecommit[Hash, Number, Signature, ID]{ + Precommit: precommit, + Signature: cc.AuthData[i].Signature, + ID: cc.AuthData[i].ID, + } + } + return Commit[Hash, Number, Signature, ID]{ + TargetHash: cc.TargetHash, + TargetNumber: cc.TargetNumber, + Precommits: signedPrecommits, + } +} + +// CatchUp is a catch-up message, which is an aggregate of prevotes and precommits necessary +// to complete a round. +// +// This message contains a "base", which is a block all of the vote-targets are +// a descendent of. +type CatchUp[Hash, Number, Signature, ID any] struct { + // Round number. + RoundNumber uint64 + // Prevotes for target block or any block after it that justify this catch-up. + Prevotes []SignedPrevote[Hash, Number, Signature, ID] + // Precommits for target block or any block after it that justify this catch-up. + Precommits []SignedPrecommit[Hash, Number, Signature, ID] + // The base hash. See struct docs. + BaseHash Hash + // The base number. See struct docs. + BaseNumber Number +} + +// MultiAuthData contains authentication data for a set of many messages, currently a set of precommit signatures but +// in the future could be optimised with BLS signature aggregation. +type MultiAuthData[Signature, ID any] []struct { + Signature Signature + ID ID +} + +// CommitValidationResult is type returned from `ValidateCommit` with information +// about the validation result. +type CommitValidationResult struct { + valid bool + numPrecommits uint + numDuplicatedPrecommits uint + numEquivocations uint + numInvalidVoters uint +} + +// Valid returns `true` if the commit is valid, which implies that the target +// block in the commit is finalized. +func (cvr CommitValidationResult) Valid() bool { + return cvr.valid +} + +// NumPrecommits returns the number of precommits in the commit. +func (cvr CommitValidationResult) NumPrecommits() uint { + return cvr.numPrecommits +} + +// NumDuplicatedPrecommits returns the number of duplicate precommits in the commit. +func (cvr CommitValidationResult) NumDuplicatedPrecommits() uint { + return cvr.numDuplicatedPrecommits +} + +// NumEquiovcations returns the number of equivocated precommits in the commit. +func (cvr CommitValidationResult) NumEquiovcations() uint { + return cvr.numEquivocations +} + +// NumInvalidVoters returns the number of invalid voters in the commit, i.e. votes from +// identities that are not part of the voter set. +func (cvr CommitValidationResult) NumInvalidVoters() uint { + return cvr.numInvalidVoters +} + +// ValidateCommit validates a GRANDPA commit message. +// +// For a commit to be valid the round ghost is calculated using the precommits +// in the commit message, making sure that it exists and that it is the same +// as the commit target. The precommit with the lowest block number is used as +// the round base. +// +// Signatures on precommits are assumed to have been checked. +// +// Duplicate votes or votes from voters not in the voter-set will be ignored, +// but it is recommended for the caller of this function to remove those at +// signature-verification time. +func ValidateCommit[ //skipcq: GO-R1005 + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +]( + commit Commit[Hash, Number, Signature, ID], + voters VoterSet[ID], + chain Chain[Hash, Number], +) (CommitValidationResult, error) { + validationResult := CommitValidationResult{ + numPrecommits: uint(len(commit.Precommits)), + } + + // filter any precommits by voters that are not part of the set + var validPrecommits []SignedPrecommit[Hash, Number, Signature, ID] + for _, signed := range commit.Precommits { + if !voters.Contains(signed.ID) { + validationResult.numInvalidVoters++ + continue + } + validPrecommits = append(validPrecommits, signed) + } + + // the base of the round should be the lowest block for which we can find a + // precommit (any vote would only have been accepted if it was targeting a + // block higher or equal to the round base) + var base HashNumber[Hash, Number] + var targets []HashNumber[Hash, Number] + for _, signed := range validPrecommits { + targets = append(targets, HashNumber[Hash, Number]{ + Hash: signed.Precommit.TargetHash, + Number: signed.Precommit.TargetNumber, + }) + } + slices.SortFunc(targets, func(a HashNumber[Hash, Number], b HashNumber[Hash, Number]) int { + return int(a.Number - b.Number) + }) + if len(targets) == 0 { + return validationResult, nil + } + base = targets[0] + + // check that all precommits are for blocks that are equal to or descendants + // of the round base + var allPrecommitsHigherThanBase bool + for i, signed := range validPrecommits { + if chain.IsEqualOrDescendantOf(base.Hash, signed.Precommit.TargetHash) { + if i == len(validPrecommits)-1 { + allPrecommitsHigherThanBase = true + } + continue + } + break + } + + if !allPrecommitsHigherThanBase { + return validationResult, nil + } + + equivocated := &btree.Set[ID]{} + + // add all precommits to the round with correct counting logic + round := NewRound[ID, Hash, Number, Signature]( + RoundParams[ID, Hash, Number]{ + RoundNumber: 0, // doesn't matter here + Voters: voters, + Base: base, + }, + ) + + for _, signedPrecommit := range validPrecommits { + importResult, err := round.importPrecommit( + chain, + signedPrecommit.Precommit, + signedPrecommit.ID, + signedPrecommit.Signature, + ) + if err != nil { + return CommitValidationResult{}, err + } + switch { + case importResult.Equivocation != nil: + validationResult.numEquivocations++ + // allow only one equivocation per voter, as extras are redundant. + if equivocated.Contains(signedPrecommit.ID) { + return validationResult, nil + } + equivocated.Insert(signedPrecommit.ID) + default: + if importResult.Duplicated { + validationResult.numDuplicatedPrecommits++ + } + } + } + + // for the commit to be valid, then a precommit ghost must be found for the + // round and it must be equal to the commit target + precommitGHOST := round.PrecommitGHOST() + switch { + case precommitGHOST != nil: + if precommitGHOST.Hash == commit.TargetHash && precommitGHOST.Number == commit.TargetNumber { + validationResult.valid = true + } + default: + } + + return validationResult, nil +} + +// HistoricalVotes are the historical votes seen in a round. +type HistoricalVotes[Hash, Number, Signature, ID any] struct { + seen []SignedMessage[Hash, Number, Signature, ID] + prevoteIdx *uint64 + precommitIdx *uint64 +} + +// NewHistoricalVotes creates a new HistoricalVotes. +func NewHistoricalVotes[Hash, Number, Signature, ID any]() HistoricalVotes[Hash, Number, Signature, ID] { + return HistoricalVotes[Hash, Number, Signature, ID]{ + seen: make([]SignedMessage[Hash, Number, Signature, ID], 0), + prevoteIdx: nil, + precommitIdx: nil, + } +} + +// PushVote pushes a vote into the list. +func (hv *HistoricalVotes[Hash, Number, Signature, ID]) PushVote(msg SignedMessage[Hash, Number, Signature, ID]) { + hv.seen = append(hv.seen, msg) +} + +// SetPrevotedIdx sets the number of messages seen before prevoting. +func (hv *HistoricalVotes[Hash, Number, Signature, ID]) SetPrevotedIdx() { + pi := uint64(len(hv.seen)) + hv.prevoteIdx = &pi +} + +// SetPrecommittedIdx sets the number of messages seen before precommiting. +func (hv *HistoricalVotes[Hash, Number, Signature, ID]) SetPrecommittedIdx() { + pi := uint64(len(hv.seen)) + hv.precommitIdx = &pi +} diff --git a/pkg/finality-grandpa/lib_test.go b/pkg/finality-grandpa/lib_test.go new file mode 100644 index 0000000000..353d45f537 --- /dev/null +++ b/pkg/finality-grandpa/lib_test.go @@ -0,0 +1,179 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateCommit(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A"}) + + IDWeights := make([]IDWeight[int32], 0) + for i := 1; i <= 100; i++ { + IDWeights = append(IDWeights, IDWeight[int32]{int32(i), 1}) + } + voters := NewVoterSet(IDWeights) + + makePrecommit := func(targetHash string, targetNumber uint, id int32) SignedPrecommit[string, uint, string, int32] { + return SignedPrecommit[string, uint, string, int32]{ + Precommit: Precommit[string, uint]{ + TargetHash: targetHash, + TargetNumber: targetNumber, + }, + ID: id, + } + } + + var precommits []SignedPrecommit[string, uint, string, int32] + ids := make([]int32, 0) + for i := 1; i < 67; i++ { + ids = append(ids, int32(i)) + } + for _, id := range ids { + precommit := makePrecommit("C", 3, id) + precommits = append(precommits, precommit) + } + + // we have still not reached threshold with 66/100 votes, so the commit + // is not valid. + result, err := ValidateCommit[string, uint, string]( + Commit[string, uint, string, int32]{ + TargetHash: "C", + TargetNumber: 3, + Precommits: precommits, + }, *voters, chain) + assert.NoError(t, err) + + assert.False(t, result.Valid()) + + // after adding one more commit targeting the same block we are over + // the finalisation threshold and the commit should be valid + precommits = append(precommits, makePrecommit("C", 3, 67)) + + result, err = ValidateCommit[string, uint, string]( + Commit[string, uint, string, int32]{ + TargetHash: "C", + TargetNumber: 3, + Precommits: precommits, + }, *voters, chain) + assert.NoError(t, err) + + assert.True(t, result.Valid()) + + // the commit target must be the exact same as the round precommit ghost + // that is calculated with the given precommits for the commit to be valid + result, err = ValidateCommit[string, uint, string]( + Commit[string, uint, string, int32]{ + TargetHash: "B", + TargetNumber: 2, + Precommits: precommits, + }, *voters, chain) + assert.NoError(t, err) + + assert.False(t, result.Valid()) +} + +func TestValidateCommit_WithEquivocation(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + + IDWeights := make([]IDWeight[int32], 0) + for i := 1; i <= 100; i++ { + IDWeights = append(IDWeights, IDWeight[int32]{int32(i), 1}) + } + voters := NewVoterSet(IDWeights) + + makePrecommit := func(targetHash string, targetNumber uint, id int32) SignedPrecommit[string, uint, string, int32] { + return SignedPrecommit[string, uint, string, int32]{ + Precommit: Precommit[string, uint]{ + TargetHash: targetHash, + TargetNumber: targetNumber, + }, + ID: id, + } + } + + // we add 66/100 precommits targeting block C + var precommits []SignedPrecommit[string, uint, string, int32] + ids := make([]int32, 0) + for i := 1; i < 67; i++ { + ids = append(ids, int32(i)) + } + for _, id := range ids { + precommit := makePrecommit("C", 3, id) + precommits = append(precommits, precommit) + } + + // we then add two equivocated votes targeting A and B + // from the 67th validator + precommits = append(precommits, makePrecommit("A", 1, 67)) + precommits = append(precommits, makePrecommit("B", 2, 67)) + + // this equivocation is treated as "voting for all blocks", which means + // that block C will now have 67/100 votes and therefore it can be + // finalized. + result, err := ValidateCommit[string, uint, string]( + Commit[string, uint, string, int32]{ + TargetHash: "C", + TargetNumber: 3, + Precommits: precommits, + }, *voters, chain) + assert.NoError(t, err) + + assert.True(t, result.Valid()) + assert.Equal(t, uint(1), result.NumEquiovcations()) +} + +func TestValidateCommit_PrecommitFromUnknownVoterIsIgnored(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + + IDWeights := make([]IDWeight[int32], 0) + for i := 1; i <= 100; i++ { + IDWeights = append(IDWeights, IDWeight[int32]{int32(i), 1}) + } + voters := NewVoterSet(IDWeights) + + makePrecommit := func(targetHash string, targetNumber uint, id int32) SignedPrecommit[string, uint, string, int32] { + return SignedPrecommit[string, uint, string, int32]{ + Precommit: Precommit[string, uint]{ + TargetHash: targetHash, + TargetNumber: targetNumber, + }, + ID: id, + } + } + + var precommits []SignedPrecommit[string, uint, string, int32] + + // invalid vote from unknown voter should not influence the base + precommits = append(precommits, makePrecommit("Z", 1, 1000)) + + ids := make([]int32, 0) + for i := 1; i <= 67; i++ { + ids = append(ids, int32(i)) + } + for _, id := range ids { + precommit := makePrecommit("C", 3, id) + precommits = append(precommits, precommit) + } + + result, err := ValidateCommit[string, uint]( + Commit[string, uint, string, int32]{ + TargetHash: "C", + TargetNumber: 3, + Precommits: precommits, + }, *voters, chain) + assert.NoError(t, err) + + // we have threshold votes for block "C" so it should be valid + assert.True(t, result.Valid()) + + // there is one invalid voter in the commit + assert.Equal(t, uint(1), result.NumInvalidVoters()) +} diff --git a/pkg/finality-grandpa/logger.go b/pkg/finality-grandpa/logger.go new file mode 100644 index 0000000000..28f8e70c92 --- /dev/null +++ b/pkg/finality-grandpa/logger.go @@ -0,0 +1,61 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +type logLevel int + +const ( + debug logLevel = iota + trace +) + +type Logger interface { + Warn(l string) + Warnf(format string, values ...any) + + Info(l string) + Infof(format string, values ...any) + + Debug(l string) + Debugf(format string, values ...any) + + Trace(l string) + Tracef(format string, values ...any) +} + +type noopLogger struct{} + +func (noopLogger) Warn(_ string) { +} + +func (noopLogger) Warnf(_ string, _ ...any) { +} + +func (noopLogger) Info(_ string) { +} + +func (noopLogger) Infof(_ string, _ ...any) { +} + +func (noopLogger) Debug(_ string) { +} + +func (noopLogger) Debugf(_ string, _ ...any) { +} + +func (noopLogger) Trace(_ string) { +} + +func (noopLogger) Tracef(_ string, _ ...any) { +} + +var log Logger + +func init() { + log = noopLogger{} +} + +func SetLogger(l Logger) { + log = l +} diff --git a/pkg/finality-grandpa/past_rounds.go b/pkg/finality-grandpa/past_rounds.go new file mode 100644 index 0000000000..49b10a715b --- /dev/null +++ b/pkg/finality-grandpa/past_rounds.go @@ -0,0 +1,364 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "golang.org/x/exp/constraints" +) + +// wraps a voting round with a new future that resolves when the round can +// be discarded from the working set. +// +// that point is when the round-estimate is finalized. +type backgroundRound[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, + ID constraints.Ordered, E Environment[Hash, Number, Signature, ID], +] struct { + inner votingRound[Hash, Number, Signature, ID, E] + finalizedNumber Number + roundCommitter *roundCommitter[Hash, Number, Signature, ID, E] + + waker *waker +} + +func (br *backgroundRound[Hash, Number, Signature, ID, E]) roundNumber() uint64 { + return br.inner.roundNumber() +} + +func (br *backgroundRound[Hash, Number, Signature, ID, E]) votingRound() votingRound[Hash, Number, Signature, ID, E] { + return br.inner +} + +func (br *backgroundRound[Hash, Number, Signature, ID, E]) isDone() bool { + // no need to listen on a round anymore once the estimate is finalized. + // + // we map `br.roundCommitter == nil` to true because + // - rounds are not backgrounded when incomplete unless we've skipped forward + // - if we skipped forward we may never complete this round and we don't need + // to keep it forever. + var ls = br.roundCommitter == nil + if !ls { + return false + } + var rs = true + estimate := br.inner.roundState().Estimate + if estimate != nil { + rs = estimate.Number <= br.finalizedNumber + } + return ls && rs +} + +func (br *backgroundRound[Hash, Number, Signature, ID, E]) updateFinalized(newFinalized Number) { + switch { + case br.finalizedNumber >= newFinalized: + default: + br.finalizedNumber = newFinalized + } + + // wake up the future to be polled if done. + if br.isDone() { + br.waker.wake() + } +} + +type concluded uint64 +type committed[Hash, Number, Signature, ID any] Commit[Hash, Number, Signature, ID] + +type backgroundRoundChange[Hash, Number, Signature, ID any] struct { + variant any +} + +func (brc backgroundRoundChange[Hash, Number, Signature, ID]) Variant() any { + switch brc.variant.(type) { + case concluded, committed[Hash, Number, Signature, ID]: + default: + panic("unsupported type") + } + return brc.variant +} + +func newBackgroundRoundChange[ + Hash, + Number, + Signature, + ID any, + V backgroundRoundChanges[Hash, Number, Signature, ID], +](variant V) backgroundRoundChange[Hash, Number, Signature, ID] { + change := backgroundRoundChange[Hash, Number, Signature, ID]{} + change.variant = variant + return change +} + +type backgroundRoundChanges[Hash, Number, Signature, ID any] interface { + concluded | committed[Hash, Number, Signature, ID] +} + +func (br *backgroundRound[Hash, Number, Signature, ID, E]) poll(waker *waker) ( + bool, + backgroundRoundChange[Hash, Number, Signature, ID], + error, +) { + br.waker = waker + + _, err := br.inner.poll(waker) + if err != nil { + return true, backgroundRoundChange[Hash, Number, Signature, ID]{}, err + } + + committer := br.roundCommitter + br.roundCommitter = nil + switch committer { + case nil: + default: + ready, commit, err := committer.commit(waker, br.inner) + switch { + case ready && commit == nil && err == nil: + br.roundCommitter = nil + case ready && commit != nil && err == nil: + change := newBackgroundRoundChange[Hash, Number, Signature, ID]( + committed[Hash, Number, Signature, ID](*commit), + ) + return true, change, nil + case !ready: + br.roundCommitter = committer + default: + panic("unreachable") + } + } + + if br.isDone() { + // if this is fully concluded (has committed _and_ estimate finalized) + // we bail for real. + change := newBackgroundRoundChange[Hash, Number, Signature, ID]( + concluded(br.roundNumber()), + ) + return true, change, nil + } + return false, backgroundRoundChange[Hash, Number, Signature, ID]{}, nil +} + +type roundCommitter[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, + ID constraints.Ordered, E Environment[Hash, Number, Signature, ID], +] struct { + commitTimer Timer + importCommits *wakerChan[Commit[Hash, Number, Signature, ID]] + lastCommit *Commit[Hash, Number, Signature, ID] +} + +func newRoundCommitter[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, + ID constraints.Ordered, E Environment[Hash, Number, Signature, ID], +]( + commitTimer Timer, + commitReceiver *wakerChan[Commit[Hash, Number, Signature, ID]], +) *roundCommitter[Hash, Number, Signature, ID, E] { + return &roundCommitter[Hash, Number, Signature, ID, E]{ + commitTimer, commitReceiver, nil, + } +} + +func (rc *roundCommitter[Hash, Number, Signature, ID, E]) importCommit( + votingRound votingRound[Hash, Number, Signature, ID, E], commit Commit[Hash, Number, Signature, ID], +) (bool, error) { + // ignore commits for a block lower than we already finalized + if votingRound.finalized() != nil && commit.TargetNumber < votingRound.finalized().Number { + return true, nil + } + + base, err := votingRound.checkAndImportFromCommit(commit) + if err != nil { + return false, err + } + if base == nil { + return true, nil + } + + rc.lastCommit = &commit + + return true, nil +} + +func (rc *roundCommitter[Hash, Number, Signature, ID, E]) commit( + waker *waker, + votingRound votingRound[Hash, Number, Signature, ID, E], +) (bool, *Commit[Hash, Number, Signature, ID], error) { + rc.importCommits.setWaker(waker) +loop: + for { + select { + case commit, ok := <-rc.importCommits.channel(): + if !ok { + break loop + } + imported, err := rc.importCommit(votingRound, commit) + if err != nil { + return true, nil, err + } + if !imported { + log.Trace("Ignoring invalid commit") + } + default: + break loop + } + } + + rc.commitTimer.SetWaker(waker) + elapsed, err := rc.commitTimer.Elapsed() + if elapsed { + if err != nil { + return true, nil, err + } + } else { + return false, nil, nil + } + + lastCommit := rc.lastCommit + rc.lastCommit = nil + finalized := votingRound.finalized() + + switch { + case lastCommit == nil && finalized != nil: + return true, votingRound.finalizingCommit(), nil + case lastCommit != nil && finalized != nil && lastCommit.TargetNumber < finalized.Number: + return true, votingRound.finalizingCommit(), nil + default: + return true, nil, nil + } +} + +// A stream for past rounds, which produces any commit messages from those +// rounds and drives them to completion. +type pastRounds[Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, + ID constraints.Ordered, E Environment[Hash, Number, Signature, ID], +] struct { + pastRounds []backgroundRound[Hash, Number, Signature, ID, E] + commitSenders map[uint64]chan Commit[Hash, Number, Signature, ID] +} + +func newPastRounds[Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, + ID constraints.Ordered, E Environment[Hash, Number, Signature, ID]]() *pastRounds[Hash, Number, Signature, ID, E] { + return &pastRounds[Hash, Number, Signature, ID, E]{ + commitSenders: make(map[uint64]chan Commit[Hash, Number, Signature, ID]), + } +} + +// push an old voting round onto this stream. +func (p *pastRounds[Hash, Number, Signature, ID, E]) Push(env E, round votingRound[Hash, Number, Signature, ID, E]) { + roundNumber := round.roundNumber() + // TODO: this is supposed to be an unbounded channel on the producer side. Use buffered in p.commitSenders + // https://github.com/ChainSafe/gossamer/issues/3510 + ch := make(chan Commit[Hash, Number, Signature, ID], 100) + background := backgroundRound[Hash, Number, Signature, ID, E]{ + inner: round, + // this will get updated in a call to pastRounds.UpdateFinalized() on next poll + finalizedNumber: 0, + roundCommitter: newRoundCommitter[Hash, Number, Signature, ID, E](env.RoundCommitTimer(), newWakerChan(ch)), + } + p.pastRounds = append(p.pastRounds, background) + p.commitSenders[roundNumber] = ch +} + +// update the last finalized block. this will lead to +// any irrelevant background rounds being pruned. +func (p *pastRounds[Hash, Number, Signature, ID, E]) UpdateFinalized(fNum Number) { //skipcq: RVV-B0001 + // have the task check if it should be pruned. + // if so, this future will be re-polled + for i := range p.pastRounds { + p.pastRounds[i].updateFinalized(fNum) + } +} + +// Get the underlying `votingRound` items that are being run in the background. +func (p *pastRounds[Hash, Number, Signature, ID, E]) votingRounds() []votingRound[Hash, Number, Signature, ID, E] { + var votingRounds []votingRound[Hash, Number, Signature, ID, E] + for _, bg := range p.pastRounds { + votingRounds = append(votingRounds, bg.votingRound()) + } + return votingRounds +} + +// import the commit into the given backgrounded round. If not possible, +// just return and process the commit. +func (p pastRounds[Hash, Number, Signature, ID, E]) ImportCommit( //skipcq: RVV-B0001 + roundNumber uint64, + commit Commit[Hash, Number, Signature, ID], +) *Commit[Hash, Number, Signature, ID] { + sender, ok := p.commitSenders[roundNumber] + if !ok { + return &commit + } + select { + case sender <- commit: + return nil + default: + return &commit + } +} + +type numberCommit[Hash, Number, Signature, ID any] struct { + Number uint64 + Commit Commit[Hash, Number, Signature, ID] +} + +func (p *pastRounds[Hash, Number, Signature, ID, E]) pollNext(waker *waker) ( + ready bool, + nc *numberCommit[Hash, Number, Signature, ID], + err error, +) { + for { + if len(p.pastRounds) == 0 { + return true, nc, nil + } + br := p.pastRounds[0] + ready, backgroundRoundChange, err := br.poll(waker) + switch { + case ready && err == nil: + v := backgroundRoundChange.Variant() + // empty stream + if v == nil { + return true, nil, nil + } + switch v := v.(type) { + case concluded: + number := v + round := br.inner + err := round.Env().Concluded( + round.roundNumber(), + round.roundState(), + round.dagBase(), + round.historicalVotes(), + ) + if err != nil { + return true, nil, err + } + close(p.commitSenders[uint64(number)]) + delete(p.commitSenders, uint64(number)) + p.pastRounds = p.pastRounds[1:] + case committed[Hash, Number, Signature, ID]: + number := br.roundNumber() + commit := Commit[Hash, Number, Signature, ID](v) + + // reschedule until irrelevant + p.pastRounds = append(p.pastRounds[1:], br) + + log.Debugf( + "Committing: round_number = %v, target_number = %v, target_hash = %v", + number, + commit.TargetNumber, + commit.TargetHash, + ) + + return true, &numberCommit[Hash, Number, Signature, ID]{number, commit}, nil + } + case ready && err != nil: + return true, nc, err + case !ready: + // reschedule until irrelevant + p.pastRounds = append(p.pastRounds[1:], br) + return false, nc, nil + } + } + +} diff --git a/pkg/finality-grandpa/report.go b/pkg/finality-grandpa/report.go new file mode 100644 index 0000000000..e21e1983e8 --- /dev/null +++ b/pkg/finality-grandpa/report.go @@ -0,0 +1,34 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +// RoundStateReport is a basic data struct for the state of a round. +type RoundStateReport[ID comparable] struct { + // Total weight of all votes. + TotalWeight VoterWeight + // The threshold voter weight. + ThresholdWeight VoterWeight + + // Current weight of the prevotes. + PrevoteCurrentWeight VoteWeight + // The identities of nodes that have cast prevotes so far. + PrevoteIDs []ID + + // Current weight of the precommits. + PrecommitCurrentWeight VoteWeight + // The identities of nodes that have cast precommits so far. + PrecommitIDs []ID +} + +// VoterStateReport is a basic data struct for the current state of +// the voter in a form suitable for passing on to other systems. +type VoterStateReport[ID comparable] struct { + // Voting rounds running in the background. + BackgroundRounds map[uint64]RoundStateReport[ID] + // The current best voting round. + BestRound struct { + Number uint64 + RoundState RoundStateReport[ID] + } +} diff --git a/pkg/finality-grandpa/round.go b/pkg/finality-grandpa/round.go new file mode 100644 index 0000000000..f4c8ead520 --- /dev/null +++ b/pkg/finality-grandpa/round.go @@ -0,0 +1,692 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "sync" + + "github.com/tidwall/btree" + "golang.org/x/exp/constraints" +) + +// Phase is the (voting) phases of a round, each corresponding to the type of +// votes cast in that phase. +type Phase uint + +const ( + // The prevote phase in which `Prevote`s are cast. + PrevotePhase Phase = iota + // The precommit phase in which `Precommit`s are cast. + PrecommitPhase +) + +type voteSignature[Vote, Signature comparable] struct { + Vote Vote + Signature Signature +} + +type single[Vote, Signature comparable] voteSignature[Vote, Signature] + +type equivocated[Vote, Signature comparable] [2]voteSignature[Vote, Signature] + +// The observed vote from a single voter. +type voteMultiplicity[Vote, Signature comparable] struct { + value interface{} +} + +// can only use type constraint interfaces as function parameters +type voteMultiplicityValue[Vote, Signature comparable] interface { + single[Vote, Signature] | equivocated[Vote, Signature] +} + +func setvoteMultiplicity[ + Vote, Signature comparable, + T voteMultiplicityValue[Vote, Signature], +](vm *voteMultiplicity[Vote, Signature], val T) { + vm.value = val +} + +func newVoteMultiplicity[ + Vote, Signature comparable, + T voteMultiplicityValue[Vote, Signature], +](val T) (vm voteMultiplicity[Vote, Signature]) { + return voteMultiplicity[Vote, Signature]{ + value: val, + } +} + +func (vm voteMultiplicity[Vote, Signature]) Value() interface{} { + return vm.value +} + +func (vm voteMultiplicity[Vote, Signature]) Contains(vote Vote, sig Signature) bool { + vs := voteSignature[Vote, Signature]{vote, sig} + switch in := vm.Value().(type) { + case single[Vote, Signature]: + return voteSignature[Vote, Signature](in) == vs + case equivocated[Vote, Signature]: + return in[0] == vs || in[1] == vs + default: + panic("invalid voteMultiplicityValue") + } +} + +type voteTracker[ID constraints.Ordered, Vote, Signature comparable] struct { + votes *btree.Map[ID, voteMultiplicity[Vote, Signature]] + currentWeight VoteWeight + mtx sync.RWMutex +} + +func newVoteTracker[ID constraints.Ordered, Vote, Signature comparable]() voteTracker[ID, Vote, Signature] { + return voteTracker[ID, Vote, Signature]{ + votes: btree.NewMap[ID, voteMultiplicity[Vote, Signature]](2), + } +} + +// track a vote, returning a value containing the multiplicity of all votes from this ID +// and a bool indicating if the vote is duplicated. +// if the vote is the first equivocation, returns a value indicating +// it as such (the new vote is always the last in the multiplicity). +// +// if the vote is a further equivocation, it is ignored and there is nothing +// to do. +// +// since this struct doesn't track the round-number of votes, that must be set +// by the caller. +func (vt *voteTracker[ID, Vote, Signature]) addVote( + id ID, + vote Vote, + signature Signature, + weight VoterWeight, +) (*voteMultiplicity[Vote, Signature], bool) { + vt.mtx.Lock() + defer vt.mtx.Unlock() + + var ok bool + vm, ok := vt.votes.Get(id) + if !ok { + // TODO: figure out saturating_add stuff + // https://github.com/ChainSafe/gossamer/issues/3511 + vt.currentWeight = vt.currentWeight + VoteWeight(weight) + multiplicity := newVoteMultiplicity[Vote, Signature]( + single[Vote, Signature]{vote, signature}, + ) + _, exists := vt.votes.Set(id, multiplicity) + if exists { + panic(fmt.Errorf("id %v should not exist in votes", id)) + } + return &multiplicity, false + } + + duplicated := vm.Contains(vote, signature) + if duplicated { + return nil, true + } + + switch in := vm.Value().(type) { + case single[Vote, Signature]: + var eq = equivocated[Vote, Signature]{ + voteSignature[Vote, Signature](in), + { + Vote: vote, + Signature: signature, + }, + } + setvoteMultiplicity(&vm, eq) + vt.votes.Set(id, vm) + return &vm, false + case equivocated[Vote, Signature]: + // ignore further equivocations + return nil, duplicated + default: + panic("invalid voteMultiplicity value") + } +} + +type idVoteSignature[ID, Vote, Signature comparable] struct { + ID ID + voteSignature[Vote, Signature] +} + +func (vt *voteTracker[ID, Vote, Signature]) Votes() (votes []idVoteSignature[ID, Vote, Signature]) { + vt.mtx.RLock() + defer vt.mtx.RUnlock() + + vt.votes.Scan(func(id ID, vm voteMultiplicity[Vote, Signature]) bool { + switch in := vm.Value().(type) { + case single[Vote, Signature]: + votes = append(votes, idVoteSignature[ID, Vote, Signature]{ + ID: id, + voteSignature: voteSignature[Vote, Signature](in), + }) + case equivocated[Vote, Signature]: + for _, vs := range in { + votes = append(votes, idVoteSignature[ID, Vote, Signature]{ + ID: id, + voteSignature: vs, + }) + } + default: + panic("invalid voteMultiplicity value") + } + return true + }) + return +} + +func (vt *voteTracker[ID, Vote, Signature]) participation() (weight VoteWeight, numParticipants int) { + return vt.currentWeight, vt.votes.Len() +} + +// RoundState is the state of the round. +type RoundState[Hash, Number any] struct { + // The prevote-GHOST block. + PrevoteGHOST *HashNumber[Hash, Number] + // The finalized block. + Finalized *HashNumber[Hash, Number] + // The new round-estimate. + Estimate *HashNumber[Hash, Number] + // Whether the round is completable. + Completable bool +} + +// NewRoundState is constructor of `RoundState` from a given genesis state. +func NewRoundState[Hash, Number any](genesis HashNumber[Hash, Number]) RoundState[Hash, Number] { + return RoundState[Hash, Number]{ + PrevoteGHOST: &genesis, + Finalized: &genesis, + Estimate: &genesis, + Completable: true, + } +} + +// RoundParams are the parameters for starting a round. +type RoundParams[ID constraints.Ordered, Hash comparable, Number constraints.Unsigned] struct { + // The round number for votes. + RoundNumber uint64 + // Actors and weights in the round. + Voters VoterSet[ID] + // The base block to build on. + Base HashNumber[Hash, Number] +} + +// Round stores data for a round. +type Round[ID constraints.Ordered, Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable] struct { + number uint64 + context context[ID] + graph VoteGraph[Hash, Number, *voteNode[ID], vote[ID]] // DAG of blocks which have been voted on. + prevotes voteTracker[ID, Prevote[Hash, Number], Signature] // tracks prevotes that have been counted + precommits voteTracker[ID, Precommit[Hash, Number], Signature] // tracks precommits + historicalVotes HistoricalVotes[Hash, Number, Signature, ID] // historical votes + prevoteGhost *HashNumber[Hash, Number] // current memoized prevote-GHOST block + precommitGhost *HashNumber[Hash, Number] // current memoized precommit-GHOST block + finalized *HashNumber[Hash, Number] // best finalized block in this round. + estimate *HashNumber[Hash, Number] // current memoized round-estimate + completable bool // whether the round is completable +} + +// Result of importing a Prevote or Precommit. +type importResult[ID constraints.Ordered, P, Signature comparable] struct { + ValidVoter bool + Duplicated bool + Equivocation *Equivocation[ID, P, Signature] +} + +// NewRound creates a new round accumulator for given round number and with given weight. +func NewRound[ID constraints.Ordered, Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable]( + roundParams RoundParams[ID, Hash, Number], +) *Round[ID, Hash, Number, Signature] { + + var newVoteNode = func() *voteNode[ID] { + return &voteNode[ID]{newBitfield()} + } + return &Round[ID, Hash, Number, Signature]{ + number: roundParams.RoundNumber, + context: newContext(roundParams.Voters), + graph: NewVoteGraph[Hash, Number, *voteNode[ID], vote[ID]]( + roundParams.Base.Hash, + roundParams.Base.Number, + newVoteNode(), + newVoteNode, + ), + prevotes: newVoteTracker[ID, Prevote[Hash, Number], Signature](), + precommits: newVoteTracker[ID, Precommit[Hash, Number], Signature](), + historicalVotes: NewHistoricalVotes[Hash, Number, Signature, ID](), + } +} + +// Number returns the round number. +func (r *Round[ID, H, N, S]) Number() uint64 { + return r.number +} + +// Import a prevote. Returns an equivocation proof, if the vote is an equivocation, +// and a bool indicating if the vote is duplicated (see `ImportResult`). +// +// Ignores duplicate prevotes (not equivocations). +func (r *Round[ID, H, N, S]) importPrevote( + chain Chain[H, N], prevote Prevote[H, N], signer ID, signature S, +) (*importResult[ID, Prevote[H, N], S], error) { + ir := importResult[ID, Prevote[H, N], S]{} + + info := r.context.Voters().Get(signer) + if info == nil { + return &ir, nil + } + + ir.ValidVoter = true + weight := info.weight + + var equivocation *Equivocation[ID, Prevote[H, N], S] + var multiplicity *voteMultiplicity[Prevote[H, N], S] + m, duplicated := r.prevotes.addVote(signer, prevote, signature, weight) + if m != nil { + multiplicity = m + } else { + ir.Duplicated = duplicated + return &ir, nil + } + + switch val := multiplicity.Value().(type) { + case single[Prevote[H, N], S]: + singleVote := val + vote := newVote[ID](*info, PrevotePhase) + err := r.graph.Insert(singleVote.Vote.TargetHash, singleVote.Vote.TargetNumber, vote, chain) + if err != nil { + return nil, err + } + + // Push the vote into HistoricalVotes. + message := Message[H, N]{} + setMessage(&message, prevote) + signedMessage := SignedMessage[H, N, S, ID]{ + Message: message, + Signature: signature, + ID: signer, + } + r.historicalVotes.PushVote(signedMessage) + + case equivocated[Prevote[H, N], S]: + first := val[0] + second := val[1] + + // mark the equivocator as such. no need to "undo" the first vote. + r.context.Equivocated(*info, PrevotePhase) + + // Push the vote into HistoricalVotes. + message := Message[H, N]{} + setMessage(&message, prevote) + signedMessage := SignedMessage[H, N, S, ID]{ + Message: message, + Signature: signature, + ID: signer, + } + r.historicalVotes.PushVote(signedMessage) + equivocation = &Equivocation[ID, Prevote[H, N], S]{ + RoundNumber: r.number, + Identity: signer, + First: first, + Second: second, + } + default: + panic("invalid voteMultiplicity value") + } + + // update prevote-GHOST + threshold := r.context.voters.threshold + if r.prevotes.currentWeight >= VoteWeight(threshold) { + r.prevoteGhost = r.graph.FindGHOST(r.prevoteGhost, func(v *voteNode[ID]) bool { + return r.context.Weight(*v, PrevotePhase) >= VoteWeight(threshold) + }) + } + + r.update() + ir.Equivocation = equivocation + return &ir, nil +} + +// Import a precommit. Returns an equivocation proof, if the vote is an +// equivocation, and a bool indicating if the vote is duplicated (see `ImportResult`). +// +// Ignores duplicate precommits (not equivocations). +func (r *Round[ID, H, N, S]) importPrecommit( + chain Chain[H, N], precommit Precommit[H, N], signer ID, signature S, +) (*importResult[ID, Precommit[H, N], S], error) { + ir := importResult[ID, Precommit[H, N], S]{} + + info := r.context.Voters().Get(signer) + if info == nil { + return &ir, nil + } + + ir.ValidVoter = true + weight := info.weight + + var equivocation *Equivocation[ID, Precommit[H, N], S] + var multiplicity *voteMultiplicity[Precommit[H, N], S] + m, duplicated := r.precommits.addVote(signer, precommit, signature, weight) + if m != nil { + multiplicity = m + } else { + ir.Duplicated = duplicated + return &ir, nil + } + + switch val := multiplicity.Value().(type) { + case single[Precommit[H, N], S]: + singleVote := val + vote := newVote[ID](*info, PrecommitPhase) + err := r.graph.Insert(singleVote.Vote.TargetHash, singleVote.Vote.TargetNumber, vote, chain) + if err != nil { + return nil, err + } + + // Push the vote into HistoricalVotes. + message := Message[H, N]{} + setMessage(&message, precommit) + signedMessage := SignedMessage[H, N, S, ID]{ + Message: message, + Signature: signature, + ID: signer, + } + r.historicalVotes.PushVote(signedMessage) + + case equivocated[Precommit[H, N], S]: + first := val[0] + second := val[1] + + // mark the equivocator as such. no need to "undo" the first vote. + r.context.Equivocated(*info, PrecommitPhase) + + // Push the vote into HistoricalVotes. + message := Message[H, N]{} + setMessage(&message, precommit) + signedMessage := SignedMessage[H, N, S, ID]{ + Message: message, + Signature: signature, + ID: signer, + } + r.historicalVotes.PushVote(signedMessage) + equivocation = &Equivocation[ID, Precommit[H, N], S]{ + RoundNumber: r.number, + Identity: signer, + First: first, + Second: second, + } + default: + panic("invalid voteMultiplicity value") + } + + r.update() + ir.Equivocation = equivocation + return &ir, nil +} + +// update the round-estimate and whether the round is completable. +func (r *Round[ID, H, N, S]) update() { + threshold := r.context.voters.threshold + + if r.prevotes.currentWeight < VoteWeight(threshold) { + return + } + + if r.prevoteGhost == nil { + return + } + + // anything new finalized? finalized blocks are those which have both + // 2/3+ prevote and precommit weight. + currentPrecommits := r.precommits.currentWeight + if currentPrecommits >= VoteWeight(threshold) { + r.finalized = r.graph.FindAncestor(r.prevoteGhost.Hash, r.prevoteGhost.Number, func(v *voteNode[ID]) bool { + return r.context.Weight(*v, PrecommitPhase) >= VoteWeight(threshold) + }) + } + + // figuring out whether a block can still be committed for is + // not straightforward because we have to account for all possible future + // equivocations and thus cannot discount weight from validators who + // have already voted. + var possibleToPrecommit = func(node *voteNode[ID]) bool { + // find how many more equivocations we could still get. + // + // it is only important to consider the voters whose votes + // we have already seen, because we are assuming any votes we + // haven't seen will target this block. + toleratedEquivocations := VoteWeight(r.context.voters.totalWeight - threshold) + currentEquivocations := r.context.EquivocationWeight(PrecommitPhase) + additionalEquiv := toleratedEquivocations - currentEquivocations + remainingCommitVotes := VoteWeight(r.context.voters.totalWeight) - r.precommits.currentWeight + + // total precommits for this block, including equivocations. + precommitedFor := r.context.Weight(*node, PrecommitPhase) + + // equivocations we could still get are out of those who + // have already voted, but not on this block. + var possibleEquivocations VoteWeight + if currentPrecommits-precommitedFor <= additionalEquiv { + possibleEquivocations = currentPrecommits - precommitedFor + } else { + possibleEquivocations = additionalEquiv + } + + // all the votes already applied on this block, + // assuming all remaining actors commit to this block, + // and that we get further equivocations + fullPossibleWeight := precommitedFor + remainingCommitVotes + possibleEquivocations + return fullPossibleWeight >= VoteWeight(threshold) + } + + // until we have threshold precommits, any new block could get supermajority + // precommits because there are at least f + 1 precommits remaining and then + // f equivocations. + // + // once it's at least that level, we only need to consider blocks + // already referenced in the graph, because no new leaf nodes + // could ever have enough precommits. + // + // the round-estimate is the highest block in the chain with head + // `prevote_ghost` that could have supermajority-commits. + if r.precommits.currentWeight >= VoteWeight(threshold) { + r.estimate = r.graph.FindAncestor(r.prevoteGhost.Hash, r.prevoteGhost.Number, possibleToPrecommit) + } else { + r.estimate = &HashNumber[H, N]{r.prevoteGhost.Hash, r.prevoteGhost.Number} + return + } + + if r.estimate != nil { + var ls bool = r.estimate.Hash != r.prevoteGhost.Hash + var rs bool + x := r.graph.FindGHOST(r.estimate, possibleToPrecommit) + if x == nil { + rs = true + } else { + rs = *x == *r.prevoteGhost + } + r.completable = ls || rs + } else { + r.completable = false + } +} + +// State returns the current state. +func (r *Round[ID, H, N, S]) State() RoundState[H, N] { + return RoundState[H, N]{ + PrevoteGHOST: r.prevoteGhost, + Finalized: r.finalized, + Estimate: r.estimate, + Completable: r.completable, + } +} + +// PrecommitGHOST will compute and cache the precommit-GHOST. +func (r *Round[ID, H, N, S]) PrecommitGHOST() *HashNumber[H, N] { + // update precommit-GHOST + var threshold = r.Threshold() + if r.precommits.currentWeight >= VoteWeight(threshold) { + r.precommitGhost = r.graph.FindGHOST(r.precommitGhost, func(v *voteNode[ID]) bool { + return r.context.Weight(*v, PrecommitPhase) >= VoteWeight(threshold) + }) + } + return r.precommitGhost +} + +type yieldVotes[H constraints.Ordered, N constraints.Unsigned, S comparable] struct { + yielded uint + multiplicity voteMultiplicity[Precommit[H, N], S] +} + +func (yv *yieldVotes[H, N, S]) voteSignature() *voteSignature[Precommit[H, N], S] { + switch vm := yv.multiplicity.Value().(type) { + case single[Precommit[H, N], S]: + if yv.yielded == 0 { + yv.yielded++ + return &voteSignature[Precommit[H, N], S]{vm.Vote, vm.Signature} + } + return nil + case equivocated[Precommit[H, N], S]: + a := vm[0] + b := vm[1] + switch yv.yielded { + case 0: + return &a + case 1: + return &b + default: + return nil + } + default: + panic("invalid voteMultiplicity value") + } +} + +// FinalizingPrecommits returns all precommits targeting the finalized hash. +// +// Only returns `nil` if no block has been finalized in this round. +func (r *Round[ID, H, N, S]) FinalizingPrecommits(chain Chain[H, N]) *[]SignedPrecommit[H, N, S, ID] { + type idvoteMultiplicity struct { + ID ID + voteMultiplicity voteMultiplicity[Precommit[H, N], S] + } + + if r.finalized == nil { + return nil + } + fHash := r.finalized.Hash + var filtered []idvoteMultiplicity + var findValidPrecommits []SignedPrecommit[H, N, S, ID] + r.precommits.votes.Scan(func(id ID, multiplicity voteMultiplicity[Precommit[H, N], S]) bool { + switch multiplicityValue := multiplicity.Value().(type) { + case single[Precommit[H, N], S]: + // if there is a single vote from this voter, we only include it + // if it branches off of the target. + if chain.IsEqualOrDescendantOf(fHash, multiplicityValue.Vote.TargetHash) { + filtered = append(filtered, idvoteMultiplicity{id, multiplicity}) + } + default: + // equivocations count for everything, so we always include them. + filtered = append(filtered, idvoteMultiplicity{id, multiplicity}) + } + return true + }) + for _, ivm := range filtered { + yieldVotes := yieldVotes[H, N, S]{ + yielded: 0, + multiplicity: ivm.voteMultiplicity, + } + if vs := yieldVotes.voteSignature(); vs != nil { + findValidPrecommits = append(findValidPrecommits, SignedPrecommit[H, N, S, ID]{ + Precommit: vs.Vote, + Signature: vs.Signature, + ID: ivm.ID, + }) + } + } + return &findValidPrecommits +} + +// Estimate will fetch the "round-estimate": the best block which might have been finalized +// in this round. +// +// Returns `nil` when new new blocks could have been finalized in this round, +// according to our estimate. +func (r *Round[ID, H, N, S]) Estimate() *HashNumber[H, N] { + return r.estimate +} + +// Finalized fetches the most recently finalized block. +func (r *Round[ID, H, N, S]) Finalized() *HashNumber[H, N] { + return r.finalized +} + +// Completable returns `true` when the round is completable. +// +// This is the case when the round-estimate is an ancestor of the prevote-ghost head, +// or when they are the same block _and_ none of its children could possibly have +// enough precommits. +func (r *Round[ID, H, N, S]) Completable() bool { + return r.completable +} + +// Threshold weight for supermajority. +func (r *Round[ID, H, N, S]) Threshold() VoterWeight { + return r.context.voters.threshold +} + +// Base returns the round base. +func (r *Round[ID, H, N, S]) Base() HashNumber[H, N] { + return r.graph.Base() +} + +// Voters returns the round voters and weights. +func (r *Round[ID, H, N, S]) Voters() VoterSet[ID] { + return r.context.voters +} + +// PrimaryVoter returns the primary voter of the round. +func (r *Round[ID, H, N, S]) PrimaryVoter() (ID, VoterInfo) { + IDVoterInfo := r.context.Voters().NthMod(uint(r.number)) + return IDVoterInfo.ID, IDVoterInfo.VoterInfo +} + +// PrevoteParticipation returns the current weight and number of voters who have participated in prevoting. +func (r *Round[ID, H, N, S]) PrevoteParticipation() (weight VoteWeight, numParticipants int) { + return r.prevotes.participation() +} + +// PrecommitParticipation returns the current weight and number of voters who have participated in precommitting. +func (r *Round[ID, H, N, S]) PrecommitParticipation() (weight VoteWeight, numParticipants int) { + return r.precommits.participation() +} + +// Prevotes returns all imported prevotes. +func (r *Round[ID, H, N, S]) Prevotes() []idVoteSignature[ID, Prevote[H, N], S] { + return r.prevotes.Votes() +} + +// Precommits returns all imported precommits. +func (r *Round[ID, H, N, S]) Precommits() []idVoteSignature[ID, Precommit[H, N], S] { + return r.precommits.Votes() +} + +// HistoricalVotes returns all votes for the round (prevotes and precommits), sorted by +// imported order and indicating the indices where we voted. At most two +// prevotes and two precommits per voter are present, further equivocations +// are not stored (as they are redundant). +func (r *Round[ID, H, N, S]) HistoricalVotes() HistoricalVotes[H, N, S, ID] { + return r.historicalVotes +} + +// SetPrevotedIdx will set the number of prevotes and precommits received at the moment of prevoting. +// It should be called inmediatly after prevoting. +func (r *Round[ID, H, N, S]) SetPrevotedIdx() { + r.historicalVotes.SetPrevotedIdx() +} + +// SetPrecommittedIdx will set the number of prevotes and precommits received at the moment of precommiting. +// It should be called inmediatly after precommiting. +func (r *Round[ID, H, N, S]) SetPrecommittedIdx() { + r.historicalVotes.SetPrecommittedIdx() +} diff --git a/pkg/finality-grandpa/round_test.go b/pkg/finality-grandpa/round_test.go new file mode 100644 index 0000000000..5c60cf2cf6 --- /dev/null +++ b/pkg/finality-grandpa/round_test.go @@ -0,0 +1,323 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVoteMultiplicity_Contains(t *testing.T) { + type headerNumber struct { + Header string + Number uint + } + type signature string + var ( + headerNumber1 = headerNumber{"header1", 1} + signature1 = signature("sig1") + headerNumber2 = headerNumber{"header2", 2} + signature2 = signature("sig2") + ) + tests := []struct { + name string + value interface{} + args voteSignature[headerNumber, signature] + want bool + }{ + { + name: "Single", + value: single[headerNumber, signature]{ + headerNumber1, + signature1, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber1, + signature1, + }, + want: true, + }, + { + name: "Single", + value: single[headerNumber, signature]{ + headerNumber1, + signature1, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber2, + signature2, + }, + want: false, + }, + { + name: "Equivocated", + value: equivocated[headerNumber, signature]{ + {headerNumber1, signature1}, + {headerNumber2, signature2}, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber1, + signature1, + }, + want: true, + }, + { + name: "Equivocated", + value: equivocated[headerNumber, signature]{ + {headerNumber1, signature1}, + {headerNumber2, signature2}, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber2, + signature2, + }, + want: true, + }, + { + name: "Equivocated", + value: equivocated[headerNumber, signature]{ + {headerNumber1, signature1}, + {headerNumber2, signature2}, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber1, + signature1, + }, + want: true, + }, + { + name: "Equivocated", + value: equivocated[headerNumber, signature]{ + {headerNumber1, signature1}, + {headerNumber2, signature2}, + }, + args: voteSignature[headerNumber, signature]{ + headerNumber{"bleh", 99}, + signature("bleh"), + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var vm voteMultiplicity[headerNumber, signature] + switch val := tt.value.(type) { + case equivocated[headerNumber, signature]: + vm = newVoteMultiplicity[headerNumber, signature](val) + case single[headerNumber, signature]: + vm = newVoteMultiplicity[headerNumber, signature](val) + } + got := vm.Contains(tt.args.Vote, tt.args.Signature) + if got != tt.want { + t.Errorf("voteMultiplicity.Contains() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRound_EstimateIsValid(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + chain.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + chain.PushBlocks("F", []string{"FA", "FB", "FC"}) + voters := NewVoterSet([]IDWeight[string]{{"Alice", 4}, {"Bob", 7}, {"Eve", 3}}) + + round := NewRound[string, string, uint32, string](RoundParams[string, string, uint32]{ + RoundNumber: 1, + Voters: *voters, + Base: HashNumber[string, uint32]{"C", 4}, + }) + + _, err := round.importPrevote(chain, Prevote[string, uint32]{"FC", 10}, "Alice", "Alice") + assert.NoError(t, err) + + _, err = round.importPrevote(chain, Prevote[string, uint32]{"ED", 10}, "Bob", "Bob") + assert.NoError(t, err) + + assert.Equal(t, HashNumber[string, uint32]{"E", 6}, *round.prevoteGhost) + assert.Equal(t, HashNumber[string, uint32]{"E", 6}, *round.estimate) + assert.False(t, round.completable) + + _, err = round.importPrevote(chain, Prevote[string, uint32]{"F", 7}, "Eve", "Eve") + assert.NoError(t, err) + + assert.Equal(t, HashNumber[string, uint32]{"E", 6}, *round.prevoteGhost) + assert.Equal(t, HashNumber[string, uint32]{"E", 6}, *round.estimate) +} + +func TestRound_Finalisation(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + chain.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + chain.PushBlocks("F", []string{"FA", "FB", "FC"}) + + voters := NewVoterSet([]IDWeight[string]{{"Alice", 4}, {"Bob", 7}, {"Eve", 3}}) + round := NewRound[string, string, uint32, string](RoundParams[string, string, uint32]{ + RoundNumber: 1, + Voters: *voters, + Base: HashNumber[string, uint32]{"C", 4}, + }) + + ir1, err := round.importPrecommit(chain, Precommit[string, uint32]{"FC", 10}, "Alice", "Alice") + assert.NoError(t, err) + assert.NotNil(t, ir1) + + ir1, err = round.importPrecommit(chain, Precommit[string, uint32]{"ED", 10}, "Bob", "Bob") + assert.NoError(t, err) + assert.NotNil(t, ir1) + + assert.Nil(t, round.finalized) + + // import some prevotes. + { + ir, err := round.importPrevote(chain, Prevote[string, uint32]{"FC", 10}, "Alice", "Alice") + assert.NoError(t, err) + assert.NotNil(t, ir) + + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"ED", 10}, "Bob", "Bob") + assert.NoError(t, err) + assert.NotNil(t, ir) + + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"EA", 7}, "Eve", "Eve") + assert.NoError(t, err) + assert.NotNil(t, ir) + + assert.Equal(t, &HashNumber[string, uint32]{"E", 6}, round.finalized) + } + + ir1, err = round.importPrecommit(chain, Precommit[string, uint32]{"EA", 7}, "Eve", "Eve") + assert.NoError(t, err) + assert.NotNil(t, ir1) + + assert.Equal(t, &HashNumber[string, uint32]{"EA", 7}, round.finalized) +} + +func TestRound_EquivocateDoesNotDoubleCount(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + chain.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + chain.PushBlocks("F", []string{"FA", "FB", "FC"}) + + voters := NewVoterSet([]IDWeight[string]{{"Alice", 4}, {"Bob", 7}, {"Eve", 3}}) + round := NewRound[string, string, uint32, string](RoundParams[string, string, uint32]{ + RoundNumber: 1, + Voters: *voters, + Base: HashNumber[string, uint32]{"C", 4}, + }) + + // first prevote by eve + ir, err := round.importPrevote(chain, Prevote[string, uint32]{"FC", 10}, "Eve", "Eve-1") + assert.NoError(t, err) + assert.NotNil(t, ir) + assert.Nil(t, ir.Equivocation) + + assert.Nil(t, round.prevoteGhost) + + // second prevote by eve: comes with equivocation proof + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"ED", 10}, "Eve", "Eve-2") + assert.NoError(t, err) + assert.NotNil(t, ir) + assert.NotNil(t, ir.Equivocation) + + // third prevote: returns nothing. + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"F", 7}, "Eve", "Eve-2") + assert.NoError(t, err) + assert.NotNil(t, ir) + assert.Nil(t, ir.Equivocation) + + // three eves together would be enough. + assert.Nil(t, round.prevoteGhost) + + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"FA", 8}, "Bob", "Bob-1") + assert.NoError(t, err) + assert.NotNil(t, ir) + assert.Nil(t, ir.Equivocation) + + assert.Equal(t, &HashNumber[string, uint32]{"FA", 8}, round.prevoteGhost) +} + +func TestRound_HistoricalVotesWorks(t *testing.T) { + chain := newDummyChain() + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + chain.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + chain.PushBlocks("F", []string{"FA", "FB", "FC"}) + + voters := NewVoterSet([]IDWeight[string]{{"Alice", 4}, {"Bob", 7}, {"Eve", 3}}) + round := NewRound[string, string, uint32, string](RoundParams[string, string, uint32]{ + RoundNumber: 1, + Voters: *voters, + Base: HashNumber[string, uint32]{"C", 4}, + }) + + ir, err := round.importPrevote(chain, Prevote[string, uint32]{"FC", 10}, "Alice", "Alice") + assert.NoError(t, err) + assert.NotNil(t, ir) + + round.historicalVotes.SetPrevotedIdx() + + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"EA", 7}, "Eve", "Eve") + assert.NoError(t, err) + assert.NotNil(t, ir) + + ir1, err := round.importPrecommit(chain, Precommit[string, uint32]{"EA", 7}, "Eve", "Eve") + assert.NoError(t, err) + assert.NotNil(t, ir1) + + ir, err = round.importPrevote(chain, Prevote[string, uint32]{"EC", 10}, "Alice", "Alice") + assert.NoError(t, err) + assert.NotNil(t, ir) + + round.historicalVotes.SetPrecommittedIdx() + + var newUint32 = func(ui uint64) *uint64 { + return &ui + } + assert.Equal(t, HistoricalVotes[string, uint32, string, string]{ + seen: []SignedMessage[string, uint32, string, string]{ + { + Message: Message[string, uint32]{ + value: Prevote[string, uint32]{ + TargetHash: "FC", + TargetNumber: 10, + }, + }, + Signature: "Alice", + ID: "Alice", + }, + { + Message: Message[string, uint32]{ + value: Prevote[string, uint32]{ + TargetHash: "EA", + TargetNumber: 7, + }, + }, + Signature: "Eve", + ID: "Eve", + }, + { + Message: Message[string, uint32]{ + value: Precommit[string, uint32]{ + TargetHash: "EA", + TargetNumber: 7, + }, + }, + Signature: "Eve", + ID: "Eve", + }, + { + Message: Message[string, uint32]{ + value: Prevote[string, uint32]{ + TargetHash: "EC", + TargetNumber: 10, + }, + }, + Signature: "Alice", + ID: "Alice", + }, + }, + prevoteIdx: newUint32(1), + precommitIdx: newUint32(4), + }, round.historicalVotes) +} diff --git a/pkg/finality-grandpa/vote_graph.go b/pkg/finality-grandpa/vote_graph.go new file mode 100644 index 0000000000..2063c56f88 --- /dev/null +++ b/pkg/finality-grandpa/vote_graph.go @@ -0,0 +1,698 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + + "github.com/tidwall/btree" + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) + +type voteGraphEntry[ + Hash constraints.Ordered, + Number constraints.Integer, + voteNode voteNodeI[voteNode, Vote], + Vote any, +] struct { + number Number + // ancestor hashes in reverse order, e.g. ancestors[0] is the parent + // and the last entry is the hash of the parent vote-node. + ancestors []Hash + descendants []Hash // descendent vote-nodes + cumulativeVote voteNode +} + +// whether the given hash, number pair is a direct ancestor of this node. +// `None` signifies that the graph must be traversed further back. +func (vge voteGraphEntry[Hash, Number, voteNode, Vote]) inDirectAncestry(hash Hash, num Number) *bool { + h := vge.ancestorBlock(num) + if h == nil { + return nil + } + b := *h == hash + return &b +} + +// Get ancestor block by number. Returns `None` if there is no block +// by that number in the direct ancestry. +func (vge voteGraphEntry[Hash, Number, voteNode, Vote]) ancestorBlock(num Number) (h *Hash) { + if num >= vge.number { + return nil + } + offset := vge.number - num - 1 + if int(offset) >= len(vge.ancestors) { + return nil + } + ancestor := vge.ancestors[int(offset)] + return &ancestor +} + +// get ancestor vote-node. +func (vge voteGraphEntry[Hash, Number, voteNode, Vote]) ancestorNode() *Hash { + if len(vge.ancestors) == 0 { + return nil + } + h := vge.ancestors[len(vge.ancestors)-1] + return &h +} + +// VoteGraph maintains a DAG of blocks in the chain which have votes attached to them, +// and vote data which is accumulated along edges. +type VoteGraph[ + Hash constraints.Ordered, + Number constraints.Unsigned, + voteNode voteNodeI[voteNode, Vote], + Vote any, +] struct { + entries *btree.Map[Hash, voteGraphEntry[Hash, Number, voteNode, Vote]] + heads *btree.Set[Hash] + base Hash + baseNumber Number + newDefaultvoteNode func() voteNode +} + +// NewVoteGraph creates a new `VoteGraph` with base node as given. +func NewVoteGraph[ + Hash constraints.Ordered, + Number constraints.Unsigned, + voteNode voteNodeI[voteNode, Vote], + Vote any, +]( + baseHash Hash, + baseNumber Number, + baseNode voteNode, + newDefaultvoteNode func() voteNode, +) VoteGraph[Hash, Number, voteNode, Vote] { + entries := btree.NewMap[Hash, voteGraphEntry[Hash, Number, voteNode, Vote]](2) + entries.Set(baseHash, voteGraphEntry[Hash, Number, voteNode, Vote]{ + number: baseNumber, + ancestors: make([]Hash, 0), + descendants: make([]Hash, 0), + cumulativeVote: baseNode, + }) + heads := &btree.Set[Hash]{} + heads.Insert(baseHash) + return VoteGraph[Hash, Number, voteNode, Vote]{ + entries: entries, + heads: heads, + base: baseHash, + baseNumber: baseNumber, + newDefaultvoteNode: newDefaultvoteNode, + } +} + +// append a vote-node onto the chain-tree. This should only be called if +// no node in the tree keeps the target anyway. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) append( + hash Hash, + num Number, + chain Chain[Hash, Number], +) (err error) { + ancestry, err := chain.Ancestry(vg.base, hash) + if err != nil { + return err + } + ancestry = append(ancestry, vg.base) + + var ancestorIndex *int + for i, ancestor := range ancestry { + entry, ok := vg.entries.Get(ancestor) + if ok { + entry.descendants = append(entry.descendants, hash) + vg.entries.Set(ancestor, entry) + if ancestorIndex == nil { + ai := i + ancestorIndex = &ai + break + } + } + } + + if ancestorIndex == nil { + panic(fmt.Errorf("base is kept; chain returns ancestry only if the block is a descendent of base;")) + } + + ancestorHash := ancestry[*ancestorIndex] + ancestry = ancestry[0 : *ancestorIndex+1] + + vg.entries.Set(hash, voteGraphEntry[Hash, Number, voteNode, Vote]{ + number: num, + ancestors: ancestry, + descendants: make([]Hash, 0), + cumulativeVote: vg.newDefaultvoteNode(), + }) + + vg.heads.Delete(ancestorHash) + vg.heads.Insert(hash) + return +} + +// From finality-grandpa: +// introduce a branch to given vote-nodes. +// +// `descendents` is a list of nodes with ancestor-edges containing the given ancestor. +// +// This function panics if any member of `descendents` is not a vote-node +// or does not have ancestor with given hash and number OR if `ancestor_hash` +// is already a known entry. + +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) introduceBranch( + descendants []Hash, + ancestorHash Hash, + ancestorNumber Number, +) { + var producedEntry *struct { + entry voteGraphEntry[Hash, Number, voteNode, Vote] + hash *Hash + } + var maybeEntry *struct { + entry voteGraphEntry[Hash, Number, voteNode, Vote] + hash *Hash + } + for _, descendant := range descendants { + entry, ok := vg.entries.Get(descendant) + if !ok { + panic("this function only invoked with keys of vote-nodes; qed") + } + + ida := entry.inDirectAncestry(ancestorHash, ancestorNumber) + if ida == nil || !*ida { + panic("entry is supposed to be in direct ancestry") + } + + // example: splitting number 10 at ancestor 4 + // before: [9 8 7 6 5 4 3 2 1] + // after: [9 8 7 6 5 4], [3 2 1] + // we ensure the `entry.ancestors` is drained regardless of whether + // the `new_entry` has already been constructed. + { + prevAncestor := entry.ancestorNode() + var offset uint + if ancestorNumber > entry.number { + panic("this function only invoked with direct ancestors; qed") + } else { + offset = uint(entry.number - ancestorNumber) + } + newAncestors := entry.ancestors[offset:len(entry.ancestors)] + entry.ancestors = entry.ancestors[0:offset] + vg.entries.Set(descendant, entry) + + if maybeEntry == nil { + maybeEntry = &struct { + entry voteGraphEntry[Hash, Number, voteNode, Vote] + hash *Hash + }{ + entry: voteGraphEntry[Hash, Number, voteNode, Vote]{ + number: ancestorNumber, + ancestors: newAncestors, + descendants: make([]Hash, 0), + cumulativeVote: vg.newDefaultvoteNode(), + }, + hash: prevAncestor, + } + } + maybeEntry.entry.descendants = append(maybeEntry.entry.descendants, descendant) + maybeEntry.entry.cumulativeVote.Add(entry.cumulativeVote) + } + producedEntry = maybeEntry + } + + if producedEntry != nil { + newEntry := producedEntry.entry + prevAncestor := producedEntry.hash + if prevAncestor != nil { + prevancestorNode, _ := vg.entries.Get(*prevAncestor) + prevancestorNodeDescendants := make([]Hash, 0) + for _, d := range prevancestorNode.descendants { + if !slices.Contains(newEntry.descendants, d) { + prevancestorNodeDescendants = append(prevancestorNodeDescendants, d) + } + } + prevancestorNodeDescendants = append(prevancestorNodeDescendants, ancestorHash) + prevancestorNode.descendants = prevancestorNodeDescendants + vg.entries.Set(*producedEntry.hash, prevancestorNode) + } + vg.entries.Set(ancestorHash, producedEntry.entry) + } +} + +// Insert a vote with given value into the graph at given hash and number. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) Insert( + hash Hash, + num Number, + vote any, + chain Chain[Hash, Number], +) error { + containing := vg.findContainingNodes(hash, num) + switch { + case containing == nil: + // this entry already exists + case len(containing) == 0: + err := vg.append(hash, num, chain) + if err != nil { + return err + } + default: + vg.introduceBranch(containing, hash, num) + } + + // update cumulative vote data. + // NOTE: below this point, there always exists a node with the given hash and number. + var inspectingHash = hash + for { + activeEntry, ok := vg.entries.Get(inspectingHash) + if !ok { + panic("vote-node and its ancestry always exist after initial phase; qed") + } + switch vote := vote.(type) { + case voteNode: + activeEntry.cumulativeVote.Add(vote) + case Vote: + activeEntry.cumulativeVote.AddVote(vote) + default: + panic(fmt.Errorf("unsupported type to add to cumulativeVote %T", vote)) + } + vg.entries.Set(inspectingHash, activeEntry) + + parent := activeEntry.ancestorNode() + if parent != nil { + inspectingHash = *parent + } else { + break + } + } + return nil +} + +// attempts to find the containing node keys for the given hash and number. +// +// returns `None` if there is a node by that key already, and a vector +// (potentially empty) of nodes with the given block in its ancestor-edge +// otherwise. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) findContainingNodes(hash Hash, num Number) (hashes []Hash) { + _, ok := vg.entries.Get(hash) + if ok { + return nil + } + + containingKeys := make([]Hash, 0) + visited := make(map[Hash]interface{}) + + for _, head := range vg.heads.Keys() { + var activeEntry voteGraphEntry[Hash, Number, voteNode, Vote] + + for { + e, ok := vg.entries.Get(head) + if !ok { + break + } + activeEntry = e + + _, ok = visited[head] + // if node has been checked already break + if ok { + break + } + visited[head] = nil + + da := activeEntry.inDirectAncestry(hash, num) + switch { + case da == nil: + prev := activeEntry.ancestorNode() + if prev != nil { + head = *prev + continue // iterate backwards + } + case *da: + // set containing node and continue search. + containingKeys = append(containingKeys, head) + case !*da: + // nothing in this branch. continue search. + } + break + } + } + return containingKeys +} + +// a subchain of blocks by hash. +type subChain[Hash comparable, Number constraints.Unsigned] struct { + hashes []Hash // forward order + bestNumber Number +} + +func (sc subChain[H, N]) best() *HashNumber[H, N] { + if len(sc.hashes) == 0 { + return nil + } + return &HashNumber[H, N]{ + sc.hashes[len(sc.hashes)-1], + sc.bestNumber, + } +} + +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) mustGetEntry( + hash Hash, +) voteGraphEntry[Hash, Number, voteNode, Vote] { + entry, ok := vg.entries.Get(hash) + if !ok { + panic("descendents always present in node storage; qed") + } + return entry +} + +type hashvote[Hash constraints.Ordered, voteNode voteNodeI[voteNode, Vote], Vote any] struct { + hash Hash + vote voteNode +} + +// given a key, node pair (which must correspond), assuming this node fulfils the condition, +// this function will find the highest point at which its descendents merge, which may be the +// node itself. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) ghostFindMergePoint( //skipcq: GO-R1005 + nodeKey Hash, activeNode *voteGraphEntry[Hash, Number, voteNode, Vote], forceConstrain *HashNumber[Hash, Number], + condition func(voteNode) bool) subChain[Hash, Number] { + + var descendantNodes []voteGraphEntry[Hash, Number, voteNode, Vote] + for _, descendant := range activeNode.descendants { + switch { + case forceConstrain == nil: + descendantNodes = append(descendantNodes, vg.mustGetEntry(descendant)) + default: + ida := vg.mustGetEntry(descendant).inDirectAncestry(forceConstrain.Hash, forceConstrain.Number) + switch { + case ida == nil: + case !*ida: + case *ida: + descendantNodes = append(descendantNodes, vg.mustGetEntry(descendant)) + } + + } + } + + baseNumber := activeNode.number + bestNumber := activeNode.number + + descendantBlocks := make([]hashvote[Hash, voteNode, Vote], 0) + hashes := []Hash{nodeKey} + + // TODO: for long ranges of blocks this could get inefficient (copied from rust code) + var offset Number + for { + offset = offset + 1 + + var newBest *Hash + for _, dNode := range descendantNodes { + dBlock := dNode.ancestorBlock(baseNumber + offset) + if dBlock == nil { + continue + } + idx, ok := slices.BinarySearchFunc( + descendantBlocks, + hashvote[Hash, voteNode, Vote]{hash: *dBlock}, + func(a, b hashvote[Hash, voteNode, Vote]) int { + switch { + case a.hash == b.hash: + return 0 + case a.hash > b.hash: + return 1 + case a.hash < b.hash: + return -1 + default: + panic("unreachable") + } + }, + ) + if ok { + descendantBlocks[idx].vote.Add(dNode.cumulativeVote) + if condition(descendantBlocks[idx].vote) { + newBest = dBlock + break + } + } else { + if idx == len(descendantBlocks) { + descendantBlocks = append(descendantBlocks, hashvote[Hash, voteNode, Vote]{ + hash: *dBlock, + vote: dNode.cumulativeVote.Copy(), + }) + } else if idx < len(descendantBlocks) { + descendantBlocks = append( + descendantBlocks[:idx], + append([]hashvote[Hash, voteNode, Vote]{{ + hash: *dBlock, + vote: dNode.cumulativeVote.Copy(), + }}, descendantBlocks[idx:]...)...) + } else { + panic("unreachable") + } + } + } + + if newBest != nil { + bestNumber = bestNumber + 1 + descendantBlocks = make([]hashvote[Hash, voteNode, Vote], 0) + retained := make([]voteGraphEntry[Hash, Number, voteNode, Vote], 0) + for _, descendant := range descendantNodes { + ida := descendant.inDirectAncestry(*newBest, bestNumber) + if ida != nil && *ida { + retained = append(retained, descendant) + } + } + descendantNodes = retained + hashes = append(hashes, *newBest) + } else { + break + } + } + + return subChain[Hash, Number]{ + hashes: hashes, + bestNumber: bestNumber, + } +} + +type hashVoteGraphEntry[ + Hash constraints.Ordered, + Number constraints.Integer, + voteNode voteNodeI[voteNode, Vote], + Vote any, +] struct { + hash Hash + entry voteGraphEntry[Hash, Number, voteNode, Vote] +} + +// FindGHOST will find the best GHOST descendent of the given block. +// Pass a closure used to evaluate the cumulative vote value. +// +// The GHOST (hash, number) returned will be the block with highest number for which the +// cumulative votes of descendents and itself causes the closure to evaluate to true. +// +// This assumes that the evaluation closure is one which returns true for at most a single +// descendent of a block, in that only one fork of a block can be "heavy" +// enough to trigger the threshold. +// +// Returns `nil` when the given `currentBest` does not fulfil the condition. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) FindGHOST( //skipcq: GO-R1005 + currentBest *HashNumber[Hash, Number], + condition func(voteNode) bool, +) *HashNumber[Hash, Number] { + var getNode = func(hash Hash) *voteGraphEntry[Hash, Number, voteNode, Vote] { + entry, ok := vg.entries.Get(hash) + if !ok { + panic("node either base or referenced by other in graph; qed") + } + return &entry + } + + var nodeKey Hash + var forceConstrain bool + + if currentBest == nil { + nodeKey = vg.base + forceConstrain = false + } else { + containing := vg.findContainingNodes(currentBest.Hash, currentBest.Number) + switch { + case containing == nil: + nodeKey = currentBest.Hash + forceConstrain = false + case len(containing) > 0: + ancestor := getNode(containing[0]).ancestorNode() + if ancestor == nil { + panic("node containing non-node in history always has ancestor; qed") + } + nodeKey = *ancestor + forceConstrain = true + default: + nodeKey = vg.base + forceConstrain = false + } + } + + activeNode := getNode(nodeKey) + + if !condition(activeNode.cumulativeVote) { + return nil + } + + // breadth-first search starting from this node. +loop: + for { + var nextDescendant *hashVoteGraphEntry[Hash, Number, voteNode, Vote] + filteredDescendants := make([]*hashVoteGraphEntry[Hash, Number, voteNode, Vote], 0) + + for _, descendant := range activeNode.descendants { + if forceConstrain && currentBest != nil { + node := getNode(descendant) + ida := node.inDirectAncestry(currentBest.Hash, currentBest.Number) + switch { + case ida == nil: + case !*ida: + case *ida: + filteredDescendants = append(filteredDescendants, &hashVoteGraphEntry[Hash, Number, voteNode, Vote]{ + hash: descendant, + entry: *node, + }) + } + } else { + node := getNode(descendant) + filteredDescendants = append(filteredDescendants, &hashVoteGraphEntry[Hash, Number, voteNode, Vote]{ + hash: descendant, + entry: *node, + }) + } + } + + for _, hvge := range filteredDescendants { + if condition(hvge.entry.cumulativeVote) { + nextDescendant = &hashVoteGraphEntry[Hash, Number, voteNode, Vote]{ + hash: hvge.hash, + entry: hvge.entry, + } + break + } + } + + switch nextDescendant { + case nil: + break loop + default: + forceConstrain = false + nodeKey = nextDescendant.hash + activeNode = &nextDescendant.entry + } + + } + + var hn *HashNumber[Hash, Number] + if forceConstrain { + hn = currentBest + } + + return vg.ghostFindMergePoint(nodeKey, activeNode, hn, condition).best() +} + +// FindAncestor will find the block with the highest block number in the chain with the given head +// which fulfils the given condition. +// +// Returns `nil` if the given head is not in the graph or no node fulfils the +// given condition. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) FindAncestor( + hash Hash, + number Number, + condition func(voteNode) bool, +) *HashNumber[Hash, Number] { + for { + children := vg.findContainingNodes(hash, number) + if children == nil { + // The block has a vote-node in the graph. + node := vg.mustGetEntry(hash) + // If the weight is sufficient, we are done. + if condition(node.cumulativeVote) { + return &HashNumber[Hash, Number]{hash, number} + } + // Not enough weight, check the parent block. + if len(node.ancestors) == 0 { + return nil + } + hash = node.ancestors[0] + number = node.number - 1 + } else { + // If there are no vote-nodes below the block in the graph, + // the block is not in the graph at all. + if len(children) == 0 { + return nil + } + // The block is "contained" in the graph (i.e. in the ancestry-chain + // of at least one vote-node) but does not itself have a vote-node. + // Check if the accumulated weight on all child vote-nodes is sufficient. + v := vg.newDefaultvoteNode() + for _, c := range children { + e := vg.mustGetEntry(c) + v.Add(e.cumulativeVote) + } + if condition(v) { + return &HashNumber[Hash, Number]{hash, number} + } + + // Not enough weight, check the parent block. + child := children[len(children)-1] + entry := vg.mustGetEntry(child) + offset := int(entry.number - number) + + if offset >= len(entry.ancestors) { + // Reached base without sufficient weight. + return nil + } + parent := entry.ancestors[offset] + + hash = parent + number = number - 1 + } + } +} + +// AdjustBase will adjust the base of the graph. The new base must be an ancestor of the +// old base. +// +// Provide an ancestry proof from the old base to the new. The proof +// should be in reverse order from the old base's parent. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) AdjustBase(ancestryProof []Hash) { + if len(ancestryProof) == 0 { + return // empty nothing to do + } + newHash := ancestryProof[len(ancestryProof)-1] + + // not a valid ancestry proof. TODO: error? (TODO copied from rust code) + if len(ancestryProof) > int(vg.baseNumber) { + return + } + + newNumber := vg.baseNumber + newNumber = newNumber - Number(len(ancestryProof)) + + oldEntry := vg.mustGetEntry(vg.base) + oldEntry.ancestors = append(oldEntry.ancestors, ancestryProof...) + vg.entries.Set(vg.base, oldEntry) + + entry := voteGraphEntry[Hash, Number, voteNode, Vote]{ + number: newNumber, + ancestors: make([]Hash, 0), + descendants: []Hash{vg.base}, + cumulativeVote: oldEntry.cumulativeVote.Copy(), + } + vg.entries.Set(newHash, entry) + vg.base = newHash + vg.baseNumber = newNumber +} + +// Base returns the base block. +func (vg *VoteGraph[Hash, Number, voteNode, Vote]) Base() HashNumber[Hash, Number] { + return HashNumber[Hash, Number]{ + vg.base, + vg.baseNumber, + } +} diff --git a/pkg/finality-grandpa/vote_graph_test.go b/pkg/finality-grandpa/vote_graph_test.go new file mode 100644 index 0000000000..ffa73674ec --- /dev/null +++ b/pkg/finality-grandpa/vote_graph_test.go @@ -0,0 +1,350 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type uintvoteNode uint + +func (uvn *uintvoteNode) Add(other *uintvoteNode) { + *uvn += *other +} + +func (uvn *uintvoteNode) AddVote(other int) { + *uvn += uintvoteNode(other) +} + +func (uvn *uintvoteNode) String() string { + return fmt.Sprintf("%+v", *uvn) +} + +func (uvn *uintvoteNode) Copy() *uintvoteNode { + copied := *uvn + return &copied +} + +func createUintVoteNode(i int) *uintvoteNode { + vn := uintvoteNode(i) + return &vn +} + +func newUintVoteNode() *uintvoteNode { + return createUintVoteNode(0) +} + +func TestVoteGraph_GraphForkNotAtNode(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("A", 2, createUintVoteNode(100), c)) + assert.NoError(t, vg.Insert("E1", 6, createUintVoteNode(100), c)) + assert.NoError(t, vg.Insert("F2", 7, createUintVoteNode(100), c)) + + assert.Contains(t, vg.heads.Keys(), "E1") + assert.Contains(t, vg.heads.Keys(), "F2") + assert.NotContains(t, vg.heads.Keys(), "A") + + var getEntry = func(key string) voteGraphEntry[string, uint, *uintvoteNode, int] { + entry, _ := vg.entries.Get(key) + return entry + } + + assert.Equal(t, []string{"E1", "F2"}, getEntry("A").descendants) + assert.Equal(t, createUintVoteNode(300), getEntry("A").cumulativeVote) + + assert.Equal(t, "A", *getEntry("E1").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("E1").cumulativeVote) + + assert.Equal(t, "A", *getEntry("F2").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("F2").cumulativeVote) +} + +func TestVoteGraph_GraphForkNotAtNode1(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("A", 2, 100, c)) + assert.NoError(t, vg.Insert("E1", 6, 100, c)) + assert.NoError(t, vg.Insert("F2", 7, 100, c)) + + assert.Contains(t, vg.heads.Keys(), "E1") + assert.Contains(t, vg.heads.Keys(), "F2") + assert.NotContains(t, vg.heads.Keys(), "A") + + var getEntry = func(key string) voteGraphEntry[string, uint, *uintvoteNode, int] { + entry, _ := vg.entries.Get(key) + return entry + } + + assert.Equal(t, []string{"E1", "F2"}, getEntry("A").descendants) + assert.Equal(t, createUintVoteNode(300), getEntry("A").cumulativeVote) + + assert.Equal(t, "A", *getEntry("E1").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("E1").cumulativeVote) + + assert.Equal(t, "A", *getEntry("F2").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("F2").cumulativeVote) +} + +func TestVoteGraph_GraphForkAtNode(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2"}) + + vn := uintvoteNode(0) + vg1 := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg1.Insert("C", 4, createUintVoteNode(100), c)) + assert.NoError(t, vg1.Insert("E1", 6, createUintVoteNode(100), c)) + assert.NoError(t, vg1.Insert("F2", 7, createUintVoteNode(100), c)) + + vn1 := uintvoteNode(0) + vg2 := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn1, newUintVoteNode) + assert.NoError(t, vg2.Insert("E1", 6, createUintVoteNode(100), c)) + assert.NoError(t, vg2.Insert("F2", 7, createUintVoteNode(100), c)) + assert.NoError(t, vg2.Insert("C", 4, createUintVoteNode(100), c)) + + for _, test := range []struct { + name string + VoteGraph[string, uint, *uintvoteNode, int] + }{ + { + name: "vg1", + VoteGraph: vg1, + }, + { + name: "vg2", + VoteGraph: vg1, + }, + } { + t.Run(test.name, func(t *testing.T) { + vg := test.VoteGraph + + var getEntry = func(key string) voteGraphEntry[string, uint, *uintvoteNode, int] { + entry, _ := vg.entries.Get(key) + return entry + } + + assert.Contains(t, vg.heads.Keys(), "E1") + assert.Contains(t, vg.heads.Keys(), "F2") + assert.NotContains(t, vg.heads.Keys(), "C") + + assert.Contains(t, vg.entries.Keys(), "C") + assert.Contains(t, getEntry("C").descendants, "E1") + assert.Contains(t, getEntry("C").descendants, "F2") + assert.Equal(t, GenesisHash, *getEntry("C").ancestorNode()) + assert.Equal(t, createUintVoteNode(300), getEntry("C").cumulativeVote) + + assert.Contains(t, vg.entries.Keys(), "E1") + assert.Equal(t, "C", *getEntry("E1").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("E1").cumulativeVote) + + assert.Contains(t, vg.entries.Keys(), "F2") + assert.Equal(t, "C", *getEntry("F2").ancestorNode()) + assert.Equal(t, createUintVoteNode(100), getEntry("F2").cumulativeVote) + }) + } +} + +func TestVoteGraph_GhostMergeAtNode(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("B", 3, createUintVoteNode(0), c)) + assert.NoError(t, vg.Insert("C", 4, createUintVoteNode(100), c)) + assert.NoError(t, vg.Insert("E1", 6, createUintVoteNode(100), c)) + assert.NoError(t, vg.Insert("F2", 7, createUintVoteNode(100), c)) + + assert.Equal(t, &HashNumber[string, uint]{"C", 4}, vg.FindGHOST(nil, func(i *uintvoteNode) bool { return *i >= 250 })) + assert.Equal(t, &HashNumber[string, uint]{"C", 4}, + vg.FindGHOST(&HashNumber[string, uint]{"C", 4}, func(i *uintvoteNode) bool { return *i >= 250 })) + assert.Equal(t, &HashNumber[string, uint]{"C", 4}, + vg.FindGHOST(&HashNumber[string, uint]{"B", 3}, func(i *uintvoteNode) bool { return *i >= 250 })) +} + +func TestVoteGraph_GhostMergeNoteAtNodeOneSideWeighted(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + c.PushBlocks("F", []string{"G1", "H1", "I1"}) + c.PushBlocks("F", []string{"G2", "H2", "I2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("B", 3, createUintVoteNode(0), c)) + assert.NoError(t, vg.Insert("G1", 8, createUintVoteNode(100), c)) + assert.NoError(t, vg.Insert("H2", 9, createUintVoteNode(150), c)) + + assert.Equal(t, &HashNumber[string, uint]{"F", 7}, vg.FindGHOST(nil, func(i *uintvoteNode) bool { return *i >= 250 })) + assert.Equal(t, &HashNumber[string, uint]{"F", 7}, + vg.FindGHOST(&HashNumber[string, uint]{"F", 7}, func(i *uintvoteNode) bool { return *i >= 250 })) + assert.Equal(t, &HashNumber[string, uint]{"F", 7}, + vg.FindGHOST(&HashNumber[string, uint]{"C", 4}, func(i *uintvoteNode) bool { return *i >= 250 })) + assert.Equal(t, &HashNumber[string, uint]{"F", 7}, + vg.FindGHOST(&HashNumber[string, uint]{"B", 3}, func(i *uintvoteNode) bool { return *i >= 250 })) +} + +func TestVoteGraph_GhostIntroduceBranch(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + c.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + c.PushBlocks("F", []string{"FA", "FB", "FC"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("FC", 10, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("ED", 10, createUintVoteNode(7), c)) + + var getEntry = func(key string) voteGraphEntry[string, uint, *uintvoteNode, int] { + entry, _ := vg.entries.Get(key) + return entry + } + + assert.Equal(t, &HashNumber[string, uint]{"E", 6}, vg.FindGHOST(nil, func(x *uintvoteNode) bool { return *x >= 10 })) + assert.Equal(t, []string{"FC", "ED"}, getEntry(GenesisHash).descendants) + + // introduce a branch in the middle. + assert.NoError(t, vg.Insert("E", 6, createUintVoteNode(3), c)) + + assert.Equal(t, []string{"E"}, getEntry(GenesisHash).descendants) + assert.Equal(t, 2, len(getEntry("E").descendants)) + assert.Contains(t, getEntry("E").descendants, "ED") + assert.Contains(t, getEntry("E").descendants, "FC") + + assert.Equal(t, &HashNumber[string, uint]{"E", 6}, vg.FindGHOST(nil, func(x *uintvoteNode) bool { return *x >= 10 })) + assert.Equal(t, &HashNumber[string, uint]{"E", 6}, + vg.FindGHOST(&HashNumber[string, uint]{"C", 4}, func(x *uintvoteNode) bool { return *x >= 10 })) + assert.Equal(t, &HashNumber[string, uint]{"E", 6}, + vg.FindGHOST(&HashNumber[string, uint]{"E", 6}, func(x *uintvoteNode) bool { return *x >= 10 })) +} + +func TestVoteGraph_WalkBackFromBlockInEdgeForkBelow(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1", "G1", "H1", "I1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2", "G2", "H2", "I2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("B", 3, createUintVoteNode(10), c)) + assert.NoError(t, vg.Insert("F1", 7, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("G2", 8, createUintVoteNode(5), c)) + + for _, block := range []string{"D1", "D2", "E1", "E2", "F1", "F2", "G2"} { + number := c.Number(block) + assert.Equal(t, &HashNumber[string, uint]{"C", 4}, + vg.FindAncestor(block, uint(number), func(x *uintvoteNode) bool { return *x > 5 })) + } +} + +func TestVoteGraph_WalkBackFromForkBlockNodeBelow(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C", "D"}) + c.PushBlocks("D", []string{"E1", "F1", "G1", "H1", "I1"}) + c.PushBlocks("D", []string{"E2", "F2", "G2", "H2", "I2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("B", 3, createUintVoteNode(10), c)) + assert.NoError(t, vg.Insert("F1", 7, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("G2", 8, createUintVoteNode(5), c)) + + assert.Equal(t, &HashNumber[string, uint]{"D", 5}, + vg.FindAncestor("G2", 8, func(x *uintvoteNode) bool { return *x > 5 })) + for _, block := range []string{"E1", "E2", "F1", "F2", "G2"} { + number := c.Number(block) + assert.Equal(t, &HashNumber[string, uint]{"D", 5}, + vg.FindAncestor(block, uint(number), func(x *uintvoteNode) bool { return *x > 5 })) + } +} + +func TestVoteGraph_WalkBackAtNode(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C"}) + c.PushBlocks("C", []string{"D1", "E1", "F1", "G1", "H1", "I1"}) + c.PushBlocks("C", []string{"D2", "E2", "F2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(1), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("C", 4, createUintVoteNode(10), c)) + assert.NoError(t, vg.Insert("F1", 7, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("F2", 7, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("I1", 10, createUintVoteNode(1), c)) + + for _, block := range []string{"C", "D1", "D2", "E1", "E2", "F1", "F2", "I1"} { + number := c.Number(block) + assert.Equal(t, &HashNumber[string, uint]{"C", 4}, + vg.FindAncestor(block, uint(number), func(x *uintvoteNode) bool { return *x >= 20 })) + } +} + +func TestVoteGraph_AdjustBase(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E", "F"}) + c.PushBlocks("E", []string{"EA", "EB", "EC", "ED"}) + c.PushBlocks("F", []string{"FA", "FB", "FC"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int]("E", uint(6), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("FC", 10, createUintVoteNode(5), c)) + assert.NoError(t, vg.Insert("ED", 10, createUintVoteNode(7), c)) + + assert.Equal(t, HashNumber[string, uint]{"E", 6}, vg.Base()) + + vg.AdjustBase([]string{"D", "C", "B", "A"}) + + assert.Equal(t, HashNumber[string, uint]{"A", 2}, vg.Base()) + + c.PushBlocks("A", []string{"3", "4", "5"}) + + vg.AdjustBase([]string{GenesisHash}) + assert.Equal(t, HashNumber[string, uint]{GenesisHash, 1}, vg.Base()) + + var getEntry = func(key string) voteGraphEntry[string, uint, *uintvoteNode, int] { + entry, _ := vg.entries.Get(key) + return entry + } + + assert.Equal(t, createUintVoteNode(12), getEntry(GenesisHash).cumulativeVote) + + assert.NoError(t, vg.Insert("5", 5, createUintVoteNode(3), c)) + + assert.Equal(t, int(15), int(*getEntry(GenesisHash).cumulativeVote)) +} + +func TestVoteGraph_FindAncestorIsLargest(t *testing.T) { + c := newDummyChain() + c.PushBlocks(GenesisHash, []string{"A"}) + c.PushBlocks(GenesisHash, []string{"B"}) + c.PushBlocks("A", []string{"A1"}) + c.PushBlocks("A", []string{"A2"}) + c.PushBlocks("B", []string{"B1"}) + c.PushBlocks("B", []string{"B2"}) + + vn := uintvoteNode(0) + vg := NewVoteGraph[string, uint, *uintvoteNode, int](GenesisHash, uint(0), &vn, newUintVoteNode) + assert.NoError(t, vg.Insert("B1", 2, createUintVoteNode(1), c)) + assert.NoError(t, vg.Insert("B2", 2, createUintVoteNode(1), c)) + assert.NoError(t, vg.Insert("A1", 2, createUintVoteNode(1), c)) + assert.NoError(t, vg.Insert("A2", 2, createUintVoteNode(1), c)) + + assert.Equal(t, &HashNumber[string, uint]{"A", 1}, + vg.FindAncestor("A", 1, func(x *uintvoteNode) bool { return *x >= 2 })) +} diff --git a/pkg/finality-grandpa/voter.go b/pkg/finality-grandpa/voter.go new file mode 100644 index 0000000000..fbb37fdbf1 --- /dev/null +++ b/pkg/finality-grandpa/voter.go @@ -0,0 +1,1125 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "sync" + "time" + + "github.com/tidwall/btree" + "golang.org/x/exp/constraints" +) + +type wakerChan[Item any] struct { + in chan Item + out chan Item + waker *waker +} + +func newWakerChan[Item any](in chan Item) *wakerChan[Item] { + wc := &wakerChan[Item]{ + in: in, + out: make(chan Item), + waker: nil, + } + go wc.start() + return wc +} + +func (wc *wakerChan[Item]) start() { + defer close(wc.out) + for item := range wc.in { + if wc.waker != nil { + wc.waker.wake() + } + wc.out <- item + } +} + +func (wc *wakerChan[Item]) setWaker(waker *waker) { + wc.waker = waker +} + +// Chan returns a channel to consume `Item`. Not thread safe, only supports one consumer +func (wc *wakerChan[Item]) channel() chan Item { + return wc.out +} + +// Timer is the associated timer type for the environment +type Timer interface { + SetWaker(waker *waker) + Elapsed() (bool, error) +} + +// Output is the output stream used to communicate with the outside world. +type Output[Hash comparable, Number constraints.Unsigned] chan Message[Hash, Number] + +// Input is the input stream used to communicate with the outside world. +type Input[ + Hash comparable, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] chan SignedMessageError[Hash, Number, Signature, ID] + +// SignedMessageError contains a `SignedMessage“ and error +type SignedMessageError[ + Hash comparable, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] struct { + SignedMessage SignedMessage[Hash, Number, Signature, ID] + Error error +} + +// BestChainOutput is the item type for `BestChain` +type BestChainOutput[Hash comparable, Number constraints.Unsigned] struct { + Value *HashNumber[Hash, Number] + Error error +} + +// BestChain is Associated channel for the environment used when asynchronously computing the +// best chain to vote on. See also `Environment.BestChainContaining`. +type BestChain[Hash comparable, Number constraints.Unsigned] chan BestChainOutput[Hash, Number] + +// Environment is the necessary environment for a voter. +// +// This encapsulates the database and networking layers of the chain. +type Environment[Hash comparable, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered] interface { + Chain[Hash, Number] + // Return a channel that will produce the hash of the best block whose chain + // contains the given block hash, even if that block is `base` itself. + // + // If `base` is unknown the future outputs `nil`. + BestChainContaining(base Hash) BestChain[Hash, Number] + + // Produce data necessary to start a round of voting. This may also be called + // with the round number of the most recently completed round, in which case + // it should yield a valid input stream. + // + // The input stream should provide messages which correspond to known blocks + // only. + // + // The voting logic will push unsigned messages over-eagerly into the + // output stream. It is the job of this stream to determine if those messages + // should be sent (for example, if the process actually controls a permissioned key) + // and then to sign the message, multicast it to peers, and schedule it to be + // returned by the `In` stream. + // + // This allows the voting logic to maintain the invariant that only incoming messages + // may alter the state, and the logic remains the same regardless of whether a node + // is a regular voter, the proposer, or simply an observer. + // + // Furthermore, this means that actual logic of creating and verifying + // signatures is flexible and can be maintained outside this crate. + RoundData( + round uint64, + outgoing Output[Hash, Number], + ) RoundData[Hash, Number, Signature, ID] + + // Return a timer that will be used to delay the broadcast of a commit + // message. This delay should not be static to minimise the amount of + // commit messages that are sent (e.g. random value in [0, 1] seconds). + RoundCommitTimer() Timer + + // Note that we've done a primary proposal in the given round. + Proposed(round uint64, propose PrimaryPropose[Hash, Number]) error + + // Note that we have prevoted in the given round. + Prevoted(round uint64, prevote Prevote[Hash, Number]) error + + // Note that we have precommitted in the given round. + Precommitted(round uint64, precommit Precommit[Hash, Number]) error + + // Note that a round is completed. This is called when a round has been + // voted in and the next round can start. The round may continue to be run + // in the background until _concluded_. + // Should return an error when something fatal occurs. + Completed( + round uint64, + state RoundState[Hash, Number], + base HashNumber[Hash, Number], + votes HistoricalVotes[Hash, Number, Signature, ID], + ) error + + // Note that a round has concluded. This is called when a round has been + // `completed` and additionally, the round's estimate has been finalized. + // + // There may be more votes than when `completed`, and it is the responsibility + // of the `Environment` implementation to deduplicate. However, the caller guarantees + // that the votes passed to `completed` for this round are a prefix of the votes passed here. + Concluded( + round uint64, + state RoundState[Hash, Number], + base HashNumber[Hash, Number], + votes HistoricalVotes[Hash, Number, Signature, ID], + ) error + + // Called when a block should be finalized. + FinalizeBlock( + hash Hash, + number Number, + round uint64, + commit Commit[Hash, Number, Signature, ID], + ) error + + // Note that an equivocation in prevotes has occurred. + PrevoteEquivocation( + round uint64, + equivocation Equivocation[ID, Prevote[Hash, Number], Signature], + ) + + // Note that an equivocation in prevotes has occurred. + PrecommitEquivocation( + round uint64, + equivocation Equivocation[ID, Precommit[Hash, Number], Signature], + ) +} + +type finalizedNotification[Hash, Number, Signature, ID any] struct { + Hash Hash + Number Number + Round uint64 + Commit Commit[Hash, Number, Signature, ID] +} + +// RoundData is the data necessary to participate in a round. +type RoundData[Hash comparable, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered] struct { + // Local voter id (if any.) + VoterID *ID + // Timer before prevotes can be cast. This should be Start + 2T + // where T is the gossip time estimate. + PrevoteTimer Timer + // Timer before precommits can be cast. This should be Start + 4T + PrecommitTimer Timer + // Incoming messages. + // Incoming chan SignedMessageError + Incoming Input[Hash, Number, Signature, ID] +} + +type buffered[I any] struct { + inner chan I + buffer []I + mtx sync.Mutex + readyCh chan any +} + +func newBuffered[I any](inner chan I) *buffered[I] { + b := &buffered[I]{ + inner: inner, + readyCh: make(chan any, 1), + } + // prime the channel + b.readyCh <- nil + return b +} + +func (b *buffered[I]) Push(item I) { + b.mtx.Lock() + defer b.mtx.Unlock() + b.buffer = append(b.buffer, item) +} + +func (b *buffered[I]) Poll(waker *waker) (bool, error) { + return b.flush(waker) +} + +func (b *buffered[I]) flush(waker *waker) (bool, error) { + if b.inner == nil { + return false, fmt.Errorf("inner channel has been closed") + } + + b.mtx.Lock() + defer b.mtx.Unlock() + if len(b.buffer) == 0 { + return true, nil + } + select { + case <-b.readyCh: + defer func() { + b.readyCh <- nil + waker.wake() + }() + + for len(b.buffer) > 0 { + b.inner <- b.buffer[0] + b.buffer = b.buffer[1:] + waker.wake() + } + + default: + } + return false, nil +} + +func (b *buffered[I]) Close() { + b.mtx.Lock() + defer b.mtx.Unlock() + close(b.inner) + b.inner = nil +} + +// Instantiates the given last round, to be backgrounded until its estimate is finalized. +// +// This round must be completable based on the passed votes (and if not, `None` will be returned), +// but it may be the case that there are some more votes to propagate in order to push +// the estimate backwards and conclude the round (i.e. finalize its estimate). +// +// may only be called with non-zero last round. +func instantiateLastRound[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +]( + voters VoterSet[ID], + lastRoundVotes []SignedMessage[Hash, Number, Signature, ID], + lastRoundNumber uint64, + lastRoundBase HashNumber[Hash, Number], + finalizedSender chan finalizedNotification[Hash, Number, Signature, ID], + env E, +) *votingRound[Hash, Number, Signature, ID, E] { + lastRoundTracker := NewRound[ID, Hash, Number, Signature](RoundParams[ID, Hash, Number]{ + Voters: voters, + Base: lastRoundBase, + RoundNumber: lastRoundNumber, + }) + + // start as completed so we don't cast votes. + lastRound := newVotingRoundCompleted(lastRoundTracker, finalizedSender, nil, env) + + for _, vote := range lastRoundVotes { + // bail if any votes are bad. + err := lastRound.handleVote(vote) + if err != nil { + log.Debugf("lastRound.Handlevote error: %v", err) + return nil + } + } + + if lastRound.roundState().Completable { + return &lastRound + } + return nil +} + +// The inner state of a voter aggregating the currently running round state +// (i.e. best and background rounds). This state exists separately since it's +// useful to wrap in a `Arc>` for sharing. +type innerVoterState[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +] struct { + bestRound votingRound[Hash, Number, Signature, ID, E] + pastRounds pastRounds[Hash, Number, Signature, ID, E] + sync.Mutex +} + +// CommunicationOut is communication between nodes that is not round-localised. +type CommunicationOut struct { + variant any +} + +// CommuincationOutVariants is interface constraint of `CommunicationOut` +type CommuincationOutVariants[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] interface { + CommunicationOutCommit[Hash, Number, Signature, ID] +} + +func newCommunicationOut[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + T CommuincationOutVariants[Hash, Number, Signature, ID], +](variant T) CommunicationOut { + co := CommunicationOut{} + setCommunicationOut[Hash, Number, Signature, ID](&co, variant) + return co +} + +func setCommunicationOut[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + T CommuincationOutVariants[Hash, Number, Signature, ID], +](co *CommunicationOut, variant T) { + co.variant = variant +} + +// CommunicationOutCommit is a commit message. +type CommunicationOutCommit[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] numberCommit[Hash, Number, Signature, ID] + +// CommitProcessingOutcome is the outcome of processing a commit. +type CommitProcessingOutcome struct { + variant any +} + +// CommitProcessingOutcomeGood means it was beneficial to process this commit. +type CommitProcessingOutcomeGood GoodCommit + +// CommitProcessingOutcomeBad means it wasn't beneficial to process this commit. We wasted resources. +type CommitProcessingOutcomeBad BadCommit + +// GoodCommit is the result of processing for a good commit. +type GoodCommit struct{} + +// BadCommit is the result of processing for a bad commit +type BadCommit struct { + numPrecommits uint + numDuplicatedPrecommits uint + numEquivocations uint + numInvalidVoters uint +} + +// NumPrecommits returns the number of precommits. +func (bc BadCommit) NumPrecommits() uint { + return bc.numPrecommits +} + +// NumDuplicatedPrecommits returns the number of duplicated precommits. +func (bc BadCommit) NumDuplicatedPrecommits() uint { + return bc.numDuplicatedPrecommits +} + +// NumEquiovcations returns the number of equivocations in the precommits +func (bc BadCommit) NumEquiovcations() uint { + return bc.numEquivocations +} + +// NumInvalidVoters returns the number of invalid voters in the precommits +func (bc BadCommit) NumInvalidVoters() uint { + return bc.numInvalidVoters +} + +func newBadCommit(cvr CommitValidationResult) BadCommit { + return BadCommit{ + numPrecommits: cvr.NumPrecommits(), + numDuplicatedPrecommits: cvr.NumDuplicatedPrecommits(), + numEquivocations: cvr.NumEquiovcations(), + numInvalidVoters: cvr.NumInvalidVoters(), + } +} + +// CatchUpProcessingOutcome is the outcome of processing a catch up. +type CatchUpProcessingOutcome struct { + variant any +} + +func newCatchUpProcessingOutcome[T CatchUpProcessingOutcomes](variant T) CatchUpProcessingOutcome { + return CatchUpProcessingOutcome{ + variant: variant, + } +} + +// CatchUpProcessingOutcomes is the interface constraint for `CatchUpProcessingOutcome` +type CatchUpProcessingOutcomes interface { + CatchUpProcessingOutcomeGood | CatchUpProcessingOutcomeBad | CatchUpProcessingOutcomeUseless +} + +// CatchUpProcessingOutcomeGood means it was beneficial to process this catch up. +type CatchUpProcessingOutcomeGood GoodCatchUp + +// CatchUpProcessingOutcomeBad means it wasn't beneficial to process this catch up, it is invalid and we +// wasted resources. +type CatchUpProcessingOutcomeBad BadCatchUp + +// CatchUpProcessingOutcomeUseless mean the catch up wasn't processed because it is useless, e.g. it is for a +// round lower than we're currently in. +type CatchUpProcessingOutcomeUseless struct{} + +// GoodCatchUp is the result of processing for a good catch up. +type GoodCatchUp struct{} + +// BadCatchUp is the result of processing for a bad catch up. +type BadCatchUp struct{} + +type CommunicationIn struct { + variant any +} + +func setCommunicationIn[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered, + T CommunicationInVariants[Hash, Number, Signature, ID], +](ci *CommunicationIn, variant T) { + ci.variant = variant +} + +func newCommunicationIn[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered, + T CommunicationInVariants[Hash, Number, Signature, ID], +](variant T) CommunicationIn { + ci := CommunicationIn{} + setCommunicationIn[Hash, Number, Signature, ID](&ci, variant) + return ci +} + +type CommunicationInVariants[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] interface { + CommunicationInCommit[Hash, Number, Signature, ID] | CommunicationInCatchUp[Hash, Number, Signature, ID] +} +type CommunicationInCommit[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] struct { + Number uint64 + CompactCommit CompactCommit[Hash, Number, Signature, ID] + Callback func(CommitProcessingOutcome) +} + +type CommunicationInCatchUp[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, +] struct { + CatchUp CatchUp[Hash, Number, Signature, ID] + Callback func(CatchUpProcessingOutcome) +} + +type globalInItem struct { + CommunicationIn + Error error +} + +// Voter maintains and multiplexes between different rounds, +// and caches votes. +// +// This voter also implements the commit protocol. +// The commit protocol allows a node to broadcast a message that finalises a +// given block and includes a set of precommits as proof. +// +// - When a round is completable and we precommitted we start a commit timer +// and start accepting commit messages; +// - When we receive a commit message if it targets a block higher than what +// we've finalized we validate it and import its precommits if valid; +// - When our commit timer triggers we check if we've received any commit +// message for a block equal to what we've finalized, if we haven't then we +// broadcast a commit. +// +// Additionally, we also listen to commit messages from rounds that aren't +// currently running, we validate the commit and dispatch a finalisation +// notification (if any) to the environment. +type Voter[Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered] struct { + env Environment[Hash, Number, Signature, ID] + voters VoterSet[ID] + inner *innerVoterState[Hash, Number, Signature, ID, Environment[Hash, Number, Signature, ID]] + finalizedNotifications *wakerChan[finalizedNotification[Hash, Number, Signature, ID]] + lastFinalizedNumber Number + globalIn *wakerChan[globalInItem] + globalOut *buffered[CommunicationOut] + // the commit protocol might finalize further than the current round (if we're + // behind), we keep track of last finalized in round so we don't violate any + // assumptions from round-to-round. + lastFinalizedInRounds HashNumber[Hash, Number] + + stopTimeout time.Duration + stopChan chan any + wg sync.WaitGroup +} + +// NewVoter creates a new `Voter` tracker with given round number and base block. +// +// Provide data about the last completed round. If there is no +// known last completed round, the genesis state (round number 0, no votes, genesis base), +// should be provided. When available, all messages required to complete +// the last round should be provided. +// +// The input stream for commit messages should provide commits which +// correspond to known blocks only (including all its precommits). It +// is also responsible for validating the signature data in commit +// messages. +func NewVoter[Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered]( + env Environment[Hash, Number, Signature, ID], + voters VoterSet[ID], + globalIn chan globalInItem, + lastRoundNumber uint64, + lastRoundVotes []SignedMessage[Hash, Number, Signature, ID], + lastRoundBase HashNumber[Hash, Number], + lastFinalized HashNumber[Hash, Number], +) (*Voter[Hash, Number, Signature, ID], chan CommunicationOut) { + finalizedSender := make(chan finalizedNotification[Hash, Number, Signature, ID], 1) + finalizedNotifications := finalizedSender + lastFinalizedNumber := lastFinalized.Number + + pastRounds := newPastRounds[Hash, Number, Signature, ID, Environment[Hash, Number, Signature, ID]]() + _, lastRoundState := bridgeState(NewRoundState(lastRoundBase)) + + if lastRoundNumber > 0 { + maybeCompletedLastRound := instantiateLastRound( + voters, lastRoundVotes, lastRoundNumber, lastRoundBase, finalizedSender, env) + + if maybeCompletedLastRound != nil { + lastRound := *maybeCompletedLastRound + lastRoundState = *lastRound.bridgeState() + pastRounds.Push(env, lastRound) + } + + // when there is no information about the last completed round, + // the best we can do is assume that the estimate == the given base + // and that it is finalized. This is always the case for the genesis + // round of a set. + } + + bestRound := newVotingRound( + lastRoundNumber+1, + voters, + lastFinalized, + &lastRoundState, + finalizedSender, + env, + ) + + inner := &innerVoterState[Hash, Number, Signature, ID, Environment[Hash, Number, Signature, ID]]{ + bestRound: bestRound, + pastRounds: *pastRounds, + } + globalOut := make(chan CommunicationOut) + return &Voter[Hash, Number, Signature, ID]{ + env: env, + voters: voters, + inner: inner, + finalizedNotifications: newWakerChan(finalizedNotifications), + lastFinalizedNumber: lastFinalizedNumber, + lastFinalizedInRounds: lastFinalized, + globalIn: newWakerChan(globalIn), + globalOut: newBuffered(globalOut), + stopChan: make(chan any), + stopTimeout: 30 * time.Second, + }, globalOut +} + +func (v *Voter[Hash, Number, Signature, ID]) pruneBackgroundRounds(waker *waker) error { + v.inner.Lock() + defer v.inner.Unlock() + +pastRounds: + for { + // Do work on all background rounds, broadcasting any commits generated. + ready, nc, err := v.inner.pastRounds.pollNext(waker) + switch ready { + case true: + if err != nil { + return err + } + if nc != nil { + co := newCommunicationOut(CommunicationOutCommit[Hash, Number, Signature, ID]{nc.Number, nc.Commit}) + v.globalOut.Push(co) + } else { + break pastRounds + } + case false: + break pastRounds + } + } + + v.finalizedNotifications.setWaker(waker) +finalizedNotifications: + for { + select { + case notif := <-v.finalizedNotifications.channel(): + fHash := notif.Hash + fNum := notif.Number + round := notif.Round + commit := notif.Commit + + v.inner.pastRounds.UpdateFinalized(fNum) + if v.setLastFinalizedNumber(fNum) { + err := v.env.FinalizeBlock(fHash, fNum, round, commit) + if err != nil { + return err + } + } + + if fNum > v.lastFinalizedInRounds.Number { + v.lastFinalizedInRounds = HashNumber[Hash, Number]{fHash, fNum} + } + default: + break finalizedNotifications + } + } + + return nil +} + +// Process all incoming messages from other nodes. +// +// Commit messages are handled with extra care. If a commit message references +// a currently backgrounded round, we send it to that round so that when we commit +// on that round, our commit message will be informed by those that we've seen. +// +// Otherwise, we will simply handle the commit and issue a finalisation command +// to the environment. +func (v *Voter[Hash, Number, Signature, ID]) processIncoming(waker *waker) error { //skipcq: GO-R1005 + v.globalIn.setWaker(waker) +loop: + for { + select { + case item := <-v.globalIn.channel(): + if item.Error != nil { + return item.Error + } + switch variant := item.CommunicationIn.variant.(type) { + case CommunicationInCommit[Hash, Number, Signature, ID]: + roundNumber := variant.Number + compactCommit := variant.CompactCommit + processCommitOutcome := variant.Callback + + log.Tracef("Got commit for round_number %+v: target_number: %+v, target_hash: %+v", + roundNumber, + compactCommit.TargetNumber, + compactCommit.TargetHash, + ) + + commit := compactCommit.Commit() + v.inner.Lock() + + // if the commit is for a background round dispatch to round committer. + // that returns Some if there wasn't one. + if imported := v.inner.pastRounds.ImportCommit(roundNumber, commit); imported != nil { + // otherwise validate the commit and signal the finalized block from the + // commit to the environment (if valid and higher than current finalized) + validationResult, err := ValidateCommit(commit, v.voters, v.env.(Chain[Hash, Number])) + if err != nil { + return err + } + if validationResult.Valid() { + lastFinalizedNumber := v.lastFinalizedNumber + + // clean up any background rounds + v.inner.pastRounds.UpdateFinalized(imported.TargetNumber) + + if imported.TargetNumber > lastFinalizedNumber { + v.lastFinalizedNumber = imported.TargetNumber + err := v.env.FinalizeBlock(imported.TargetHash, imported.TargetNumber, roundNumber, *imported) + if err != nil { + v.inner.Unlock() + return err + } + } + + outcome := CommitProcessingOutcome{CommitProcessingOutcomeGood(GoodCommit{})} + if processCommitOutcome != nil { + processCommitOutcome(outcome) + } + } else { + // Failing validation of a commit is bad. + outcome := CommitProcessingOutcome{CommitProcessingOutcomeBad(newBadCommit(validationResult))} + if processCommitOutcome != nil { + processCommitOutcome(outcome) + } + } + } else { + // Import to backgrounded round is good. + outcome := CommitProcessingOutcome{CommitProcessingOutcomeGood(GoodCommit{})} + if processCommitOutcome != nil { + processCommitOutcome(outcome) + } + } + v.inner.Unlock() + case CommunicationInCatchUp[Hash, Number, Signature, ID]: + catchUp := variant.CatchUp + processCatchUpOutcome := variant.Callback + + log.Tracef("Got catch-up message for round %v", catchUp.RoundNumber) + + v.inner.Lock() + + round := validateCatchUp(catchUp, v.env, v.voters, v.inner.bestRound.roundNumber()) + if round == nil { + if processCatchUpOutcome != nil { + processCatchUpOutcome(newCatchUpProcessingOutcome(CatchUpProcessingOutcomeBad{})) + } + return nil + } + + state := round.State() + + // beyond this point, we set this round to the past and + // start voting in the next round. + justCompleted := newVotingRoundCompleted(round, v.inner.bestRound.FinalizedSender(), nil, v.env) + + newBest := newVotingRound( + justCompleted.roundNumber()+1, + v.voters, + v.lastFinalizedInRounds, + justCompleted.bridgeState(), + v.inner.bestRound.FinalizedSender(), + v.env, + ) + + // update last-finalized in rounds _after_ starting new round. + // otherwise the base could be too eagerly set forward. + if state.Finalized != nil { + fNum := state.Finalized.Number + if fNum > v.lastFinalizedInRounds.Number { + v.lastFinalizedInRounds = *state.Finalized + } + } + + err := v.env.Completed( + justCompleted.roundNumber(), + justCompleted.roundState(), + justCompleted.dagBase(), + justCompleted.historicalVotes(), + ) + if err != nil { + v.inner.Unlock() + return err + } + + v.inner.pastRounds.Push(v.env, justCompleted) + + oldBest := v.inner.bestRound + v.inner.bestRound = newBest + v.inner.pastRounds.Push(v.env, oldBest) + + if processCatchUpOutcome != nil { + processCatchUpOutcome(newCatchUpProcessingOutcome(CatchUpProcessingOutcomeGood{})) + } + v.inner.Unlock() + } + default: + break loop + } + } + return nil +} + +// process the logic of the best round. +func (v *Voter[Hash, Number, Signature, ID]) processBestRound(waker *waker) (bool, error) { + // If the current `best_round` is completable and we've already precommitted, + // we start a new round at `best_round + 1`. + { + v.inner.Lock() + + var shouldStartNext bool + completable, err := v.inner.bestRound.poll(waker) + if err != nil { + return true, err + } + + var precomitted bool + state := v.inner.bestRound.State() + if state != nil { + _, precomitted = v.inner.bestRound.State().(statePrecommitted) + } + + shouldStartNext = completable && precomitted + + if !shouldStartNext { + v.inner.Unlock() + return false, nil + } + + log.Tracef("Best round at %v has become completable. Starting new best round at %v", + v.inner.bestRound.roundNumber(), + v.inner.bestRound.roundNumber()+1, + ) + v.inner.Unlock() + } + + err := v.completedBestRound() + if err != nil { + return true, err + } + + // round has been updated. so we need to re-poll. + return v.poll(waker) +} + +func (v *Voter[Hash, Number, Signature, ID]) completedBestRound() error { + v.inner.Lock() + defer v.inner.Unlock() + + err := v.env.Completed( + v.inner.bestRound.roundNumber(), + v.inner.bestRound.roundState(), + v.inner.bestRound.dagBase(), + v.inner.bestRound.historicalVotes(), + ) + if err != nil { + return err + } + + oldRoundNumber := v.inner.bestRound.roundNumber() + + nextRound := newVotingRound( + oldRoundNumber+1, + v.voters, + v.lastFinalizedInRounds, + v.inner.bestRound.bridgeState(), + v.inner.bestRound.FinalizedSender(), + v.env, + ) + + oldBest := v.inner.bestRound + v.inner.bestRound = nextRound + v.inner.pastRounds.Push(v.env, oldBest) + return nil +} + +func (v *Voter[Hash, Number, Signature, ID]) setLastFinalizedNumber(finalizedNumber Number) bool { + if finalizedNumber > v.lastFinalizedNumber { + v.lastFinalizedNumber = finalizedNumber + return true + } + return false +} + +func (v *Voter[Hash, Number, Signature, ID]) Start() error { //skipcq: RVV-B0001 + v.wg.Add(1) + defer v.wg.Done() + waker := newWaker() + for { + ready, err := v.poll(waker) + if err != nil { + return err + } + if ready { + return nil + } + select { + case <-waker.channel(): + case <-v.stopChan: + return fmt.Errorf("early voter stop") + } + } +} + +func (v *Voter[Hash, Number, Signature, ID]) Stop() error { + close(v.stopChan) + v.globalOut.Close() + timeout := time.NewTimer(v.stopTimeout) + wgDone := make(chan any) + go func() { + defer close(wgDone) + v.wg.Wait() + }() + select { + case <-timeout.C: + return fmt.Errorf("timeout for Voter.Stop()") + case <-wgDone: + } + return nil +} + +func (v *Voter[Hash, Number, Signature, ID]) poll(waker *waker) (bool, error) { //skipcq: RVV-B0001 + err := v.processIncoming(waker) + if err != nil { + return true, err + } + err = v.pruneBackgroundRounds(waker) + if err != nil { + return true, err + } + ready, err := v.globalOut.Poll(waker) + if !ready { + return false, nil + } + if err != nil { + return true, err + } + + return v.processBestRound(waker) +} + +type sharedVoteState[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +] struct { + inner *innerVoterState[Hash, Number, Signature, ID, E] + mtx sync.Mutex +} + +func (svs *sharedVoteState[Hash, Number, Signature, ID, E]) Get() VoterStateReport[ID] { + toRoundState := func(votingRound votingRound[Hash, Number, Signature, ID, E]) (uint64, RoundStateReport[ID]) { + return votingRound.roundNumber(), RoundStateReport[ID]{ + TotalWeight: votingRound.voters().TotalWeight(), + ThresholdWeight: votingRound.voters().Threshold(), + PrevoteCurrentWeight: votingRound.preVoteWeight(), + PrevoteIDs: votingRound.prevoteIDs(), + PrecommitCurrentWeight: votingRound.precommitWeight(), + PrecommitIDs: votingRound.precommitIDs(), + } + } + + svs.mtx.Lock() + defer svs.mtx.Unlock() + + bestRoundNum, bestRound := toRoundState(svs.inner.bestRound) + backgroundRounds := svs.inner.pastRounds.votingRounds() + mappedBackgroundRounds := make(map[uint64]RoundStateReport[ID]) + for _, backgroundRound := range backgroundRounds { + num, round := toRoundState(backgroundRound) + mappedBackgroundRounds[num] = round + } + return VoterStateReport[ID]{ + BackgroundRounds: mappedBackgroundRounds, + BestRound: struct { + Number uint64 + RoundState RoundStateReport[ID] + }{ + Number: bestRoundNum, + RoundState: bestRound, + }, + } +} + +// VoterState returns an object allowing to query the voter state. +func (v *Voter[Hash, Number, Signature, ID]) VoterState() VoterState[ID] { + return &sharedVoteState[Hash, Number, Signature, ID, Environment[Hash, Number, Signature, ID]]{ + inner: v.inner, + } +} + +// VoterState interface for querying the state of the voter. Used by `Voter` to return a queryable object +// without exposing too many data types. +type VoterState[ID comparable] interface { + // Returns a plain data type, `report::VoterState`, describing the current state + // of the voter relevant to the voting process. + Get() VoterStateReport[ID] +} + +// Validate the given catch up and return a completed round with all prevotes +// and precommits from the catch up imported. If the catch up is invalid `None` +// is returned instead. +func validateCatchUp[ //skipcq: GO-R1005 + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +]( + catchUp CatchUp[Hash, Number, Signature, ID], + env E, + voters VoterSet[ID], + bestRoundNumber uint64, +) *Round[ID, Hash, Number, Signature] { + if catchUp.RoundNumber <= bestRoundNumber { + log.Tracef("Ignoring because best round number is %d", bestRoundNumber) + return nil + } + + type prevotedPrecommitted struct { + prevoted bool + precommitted bool + } + // check threshold support in prevotes and precommits. + { + mapped := btree.NewMap[ID, prevotedPrecommitted](2) + + for _, prevote := range catchUp.Prevotes { + if !voters.Contains(prevote.ID) { + log.Tracef("Ignoring invalid catch up, invalid voter: %v", prevote.ID) + return nil + } + + entry, found := mapped.Get(prevote.ID) + if !found { + mapped.Set(prevote.ID, prevotedPrecommitted{true, false}) + } else { + entry.prevoted = true + mapped.Set(prevote.ID, entry) + } + } + + for _, precommit := range catchUp.Precommits { + if !voters.Contains(precommit.ID) { + log.Tracef("Ignoring invalid catch up, invalid voter: %v", precommit.ID) + return nil + } + + entry, found := mapped.Get(precommit.ID) + if !found { + mapped.Set(precommit.ID, prevotedPrecommitted{false, true}) + } else { + entry.precommitted = true + mapped.Set(precommit.ID, entry) + } + } + + var ( + pv VoteWeight + pc VoteWeight + ) + mapped.Scan(func(id ID, pp prevotedPrecommitted) bool { + prevoted := pp.prevoted + precommitted := pp.precommitted + + if vi := voters.Get(id); vi != nil { + if prevoted { + pv = pv + VoteWeight(vi.Weight()) + } + + if precommitted { + pc = pc + VoteWeight(vi.Weight()) + } + } + return true + }) + + threshold := voters.Threshold() + if pv < VoteWeight(threshold) || pc < VoteWeight(threshold) { + log.Tracef("Ignoring invalid catch up, missing voter threshold") + return nil + } + } + + round := NewRound[ID, Hash, Number, Signature](RoundParams[ID, Hash, Number]{ + catchUp.RoundNumber, voters, HashNumber[Hash, Number]{catchUp.BaseHash, catchUp.BaseNumber}, + }) + + // import prevotes first + for _, sp := range catchUp.Prevotes { + _, err := round.importPrevote(env, sp.Prevote, sp.ID, sp.Signature) + if err != nil { + log.Tracef("Ignoring invalid catch up, error importing prevote: %v", err) + return nil + } + } + + // then precommits. + for _, sp := range catchUp.Precommits { + _, err := round.importPrecommit(env, sp.Precommit, sp.ID, sp.Signature) + if err != nil { + log.Tracef("Ignoring invalid catch up, error importing precommit: %v", err) + return nil + } + } + + state := round.State() + if !state.Completable { + return nil + } + + return round +} diff --git a/pkg/finality-grandpa/voter_set.go b/pkg/finality-grandpa/voter_set.go new file mode 100644 index 0000000000..0d0b9ecb65 --- /dev/null +++ b/pkg/finality-grandpa/voter_set.go @@ -0,0 +1,185 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "github.com/tidwall/btree" + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) + +// IDVoterInfo is tuple for ID and VoterInfo +type IDVoterInfo[ID constraints.Ordered] struct { + ID ID + VoterInfo +} + +// VoterSet is a (non-empty) set of voters and associated weights. +// +// A `VoterSet` identifies all voters that are permitted to vote in a round +// of the protocol and their associated weights. A `VoterSet` is furthermore +// equipped with a total order, given by the ordering of the voter's IDs. +type VoterSet[ID constraints.Ordered] struct { + voters []IDVoterInfo[ID] + threshold VoterWeight + totalWeight VoterWeight +} + +// IDWeight is tuple for ID and Weight +type IDWeight[ID constraints.Ordered] struct { + ID ID + Weight VoterWeight +} + +// NewVoterSet creates a voter set from a weight distribution produced by the given iterator. +// +// If the distribution contains multiple weights for the same voter ID, they are +// understood to be partial weights and are accumulated. As a result, the +// order in which the iterator produces the weights is irrelevant. +// +// Returns `None` if the iterator does not yield a valid voter set, which is +// the case if it either produced no non-zero weights or, i.e. the voter set +// would be empty, or if the total voter weight exceeds `u64::MAX`. +func NewVoterSet[ID constraints.Ordered](weights []IDWeight[ID]) *VoterSet[ID] { + var totalWeight VoterWeight + var voters = btree.NewMap[ID, VoterInfo](2) + for _, iw := range weights { + if iw.Weight != 0 { + err := totalWeight.checkedAdd(iw.Weight) + if err != nil { + return nil + } + vi, has := voters.Get(iw.ID) + if !has { + voters.Set(iw.ID, VoterInfo{ + position: 0, // The total order is determined afterwards. + weight: iw.Weight, + }) + } else { + vi.weight = iw.Weight + voters.Set(iw.ID, vi) + } + } + } + + if voters.Len() == 0 { + return nil + } + + var orderedVoters = make([]IDVoterInfo[ID], voters.Len()) + var i uint + voters.Scan(func(id ID, info VoterInfo) bool { + info.position = i + orderedVoters[i] = IDVoterInfo[ID]{id, info} + i++ + return true + }) + + if totalWeight == 0 { + panic("weight can not be zero") + } + + return &VoterSet[ID]{ + voters: orderedVoters, + totalWeight: totalWeight, + threshold: threshold(totalWeight), + } +} + +// Get the voter info for the voter with the given ID, if any. +func (vs VoterSet[ID]) Get(id ID) *VoterInfo { + idx, ok := slices.BinarySearchFunc(vs.voters, IDVoterInfo[ID]{ID: id}, func(a, b IDVoterInfo[ID]) int { + switch { + case a.ID == b.ID: + return 0 + case a.ID > b.ID: + return 1 + case b.ID > a.ID: + return -1 + default: + panic("unreachable") + } + }) + if ok { + return &vs.voters[idx].VoterInfo + } + return nil +} + +// Len returns the size of the set. +func (vs VoterSet[ID]) Len() int { + return len(vs.voters) +} + +// Contains returns whether the set contains a voter with the given ID. +func (vs VoterSet[ID]) Contains(id ID) bool { + return vs.Get(id) != nil +} + +// NthMod gets the nth voter in the set, modulo the size of the set, +// as per the associated total order. +func (vs VoterSet[ID]) NthMod(n uint) IDVoterInfo[ID] { + ivi := vs.Nth(n % uint(len(vs.voters))) + if ivi == nil { + panic("set is nonempty and n % len < len; qed") + } + return *ivi +} + +// Nth gets the nth voter in the set, if any. +// +// Returns `None` if `n >= len`. +func (vs VoterSet[ID]) Nth(n uint) *IDVoterInfo[ID] { + if n >= uint(len(vs.voters)) { + return nil + } + return &IDVoterInfo[ID]{ + vs.voters[n].ID, + vs.voters[n].VoterInfo, + } +} + +// Threshold returns the threshold vote weight required for supermajority +// w.r.t. this set of voters. +func (vs VoterSet[ID]) Threshold() VoterWeight { + return vs.threshold +} + +// TotalWeight returns the total weight of all voters. +func (vs VoterSet[ID]) TotalWeight() VoterWeight { + return vs.totalWeight +} + +// Iter returns the voters in the set, as given by +// the associated total order. +func (vs VoterSet[ID]) Iter() []IDVoterInfo[ID] { + return vs.voters +} + +// VoterInfo is the information about a voter in a `VoterSet`. +type VoterInfo struct { + position uint + weight VoterWeight +} + +func (vi VoterInfo) Position() uint { + return vi.position +} + +func (vi VoterInfo) Weight() VoterWeight { + return vi.weight +} + +// Compute the threshold weight given the total voting weight. +func threshold(totalWeight VoterWeight) VoterWeight { //skipcq: RVV-B0001 + // TODO: implement saturating sub + // https://github.com/ChainSafe/gossamer/issues/3511 + // let faulty = total_weight.get().saturating_sub(1) / 3; + var faulty = (totalWeight - 1) / 3 + vw := totalWeight - faulty + if vw == 0 { + panic("subtrahend > minuend; qed") + } + return totalWeight - faulty +} diff --git a/pkg/finality-grandpa/voter_set_test.go b/pkg/finality-grandpa/voter_set_test.go new file mode 100644 index 0000000000..109aee0ebe --- /dev/null +++ b/pkg/finality-grandpa/voter_set_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "math" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/stretchr/testify/assert" +) + +func (VoterSet[ID]) Generate(rand *rand.Rand, _ int) reflect.Value { + for { + idsValue, ok := quick.Value(reflect.TypeOf(make([]ID, 0)), rand) + if !ok { + panic("unable to generate value") + } + ids := idsValue.Interface().([]ID) + weights := make([]IDWeight[ID], len(ids)) + for i, id := range ids { + u64v, ok := quick.Value(reflect.TypeOf(uint64(0)), rand) + if !ok { + panic("unable to generate value") + } + weights[i] = IDWeight[ID]{ + id, + VoterWeight(u64v.Interface().(uint64)), + } + } + set := NewVoterSet(weights) + if set == nil { + continue + } + return reflect.ValueOf(*set) + } +} + +func TestVoterSet_Equality(t *testing.T) { + f := func(v []IDWeight[uint]) bool { + v1 := NewVoterSet(v) + if v1 != nil { + rand := rand.New(rand.NewSource(time.Now().UnixNano())) //skipcq: GSC-G404 + rand.Shuffle(len(v), func(i, j int) { v[i], v[j] = v[j], v[i] }) + v2 := NewVoterSet(v) + assert.NotNil(t, v1) + return assert.Equal(t, v1, v2) + } + // either no authority has a valid weight + var noValIDWeight = true + for _, iw := range v { + if iw.Weight != 0 { + noValIDWeight = false + break + } + } + if noValIDWeight == true { + return true + } + // or the total weight overflows a u64 + sum := big.NewInt(0) + for _, iw := range v { + sum.Add(sum, new(big.Int).SetUint64(uint64(iw.Weight))) + } + return sum.Cmp(new(big.Int).SetUint64(uint64(math.MaxUint64))) > 0 + + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestVoterSet_TotalWeight(t *testing.T) { + f := func(v []IDWeight[uint]) bool { + totalWeight := big.NewInt(0) + for _, iw := range v { + totalWeight.Add(totalWeight, new(big.Int).SetUint64(uint64(iw.Weight))) + } + // this validator set is invalid + if totalWeight.Cmp(new(big.Int).SetUint64(uint64(math.MaxUint64))) > 0 { + return true + } + + expected := VoterWeight(totalWeight.Uint64()) + v1 := NewVoterSet(v) + if v1 != nil { + return assert.Equal(t, expected, v1.totalWeight) + } + return assert.Equal(t, expected, VoterWeight(0)) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func TestVoterSet_MinTreshold(t *testing.T) { + f := func(v VoterSet[uint]) bool { + t := v.threshold + w := v.totalWeight + return t >= 2*(w/3)+(w%3) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} diff --git a/pkg/finality-grandpa/voter_test.go b/pkg/finality-grandpa/voter_test.go new file mode 100644 index 0000000000..fd92f9bdc9 --- /dev/null +++ b/pkg/finality-grandpa/voter_test.go @@ -0,0 +1,703 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestVoter_TalkingToMyself(t *testing.T) { + var localID ID = 5 + voters := NewVoterSet([]IDWeight[ID]{ + {localID, 100}, + }) + + network := NewNetwork() + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + finalized := env.FinalizedStream() + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voters, + make(chan globalInItem), + 0, + nil, + lastFinalized, + lastFinalized, + ) + + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + done := make(chan any) + go func() { + defer close(done) + err := voter.Start() + // stops early, so this should return an error + assert.Error(t, err) + }() + + <-finalized + err := voter.Stop() + assert.NoError(t, err) + <-done +} + +func TestVoter_FinalizingAtFaultThreshold(t *testing.T) { + weights := make([]IDWeight[ID], 10) + for i := range weights { + weights[i] = IDWeight[ID]{ID(i), 1} + } + voters := NewVoterSet(weights) + + network := NewNetwork() + + var wg sync.WaitGroup + // 3 voters offline. + for i := 0; i < 7; i++ { + localID := ID(i) + // initialize chain + env := newEnvironment(network, localID) + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + finalized := env.FinalizedStream() + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voters, + make(chan globalInItem), + 0, + nil, + lastFinalized, + lastFinalized, + ) + + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + wg.Add(1) + go voter.Start() + go func() { + defer wg.Done() + <-finalized + err := voter.Stop() + assert.NoError(t, err) + }() + } + wg.Wait() +} + +func TestVoter_ExposingVoterState(t *testing.T) { + numVoters := 10 + votersOnline := 7 + + weights := make([]IDWeight[ID], numVoters) + for i := range weights { + weights[i] = IDWeight[ID]{ID(i), 1} + } + voterSet := NewVoterSet(weights) + + network := NewNetwork() + + var wg sync.WaitGroup + voters := make([]*Voter[string, uint32, Signature, ID], votersOnline) + voterStates := make([]VoterState[ID], votersOnline) + // some voters offline + for i := 0; i < votersOnline; i++ { + localID := ID(i) + // initialize chain + env := newEnvironment(network, localID) + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + finalized := env.FinalizedStream() + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + make(chan globalInItem), + 0, + nil, + lastFinalized, + lastFinalized, + ) + + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + voters[i] = voter + voterStates[i] = voter.VoterState() + + wg.Add(1) + go func() { + defer wg.Done() + <-finalized + }() + } + + voterState := voterStates[0] + for _, vs := range voterStates { + assert.Equal(t, vs.Get(), voterState.Get()) + } + + expectedRoundState := RoundStateReport[ID]{ + TotalWeight: VoterWeight(numVoters), + ThresholdWeight: VoterWeight(votersOnline), + PrevoteCurrentWeight: 0, + PrevoteIDs: nil, + PrecommitCurrentWeight: 0, + PrecommitIDs: nil, + } + + assert.Equal(t, + VoterStateReport[ID]{ + BackgroundRounds: make(map[uint64]RoundStateReport[ID]), + BestRound: struct { + Number uint64 + RoundState RoundStateReport[ID] + }{1, expectedRoundState}, + }, + voterState.Get(), + ) + + for _, v := range voters { + go v.Start() + } + wg.Wait() + + assert.Equal(t, + voterState.Get().BestRound, + struct { + Number uint64 + RoundState RoundStateReport[ID] + }{2, expectedRoundState}, + ) + + for _, v := range voters { + err := v.Stop() + assert.NoError(t, err) + } +} + +func TestVoter_BroadcastCommit(t *testing.T) { + localID := ID(5) + voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}}) + + network := NewNetwork() + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + make(chan globalInItem), + 0, + nil, + lastFinalized, + lastFinalized, + ) + + commitsIn := network.MakeGlobalComms(globalOut) + + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + go voter.Start() + <-commitsIn + + err := voter.Stop() + assert.NoError(t, err) +} + +func TestVoter_BroadcastCommitOnlyIfNewer(t *testing.T) { + localID := ID(5) + testID := ID(42) + voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}, {testID, 201}}) + + network := NewNetwork() + + commitsOut := make(chan CommunicationOut) + commitsIn := network.MakeGlobalComms(commitsOut) + + roundOut := make(chan Message[string, uint32]) + roundIn := network.MakeRoundComms(1, testID, roundOut) + + prevote := Prevote[string, uint32]{"E", 6} + precommit := Precommit[string, uint32]{"E", 6} + + commit := numberCommit[string, uint32, Signature, ID]{ + 1, Commit[string, uint32, Signature, ID]{ + TargetHash: "E", + TargetNumber: 6, + Precommits: []SignedPrecommit[string, uint32, Signature, ID]{ + { + Precommit: Precommit[string, uint32]{"E", 6}, + Signature: Signature(testID), + ID: testID, + }, + }, + }, + } + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + nil, + 0, + nil, + lastFinalized, + lastFinalized, + ) + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + go func() { + voter.Start() + }() + + item := <-roundIn + // wait for a prevote + assert.NoError(t, item.Error) + assert.IsType(t, Prevote[string, uint32]{}, item.SignedMessage.Message.value) + assert.Equal(t, localID, item.SignedMessage.ID) + + // send our prevote and precommit + votes := []Message[string, uint32]{newMessage(prevote), newMessage(precommit)} + for _, v := range votes { + roundOut <- v + } + +waitForPrecommit: + for { + item = <-roundIn + // wait for a precommit + assert.NoError(t, item.Error) + switch item.SignedMessage.Message.value.(type) { + case Precommit[string, uint32]: + if item.SignedMessage.ID == localID { + break waitForPrecommit + } + } + } + + // send our commit + co := newCommunicationOut(CommunicationOutCommit[string, uint32, Signature, ID](commit)) + commitsOut <- co + + timer := time.NewTimer(500 * time.Millisecond) + var commitCount int +waitForCommits: + for { + select { + case <-commitsIn: + commitCount++ + case <-timer.C: + break waitForCommits + } + } + assert.Equal(t, 1, commitCount) + + err := voter.Stop() + assert.NoError(t, err) +} + +func TestVoter_ImportCommitForAnyRound(t *testing.T) { + localID := ID(5) + testID := ID(42) + voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}, {testID, 201}}) + + network := NewNetwork() + commitsOut := make(chan CommunicationOut) + _ = network.MakeGlobalComms(commitsOut) + + commit := Commit[string, uint32, Signature, ID]{ + TargetHash: "E", + TargetNumber: 6, + Precommits: []SignedPrecommit[string, uint32, Signature, ID]{ + { + Precommit: Precommit[string, uint32]{"E", 6}, + Signature: Signature(testID), + ID: testID, + }, + }, + } + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + nil, + 0, + nil, + lastFinalized, + lastFinalized, + ) + + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + go func() { + voter.Start() + }() + + // Send the commit message + co := newCommunicationOut(CommunicationOutCommit[string, uint32, Signature, ID]{ + Number: 0, + Commit: commit, + }) + commitsOut <- co + + finalized := <-env.FinalizedStream() + assert.Equal(t, finalized.Commit, commit) + + err := voter.Stop() + assert.NoError(t, err) +} + +func TestVoter_SkipsToLatestRoundAfterCatchUp(t *testing.T) { + voterIDs := make([]ID, 3) + // 3 voters + weights := make([]IDWeight[ID], 3) + for i := range weights { + weights[i] = IDWeight[ID]{ID(i), 1} + voterIDs[i] = ID(i) + } + voterSet := NewVoterSet(weights) + totalWeight := voterSet.TotalWeight() + thresholdWeight := voterSet.Threshold() + + network := NewNetwork() + + // initialize unsynced voter at round 0 + localID := ID(4) + + env := newEnvironment(network, localID) + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + unsyncedVoter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + nil, + 0, + nil, + lastFinalized, + lastFinalized, + ) + globalIn := network.MakeGlobalComms(globalOut) + unsyncedVoter.globalIn = newWakerChan(globalIn) + + prevote := func(id uint32) SignedPrevote[string, uint32, Signature, ID] { + return SignedPrevote[string, uint32, Signature, ID]{ + Prevote: Prevote[string, uint32]{"C", 4}, + ID: ID(id), + Signature: Signature(99), + } + } + + precommit := func(id uint32) SignedPrecommit[string, uint32, Signature, ID] { + return SignedPrecommit[string, uint32, Signature, ID]{ + Precommit: Precommit[string, uint32]{"C", 4}, + ID: ID(id), + Signature: Signature(99), + } + } + + // send in a catch-up message for round 5. + ci := newCommunicationIn[string, uint32, Signature, ID](CommunicationInCatchUp[string, uint32, Signature, ID]{ + CatchUp: CatchUp[string, uint32, Signature, ID]{ + BaseNumber: 1, + BaseHash: GenesisHash, + RoundNumber: 5, + Prevotes: []SignedPrevote[string, uint32, Signature, ID]{prevote(0), prevote(1), prevote(2)}, + Precommits: []SignedPrecommit[string, uint32, Signature, ID]{precommit(0), precommit(1), precommit(2)}, + }, + }) + network.SendMessage(ci) + + voterState := unsyncedVoter.VoterState() + _, ok := voterState.Get().BackgroundRounds[5] + assert.False(t, ok) + + // spawn the voter in the background + go unsyncedVoter.Start() + + finalized := env.FinalizedStream() + + // wait until it's caught up, it should skip to round 6 and send a + // finality notification for the block that was finalized by catching + // up. + caughtUp := make(chan any) + go func() { + for { + report := voterState.Get() + if report.BestRound.Number == 6 { + close(caughtUp) + return + } + <-time.NewTimer(10 * time.Millisecond).C + } + }() + + <-caughtUp + <-finalized + assert.Equal(t, + struct { + Number uint64 + RoundState RoundStateReport[ID] + }{ + Number: 6, + RoundState: RoundStateReport[ID]{ + TotalWeight: totalWeight, + ThresholdWeight: thresholdWeight, + PrevoteCurrentWeight: 0, + PrevoteIDs: nil, + PrecommitCurrentWeight: 0, + PrecommitIDs: nil, + }, + }, + voterState.Get().BestRound) + + assert.Equal(t, + RoundStateReport[ID]{ + TotalWeight: totalWeight, + ThresholdWeight: thresholdWeight, + PrevoteCurrentWeight: 3, + PrevoteIDs: voterIDs, + PrecommitCurrentWeight: 3, + PrecommitIDs: voterIDs, + }, + voterState.Get().BackgroundRounds[5]) + + err := unsyncedVoter.Stop() + assert.NoError(t, err) +} + +func TestVoter_PickUpFromPriorWithoutGrandparentState(t *testing.T) { + localID := ID(5) + voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}}) + + network := NewNetwork() + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + // run voter in background. scheduling it to shut down at the end. + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + nil, + 10, + nil, + lastFinalized, + lastFinalized, + ) + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + + go voter.Start() + for finalized := range env.FinalizedStream() { + if finalized.Number >= 6 { + break + } + } + + err := voter.Stop() + assert.NoError(t, err) +} + +func TestVoter_PickUpFromPriorWithGrandparentStatus(t *testing.T) { + localID := ID(99) + weights := make([]IDWeight[ID], 100) + for i := range weights { + weights[i] = IDWeight[ID]{ID(i), 1} + } + voterSet := NewVoterSet(weights) + + network := NewNetwork() + + env := newEnvironment(network, localID) + + // initialize chain + var lastFinalized HashNumber[string, uint32] + env.WithChain(func(chain *dummyChain) { + chain.PushBlocks(GenesisHash, []string{"A", "B", "C", "D", "E"}) + lastFinalized.Hash, lastFinalized.Number = chain.LastFinalized() + }) + + lastRoundVotes := make([]SignedMessage[string, uint32, Signature, ID], 0) + + // round 1 state on disk: 67 prevotes for "E". 66 precommits for "D". 1 precommit "E". + // the round is completable, but the estimate ("E") is not finalized. + for id := 0; id < 67; id++ { + prevote := Prevote[string, uint32]{"E", 6} + var precommit Precommit[string, uint32] + if id < 66 { + precommit = Precommit[string, uint32]{"D", 5} + } else { + precommit = Precommit[string, uint32]{"E", 6} + } + + lastRoundVotes = append(lastRoundVotes, SignedMessage[string, uint32, Signature, ID]{ + Message: newMessage(prevote), + Signature: Signature(id), + ID: ID(id), + }) + + lastRoundVotes = append(lastRoundVotes, SignedMessage[string, uint32, Signature, ID]{ + Message: newMessage(precommit), + Signature: Signature(id), + ID: ID(id), + }) + + // round 2 has the same votes. + // + // this means we wouldn't be able to start round 3 until + // the estimate of round-1 moves backwards. + roundOut := make(chan Message[string, uint32]) + _ = network.MakeRoundComms(2, ID(id), roundOut) + msgs := []Message[string, uint32]{newMessage(prevote), newMessage(precommit)} + for _, msg := range msgs { + roundOut <- msg + } + } + + // round 1 fresh communication. we send one more precommit for "D" so the estimate + // moves backwards. + sender := ID(67) + roundOut := make(chan Message[string, uint32]) + _ = network.MakeRoundComms(1, sender, roundOut) + lastPrecommit := Precommit[string, uint32]{"D", 3} + roundOut <- newMessage(lastPrecommit) + + // run voter in background. scheduling it to shut down at the end. + voter, globalOut := NewVoter[string, uint32, Signature, ID]( + &env, + *voterSet, + nil, + 1, + lastRoundVotes, + lastFinalized, + lastFinalized, + ) + globalIn := network.MakeGlobalComms(globalOut) + voter.globalIn = newWakerChan(globalIn) + go voter.Start() + + // wait until we see a prevote on round 3 from our local ID, + // indicating that the round 3 has started. + roundIn := network.MakeRoundComms(3, ID(1000), nil) +waitForPrevote: + for sme := range roundIn { + if sme.Error != nil { + t.Errorf("wtf?") + } + + msg := sme.SignedMessage.Message.Value() + switch msg.(type) { + case Prevote[string, uint32]: + if sme.SignedMessage.ID == localID { + break waitForPrevote + } + } + } + + assert.Equal(t, [2]uint64{2, 1}, env.LastCompletedAndConcluded()) + + err := voter.Stop() + assert.NoError(t, err) +} + +func TestBuffered(_ *testing.T) { + in := make(chan int32) + buffered := newBuffered(in) + + run := true + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for run { + buffered.Push(999) + time.Sleep(1 * time.Millisecond) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for run { + buffered.flush(newWaker()) + time.Sleep(1 * time.Millisecond) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for range in { + } + }() + + time.Sleep(100 * time.Millisecond) + buffered.Close() + + run = false + wg.Wait() +} diff --git a/pkg/finality-grandpa/voting_round.go b/pkg/finality-grandpa/voting_round.go new file mode 100644 index 0000000000..47998f3134 --- /dev/null +++ b/pkg/finality-grandpa/voting_round.go @@ -0,0 +1,843 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +type stateStart[T any] [2]T + +type stateProposed[T any] [2]T + +type statePrevoting[T, W any] struct { + T T + W W +} + +type statePrevoted[T any] [1]T + +type statePrecommitted struct{} + +type states[T, W any] interface { + stateStart[T] | stateProposed[T] | statePrevoting[T, W] | statePrevoted[T] | statePrecommitted +} + +// The state of a voting round. +type state any + +func setState[T, W any, V states[T, W]](s *state, val V) { + *s = val +} + +func newState[T, W any, V states[T, W]](val V) state { + var s state + setState[T, W](&s, val) + return s +} + +type hashBestChain[Hash comparable, Number constraints.Unsigned] struct { + Hash Hash + BestChain BestChain[Hash, Number] +} + +// Whether we should vote in the current round (i.e. push votes to the sink.) +type voting uint + +const ( + // Voting is disabled for the current round. + votingNo voting = iota + // Voting is enabled for the current round (prevotes and precommits.) + votingYes + // Voting is enabled for the current round and we are the primary proposer + // (we can also push primary propose messages). + votingPrimary +) + +// Whether the voter should cast round votes (prevotes and precommits.) +func (v voting) isActive() bool { + return v == votingYes || v == votingPrimary +} + +// Whether the voter is the primary proposer. +func (v voting) isPrimary() bool { + return v == votingPrimary +} + +// Logic for a voter on a specific round. +type votingRound[ + Hash constraints.Ordered, + Number constraints.Unsigned, + Signature comparable, + ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +] struct { + env E + voting voting + // this is not an Option in the rust code. Using a pointer for copylocks + votes *Round[ID, Hash, Number, Signature] + incoming *wakerChan[SignedMessageError[Hash, Number, Signature, ID]] + outgoing *buffered[Message[Hash, Number]] + state state + bridgedRoundState *priorView[Hash, Number] + lastRoundState *latterView[Hash, Number] + primaryBlock *HashNumber[Hash, Number] + finalizedSender chan finalizedNotification[Hash, Number, Signature, ID] + bestFinalized *Commit[Hash, Number, Signature, ID] +} + +// Create a new voting round. +func newVotingRound[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +]( + roundNumber uint64, voters VoterSet[ID], base HashNumber[Hash, Number], + lastRoundState *latterView[Hash, Number], + finalizedSender chan finalizedNotification[Hash, Number, Signature, ID], env E, +) votingRound[Hash, Number, Signature, ID, E] { + outgoing := make(chan Message[Hash, Number]) + roundData := env.RoundData(roundNumber, outgoing) + roundParams := RoundParams[ID, Hash, Number]{ + RoundNumber: roundNumber, + Voters: voters, + Base: base, + } + + votes := NewRound[ID, Hash, Number, Signature](roundParams) + + primaryVoterID, _ := votes.PrimaryVoter() + var voting voting //nolint:govet + if roundData.VoterID != nil && *roundData.VoterID == primaryVoterID { + voting = votingPrimary + } else if roundData.VoterID != nil && votes.Voters().Contains(*roundData.VoterID) { + voting = votingYes + } else { + voting = votingNo + } + + return votingRound[Hash, Number, Signature, ID, E]{ + votes: votes, + voting: voting, + incoming: newWakerChan(roundData.Incoming), + outgoing: newBuffered(outgoing), + state: newState[Timer, hashBestChain[Hash, Number]]( + stateStart[Timer]{roundData.PrevoteTimer, roundData.PrecommitTimer}), + bridgedRoundState: nil, + primaryBlock: nil, + bestFinalized: nil, + env: env, + lastRoundState: lastRoundState, + finalizedSender: finalizedSender, + } +} + +// Create a voting round from a completed `Round`. We will not vote further +// in this round. +func newVotingRoundCompleted[ + Hash constraints.Ordered, Number constraints.Unsigned, Signature comparable, ID constraints.Ordered, + E Environment[Hash, Number, Signature, ID], +]( + votes *Round[ID, Hash, Number, Signature], + finalizedSender chan finalizedNotification[Hash, Number, Signature, ID], + lastRoundState *latterView[Hash, Number], + env E, +) votingRound[Hash, Number, Signature, ID, E] { + outgoing := make(chan Message[Hash, Number]) + roundData := env.RoundData(votes.Number(), outgoing) + return votingRound[Hash, Number, Signature, ID, E]{ + votes: votes, + voting: votingNo, + incoming: newWakerChan(roundData.Incoming), + outgoing: newBuffered(outgoing), + state: nil, + bridgedRoundState: nil, + primaryBlock: nil, + bestFinalized: nil, + env: env, + lastRoundState: lastRoundState, + finalizedSender: finalizedSender, + } +} + +// Poll the round. When the round is completable and messages have been flushed, it will return `Poll::Ready` but +// can continue to be polled. +func (vr *votingRound[Hash, Number, Signature, ID, E]) poll(waker *waker) (bool, error) { //skipcq: GO-R1005 + log.Tracef( + "Polling round %d, state = %+v, step = %T", + vr.votes.Number(), + vr.votes.State(), + vr.state, + ) + + preState := vr.votes.State() + err := vr.processIncoming(waker) + if err != nil { + return true, err + } + + // we only cast votes when we have access to the previous round state. + // we might have started this round as a prospect "future" round to + // check whether the voter is lagging behind the current round. + var lastRoundState *RoundState[Hash, Number] + if vr.lastRoundState != nil { + lrr := vr.lastRoundState.get(waker) + lastRoundState = &lrr + } + if lastRoundState != nil { + err := vr.primaryPropose(lastRoundState) + if err != nil { + return true, err + } + err = vr.prevote(waker, lastRoundState) + if err != nil { + return true, err + } + err = vr.precommit(waker, lastRoundState) + if err != nil { + return true, err + } + } + + ready, err := vr.outgoing.Poll(waker) + if !ready { + return false, nil + } + if err != nil { + return true, err + } + err = vr.processIncoming(waker) // in case we got a new message signed locally. + if err != nil { + return true, err + } + + // broadcast finality notifications after attempting to cast votes + postState := vr.votes.State() + vr.notify(preState, postState) + + completable := vr.votes.Completable() + // early exit if the current round is not completable + if !completable { + return false, nil + } + + // make sure that the previous round estimate has been finalized + var lastRoundEstimateFinalized bool + switch { + case lastRoundState != nil && lastRoundState.Estimate != nil && lastRoundState.Finalized != nil: + // either it was already finalized in the previous round + finalizedInLastRound := lastRoundState.Estimate.Number <= lastRoundState.Finalized.Number + + // or it must be finalized in the current round + var finalizedInCurrentRound bool + if vr.finalized() != nil { + finalizedInCurrentRound = lastRoundState.Estimate.Number <= vr.finalized().Number + } + + lastRoundEstimateFinalized = finalizedInLastRound || finalizedInCurrentRound + case lastRoundState == nil: + // NOTE: when we catch up to a round we complete the round + // without any last round state. in this case we already started + // a new round after we caught up so this guard is unneeded. + lastRoundEstimateFinalized = true + default: + lastRoundEstimateFinalized = false + } + + // the previous round estimate must be finalized + if !lastRoundEstimateFinalized { + log.Tracef("Round {} completable but estimate not finalized.", vr.roundNumber()) + vr.logParticipation(trace) + return false, nil + } + + log.Debugf( + "Completed round %d, state = %+v, step = %T", + vr.votes.Number(), + vr.votes.State(), + vr.state, + ) + + vr.logParticipation(debug) + return true, nil +} + +// Inspect the state of this round. +func (vr *votingRound[Hash, Number, Signature, ID, E]) State() any { + return vr.state +} + +// Get access to the underlying environment. +func (vr *votingRound[Hash, Number, Signature, ID, E]) Env() E { + return vr.env +} + +// Get the round number. +func (vr *votingRound[Hash, Number, Signature, ID, E]) roundNumber() uint64 { + return vr.votes.Number() +} + +// Get the round state. +func (vr *votingRound[Hash, Number, Signature, ID, E]) roundState() RoundState[Hash, Number] { + return vr.votes.State() +} + +// Get the base block in the dag. +func (vr *votingRound[Hash, Number, Signature, ID, E]) dagBase() HashNumber[Hash, Number] { + return vr.votes.Base() +} + +// Get the base block in the dag. +func (vr *votingRound[Hash, Number, Signature, ID, E]) voters() VoterSet[ID] { + return vr.votes.Voters() +} + +// Get the best block finalized in this round. +func (vr *votingRound[Hash, Number, Signature, ID, E]) finalized() *HashNumber[Hash, Number] { + return vr.votes.State().Finalized +} + +// Get the current total weight of prevotes. +func (vr *votingRound[Hash, Number, Signature, ID, E]) preVoteWeight() VoteWeight { + weight, _ := vr.votes.PrevoteParticipation() + return weight +} + +// Get the current total weight of precommits. +func (vr *votingRound[Hash, Number, Signature, ID, E]) precommitWeight() VoteWeight { + weight, _ := vr.votes.PrecommitParticipation() + return weight +} + +// Get the current total weight of prevotes. +func (vr *votingRound[Hash, Number, Signature, ID, E]) prevoteIDs() []ID { + var ids []ID + for _, pv := range vr.votes.Prevotes() { + ids = append(ids, pv.ID) + } + return ids +} + +// Get the current total weight of prevotes. +func (vr *votingRound[Hash, Number, Signature, ID, E]) precommitIDs() []ID { + var ids []ID + for _, pv := range vr.votes.Precommits() { + ids = append(ids, pv.ID) + } + return ids +} + +// Check a commit. If it's valid, import all the votes into the round as well. +// Returns the finalized base if it checks out. +func (vr *votingRound[Hash, Number, Signature, ID, E]) checkAndImportFromCommit( + commit Commit[Hash, Number, Signature, ID], +) (*HashNumber[Hash, Number], error) { + cvr, err := ValidateCommit[Hash, Number](commit, vr.voters(), vr.env) + if err != nil { + return nil, err + } + if !cvr.Valid() { + return nil, nil + } + + for _, signedPrecommit := range commit.Precommits { + precommit := signedPrecommit.Precommit + signature := signedPrecommit.Signature + id := signedPrecommit.ID + + importResult, err := vr.votes.importPrecommit(vr.env, precommit, id, signature) + if err != nil { + return nil, err + } + if importResult.Equivocation != nil { + vr.env.PrecommitEquivocation(vr.roundNumber(), *importResult.Equivocation) + } + } + + return &HashNumber[Hash, Number]{commit.TargetHash, commit.TargetNumber}, nil +} + +// Get a clone of the finalized sender. +func (vr *votingRound[Hash, Number, Signature, ID, E]) FinalizedSender() chan finalizedNotification[Hash, Number, Signature, ID] { //nolint:lll + return vr.finalizedSender +} + +// call this when we build on top of a given round in order to get a handle +// to updates to the latest round-state. +func (vr *votingRound[Hash, Number, Signature, ID, E]) bridgeState() *latterView[Hash, Number] { + priorView, latterView := bridgeState(vr.votes.State()) + if vr.bridgedRoundState != nil { + log.Warnf("Bridged state from round %d more than once", vr.votes.Number()) + } + + vr.bridgedRoundState = &priorView + return &latterView +} + +// Get a commit justifying the best finalized block. +func (vr *votingRound[Hash, Number, Signature, ID, E]) finalizingCommit() *Commit[Hash, Number, Signature, ID] { + return vr.bestFinalized +} + +// Return all votes for the round (prevotes and precommits), sorted by +// imported order and indicating the indices where we voted. At most two +// prevotes and two precommits per voter are present, further equivocations +// are not stored (as they are redundant). +func (vr *votingRound[Hash, Number, Signature, ID, E]) historicalVotes() HistoricalVotes[Hash, Number, Signature, ID] { + return vr.votes.HistoricalVotes() +} + +// Handle a vote manually. +func (vr *votingRound[Hash, Number, Signature, ID, E]) handleVote(vote SignedMessage[Hash, Number, Signature, ID]) error { //nolint:lll + message := vote.Message + if !vr.env.IsEqualOrDescendantOf(vr.votes.Base().Hash, message.Target().Hash) { + return nil + } + + switch message := message.Value().(type) { + case Prevote[Hash, Number]: + prevote := message + importResult, err := vr.votes.importPrevote(vr.env, prevote, vote.ID, vote.Signature) + if err != nil { + return err + } + if importResult.Equivocation != nil { + vr.env.PrevoteEquivocation(vr.votes.Number(), *importResult.Equivocation) + } + case Precommit[Hash, Number]: + precommit := message + importResult, err := vr.votes.importPrecommit(vr.env, precommit, vote.ID, vote.Signature) + if err != nil { + return err + } + if importResult.Equivocation != nil { + vr.env.PrecommitEquivocation(vr.votes.Number(), *importResult.Equivocation) + } + case PrimaryPropose[Hash, Number]: + primary := message + primaryID, _ := vr.votes.PrimaryVoter() + // note that id here refers to the party which has cast the vote + // and not the id of the party which has received the vote message. + if vote.ID == primaryID { + vr.primaryBlock = &HashNumber[Hash, Number]{primary.TargetHash, primary.TargetNumber} + } + } + + return nil +} + +func (vr *votingRound[Hash, Number, Signature, ID, E]) logParticipation(level logLevel) { + totalWeight := vr.voters().TotalWeight() + threshold := vr.voters().Threshold() + nVoters := vr.voters().Len() + number := vr.roundNumber() + + preVoteWeight, nPrevotes := vr.votes.PrevoteParticipation() + precommitWeight, nPrecommits := vr.votes.PrecommitParticipation() + + var logf func(format string, values ...any) + switch level { + case debug: + logf = log.Debugf + case trace: + logf = log.Tracef + } + + logf("%s: Round %d: prevotes: %d/%d/%d weight, %d/%d actual", + level, number, preVoteWeight, threshold, totalWeight, nPrevotes, nVoters) + + logf("%s: Round %d: precommits: %d/%d/%d weight, %d/%d actual", + level, number, precommitWeight, threshold, totalWeight, nPrecommits, nVoters) +} + +func (vr *votingRound[Hash, Number, Signature, ID, E]) processIncoming(waker *waker) error { + vr.incoming.setWaker(waker) + var ( + msgCount = 0 + timer *time.Timer + timerChan <-chan time.Time + ) +while: + for { + select { + case incoming := <-vr.incoming.channel(): + log.Tracef("Round %d: Got incoming message", vr.roundNumber()) + if timer != nil { + timer.Stop() + timer = nil + } + if incoming.Error != nil { + return incoming.Error + } + err := vr.handleVote(incoming.SignedMessage) + if err != nil { + return err + } + msgCount++ + case <-timerChan: + if msgCount > 0 { + log.Tracef("processed %d messages", msgCount) + } + break while + default: + if timer == nil { + // delay 1ms before exiting this loop + timer = time.NewTimer(1 * time.Millisecond) + timerChan = timer.C + } + } + } + return nil +} + +func (vr *votingRound[Hash, Number, Signature, ID, E]) primaryPropose(lastRoundState *RoundState[Hash, Number]) error { + state := vr.state + vr.state = nil + + if state == nil { + return nil + } + switch state := state.(type) { + case stateStart[Timer]: + prevoteTimer := state[0] + precommitTimer := state[1] + + maybeEstimate := lastRoundState.Estimate + switch { + case maybeEstimate != nil && vr.voting.isPrimary(): + lastRoundEstimate := maybeEstimate + maybeFinalized := lastRoundState.Finalized + + var shouldSendPrimary = true + if maybeFinalized != nil { + shouldSendPrimary = lastRoundEstimate.Number > maybeFinalized.Number + } + if shouldSendPrimary { + log.Debugf("Sending primary block hint for round %d", vr.votes.Number()) + primary := PrimaryPropose[Hash, Number]{ + TargetHash: lastRoundEstimate.Hash, + TargetNumber: lastRoundEstimate.Number, + } + err := vr.env.Proposed(vr.roundNumber(), primary) + if err != nil { + return err + } + message := newMessage(primary) + vr.outgoing.Push(message) + setState[Timer, hashBestChain[Hash, Number]](&vr.state, stateProposed[Timer]{prevoteTimer, precommitTimer}) + + return nil + } + log.Debugf( + "Last round estimate has been finalized, not sending primary block hint for round %d", + vr.votes.Number(), + ) + + case maybeEstimate == nil && vr.voting.isPrimary(): + log.Debugf("Last round estimate does not exist, not sending primary block hint for round %d", vr.votes.Number()) + default: + } + setState[Timer, hashBestChain[Hash, Number]](&vr.state, stateStart[Timer]{prevoteTimer, precommitTimer}) + default: + vr.state = state + } + return nil +} + +func (vr *votingRound[Hash, Number, Signature, ID, E]) prevote(w *waker, lastRoundState *RoundState[Hash, Number]) error { //nolint:lll //skipcq: GO-R1005 + state := vr.state + vr.state = nil + + var startPrevoting = func(prevoteTimer Timer, precommitTimer Timer, proposed bool, waker *waker) error { + prevoteTimer.SetWaker(waker) + var shouldPrevote bool + elapsed, err := prevoteTimer.Elapsed() + if elapsed { + if err != nil { + return err + } + shouldPrevote = true + } else { + shouldPrevote = vr.votes.Completable() + } + + if shouldPrevote { + if vr.voting.isActive() { + log.Debugf("Constructing prevote for round %d", vr.votes.Number()) + + base, bestChain := vr.constructPrevote(lastRoundState) + + // since we haven't polled the future above yet we need to + // manually schedule the current task to be awoken so the + // `best_chain` future is then polled below after we switch the + // state to `Prevoting`. + waker.wake() + + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrevoting[Timer, hashBestChain[Hash, Number]]{ + precommitTimer, hashBestChain[Hash, Number]{base, bestChain}, + }) + } else { + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrevoted[Timer]{precommitTimer}) + } + } else if proposed { + setState[Timer, hashBestChain[Hash, Number]](&vr.state, stateProposed[Timer]{prevoteTimer, precommitTimer}) + } else { + setState[Timer, hashBestChain[Hash, Number]](&vr.state, stateStart[Timer]{prevoteTimer, precommitTimer}) + } + + return nil + } + + var finishPrevoting = func(precommitTimer Timer, base Hash, bestChain BestChain[Hash, Number], waker *waker) error { + wakerChan := newWakerChan(bestChain) + wakerChan.setWaker(waker) + var best *HashNumber[Hash, Number] + res := <-wakerChan.channel() + switch { + case res.Error != nil: + return res.Error + case res.Value != nil: + best = res.Value + default: + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrevoting[Timer, hashBestChain[Hash, Number]]{ + precommitTimer, hashBestChain[Hash, Number]{base, bestChain}, + }) + return nil + } + + if best != nil { + prevote := Prevote[Hash, Number]{best.Hash, best.Number} + + log.Debugf("Casting prevote for round {}", vr.votes.Number()) + err := vr.env.Prevoted(vr.roundNumber(), prevote) + if err != nil { + return err + } + vr.votes.SetPrevotedIdx() + message := newMessage(prevote) + vr.outgoing.Push(message) + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrevoted[Timer]{precommitTimer}) + } else { + log.Warnf("Could not cast prevote: previously known block %v has disappeared", base) + + // when we can't construct a prevote, we shouldn't precommit. + vr.state = nil + vr.voting = votingNo + } + + return nil + } + + if state == nil { + return nil + } + switch state := state.(type) { + case stateStart[Timer]: + return startPrevoting(state[0], state[1], false, w) + case stateProposed[Timer]: + return startPrevoting(state[0], state[1], true, w) + case statePrevoting[Timer, hashBestChain[Hash, Number]]: + return finishPrevoting(state.T, state.W.Hash, state.W.BestChain, w) + default: + vr.state = state + } + + return nil +} + +func (vr *votingRound[Hash, Number, Signature, ID, E]) precommit(waker *waker, lastRoundState *RoundState[Hash, Number]) error { //nolint:lll + state := vr.state + vr.state = nil + if state == nil { + return nil + } + switch state := state.(type) { + case statePrevoted[Timer]: + precommitTimer := state[0] + precommitTimer.SetWaker(waker) + lastRoundEstimate := lastRoundState.Estimate + if lastRoundEstimate == nil { + panic("Rounds only started when prior round completable; qed") + } + + var shouldPrecommit bool + var ls bool + st := vr.votes.State() + pg := st.PrevoteGHOST + if pg != nil { + ls = *pg == *lastRoundEstimate || vr.env.IsEqualOrDescendantOf(lastRoundEstimate.Hash, pg.Hash) + } + var rs bool + elapsed, err := precommitTimer.Elapsed() + if elapsed { + if err != nil { + return err + } else { + rs = true + } + } else { + rs = vr.votes.Completable() + } + shouldPrecommit = ls && rs + + if shouldPrecommit { + if vr.voting.isActive() { + log.Debugf("Casting precommit for round {}", vr.votes.Number()) + precommit := vr.constructPrecommit() + err := vr.env.Precommitted(vr.roundNumber(), precommit) + if err != nil { + return err + } + vr.votes.SetPrecommittedIdx() + message := newMessage(precommit) + vr.outgoing.Push(message) + } + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrecommitted{}) + } else { + setState[Timer, hashBestChain[Hash, Number]](&vr.state, statePrevoted[Timer]{precommitTimer}) + } + default: + vr.state = state + } + + return nil +} + +// construct a prevote message based on local state. +func (vr *votingRound[Hash, Number, Signature, ID, E]) constructPrevote(lastRoundState *RoundState[Hash, Number]) (h Hash, bc BestChain[Hash, Number]) { //nolint:lll + lastRoundEstimate := lastRoundState.Estimate + if lastRoundEstimate == nil { + panic("Rounds only started when prior round completable; qed") + } + + var findDescendentOf Hash + switch primaryBlock := vr.primaryBlock; primaryBlock { + case nil: + // vote for best chain containing prior round-estimate. + findDescendentOf = lastRoundEstimate.Hash + default: + // we will vote for the best chain containing `p_hash` iff + // the last round's prevote-GHOST included that block and + // that block is a strict descendent of the last round-estimate that we are + // aware of. + lastPrevoteG := lastRoundState.PrevoteGHOST + if lastPrevoteG == nil { + panic("Rounds only started when prior round completable; qed") + } + + // if the blocks are equal, we don't check ancestry. + if *primaryBlock == *lastPrevoteG { + findDescendentOf = primaryBlock.Hash + } else if primaryBlock.Hash >= lastPrevoteG.Hash { + findDescendentOf = lastRoundEstimate.Hash + } else { + // from this point onwards, the number of the primary-broadcasted + // block is less than the last prevote-GHOST's number. + // if the primary block is in the ancestry of p-G we vote for the + // best chain containing it. + pHash := primaryBlock.Hash + pNum := primaryBlock.Number + ancestry, err := vr.env.Ancestry(lastRoundEstimate.Hash, lastPrevoteG.Hash) + if err != nil { + // This is only possible in case of massive equivocation + log.Warnf( + "Possible case of massive equivocation: last round prevote GHOST: %v"+ + " is not a descendant of last round estimate: %v", + lastPrevoteG, + lastRoundEstimate, + ) + findDescendentOf = lastRoundEstimate.Hash + } else { + toSub := pNum + 1 + + var offset uint + if lastPrevoteG.Number < toSub { + offset = 0 + } else { + offset = uint(lastPrevoteG.Number - toSub) + } + + if offset >= uint(len(ancestry)) { + findDescendentOf = lastRoundEstimate.Hash + } else { + if ancestry[offset] == pHash { + findDescendentOf = pHash + } else { + findDescendentOf = lastRoundEstimate.Hash + } + } + } + } + } + + return findDescendentOf, vr.env.BestChainContaining(findDescendentOf) +} + +// construct a precommit message based on local state. +func (vr *votingRound[Hash, Number, Signature, ID, E]) constructPrecommit() Precommit[Hash, Number] { + var t HashNumber[Hash, Number] + switch target := vr.votes.State().PrevoteGHOST; target { + case nil: + t = vr.votes.Base() + default: + t = *target + } + return Precommit[Hash, Number]{t.Hash, t.Number} +} + +// notify when new blocks are finalized or when the round-estimate is updated +func (vr *votingRound[Hash, Number, Signature, ID, E]) notify( + lastState RoundState[Hash, Number], + newState RoundState[Hash, Number], +) { + // `RoundState` attributes have pointers to values so comparison here is on pointer address. + // It's assumed that the `Round` attributes will use a new address for new values. + // Given the caller of this function, we know that new values will use new addresses + // so no need for deep value comparison. + if lastState != newState { + if vr.bridgedRoundState != nil { + vr.bridgedRoundState.update(newState) + } + } + + // send notification only when the round is completable and we've cast votes. + // this is a workaround that ensures when we re-instantiate the voter after + // a shutdown, we never re-create the same round with a base that was finalized + // in this round or after. + // we try to notify if either the round state changed or if we haven't + // sent any notification yet (this is to guard against seeing enough + // votes to finalize before having precommited) + stateChanged := lastState.Finalized != newState.Finalized + sentFinalityNotifications := vr.bestFinalized != nil + + if newState.Completable && (stateChanged || !sentFinalityNotifications) { + _, precommited := vr.state.(statePrecommitted) + // we only cast votes when we have access to the previous round state, + // which won't be the case whenever we catch up to a later round. + cantVote := vr.lastRoundState == nil + + if precommited || cantVote { + if newState.Finalized != nil { + precommits := vr.votes.FinalizingPrecommits(vr.env) + if precommits == nil { + panic("always returns none if something was finalized; this is checked above; qed") + } + commit := Commit[Hash, Number, Signature, ID]{ + TargetHash: newState.Finalized.Hash, + TargetNumber: newState.Finalized.Number, + Precommits: *precommits, + } + vr.finalizedSender <- finalizedNotification[Hash, Number, Signature, ID]{ + Hash: newState.Finalized.Hash, + Number: newState.Finalized.Number, + Round: vr.votes.Number(), + Commit: commit, + } + vr.bestFinalized = &commit + } + } + } + +} diff --git a/pkg/finality-grandpa/weights.go b/pkg/finality-grandpa/weights.go new file mode 100644 index 0000000000..6d861d56c8 --- /dev/null +++ b/pkg/finality-grandpa/weights.go @@ -0,0 +1,24 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "fmt" + "math" + "math/big" +) + +type VoteWeight uint64 + +type VoterWeight uint64 + +func (vw *VoterWeight) checkedAdd(add VoterWeight) (err error) { + sum := new(big.Int).SetUint64(uint64(*vw)) + sum.Add(sum, new(big.Int).SetUint64(uint64(add))) + if sum.Cmp(new(big.Int).SetUint64(uint64(math.MaxUint64))) > 0 { + return fmt.Errorf("VoterWeight overflow for CheckedAdd") + } + *vw = VoterWeight(sum.Uint64()) + return nil +}