diff --git a/pkg/finality-grandpa/environment_test.go b/pkg/finality-grandpa/environment_test.go index 07698cd3f5..94a9480d57 100644 --- a/pkg/finality-grandpa/environment_test.go +++ b/pkg/finality-grandpa/environment_test.go @@ -205,13 +205,14 @@ type BroadcastNetwork[M, N any] struct { senders []chan M history []M routing bool + wg sync.WaitGroup } -func NewBroadcastNetwork[M, N any]() BroadcastNetwork[M, N] { +func NewBroadcastNetwork[M, N any]() *BroadcastNetwork[M, N] { bn := BroadcastNetwork[M, N]{ receiver: make(chan M, 10000), } - return bn + return &bn } func (bm *BroadcastNetwork[M, N]) SendMessage(message M) { @@ -231,6 +232,7 @@ func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) { if !bm.routing { bm.routing = true + bm.wg.Add(1) go bm.route() } @@ -243,6 +245,7 @@ func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) { } func (bm *BroadcastNetwork[M, N]) route() { + defer bm.wg.Done() for msg := range bm.receiver { bm.history = append(bm.history, msg) for _, sender := range bm.senders { @@ -251,8 +254,13 @@ func (bm *BroadcastNetwork[M, N]) route() { } } +func (bm *BroadcastNetwork[M, N]) Stop() { + close(bm.receiver) + bm.wg.Wait() +} + type RoundNetwork struct { - BroadcastNetwork[SignedMessageError[string, uint32, Signature, ID], Message[string, uint32]] + *BroadcastNetwork[SignedMessageError[string, uint32, Signature, ID], Message[string, uint32]] } func NewRoundNetwork() *RoundNetwork { @@ -269,7 +277,7 @@ func (rn *RoundNetwork) AddNode( } type GlobalMessageNetwork struct { - BroadcastNetwork[globalInItem, CommunicationOut] + *BroadcastNetwork[globalInItem, CommunicationOut] } func NewGlobalMessageNetwork() *GlobalMessageNetwork { @@ -299,6 +307,13 @@ func NewNetwork() *Network { } } +func (n *Network) Stop() { + for _, rn := range n.rounds { + rn.Stop() + } + n.globalMessages.Stop() +} + func (n *Network) MakeRoundComms( roundNumber uint64, nodeID ID, diff --git a/pkg/finality-grandpa/voter_test.go b/pkg/finality-grandpa/voter_test.go index 013d8b5e15..5117c6cb12 100644 --- a/pkg/finality-grandpa/voter_test.go +++ b/pkg/finality-grandpa/voter_test.go @@ -18,6 +18,7 @@ func TestVoter_TalkingToMyself(t *testing.T) { }) network := NewNetwork() + defer network.Stop() env := newEnvironment(network, localID) @@ -64,6 +65,7 @@ func TestVoter_FinalizingAtFaultThreshold(t *testing.T) { voters := NewVoterSet(weights) network := NewNetwork() + defer network.Stop() var wg sync.WaitGroup // 3 voters offline. @@ -115,6 +117,7 @@ func TestVoter_ExposingVoterState(t *testing.T) { voterSet := NewVoterSet(weights) network := NewNetwork() + defer network.Stop() var wg sync.WaitGroup voters := make([]*Voter[string, uint32, Signature, ID], votersOnline) @@ -204,6 +207,7 @@ func TestVoter_BroadcastCommit(t *testing.T) { voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}}) network := NewNetwork() + defer network.Stop() env := newEnvironment(network, localID) @@ -243,6 +247,7 @@ func TestVoter_BroadcastCommitOnlyIfNewer(t *testing.T) { voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}, {testID, 201}}) network := NewNetwork() + defer network.Stop() commitsOut := make(chan CommunicationOut) commitsIn := network.MakeGlobalComms(commitsOut) @@ -345,6 +350,8 @@ func TestVoter_ImportCommitForAnyRound(t *testing.T) { voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}, {testID, 201}}) network := NewNetwork() + defer network.Stop() + commitsOut := make(chan CommunicationOut) _ = network.MakeGlobalComms(commitsOut) @@ -414,6 +421,7 @@ func TestVoter_SkipsToLatestRoundAfterCatchUp(t *testing.T) { thresholdWeight := voterSet.Threshold() network := NewNetwork() + defer network.Stop() // initialize unsynced voter at round 0 localID := ID(4) @@ -528,6 +536,7 @@ func TestVoter_PickUpFromPriorWithoutGrandparentState(t *testing.T) { voterSet := NewVoterSet([]IDWeight[ID]{{localID, 100}}) network := NewNetwork() + defer network.Stop() env := newEnvironment(network, localID) @@ -571,6 +580,7 @@ func TestVoter_PickUpFromPriorWithGrandparentStatus(t *testing.T) { voterSet := NewVoterSet(weights) network := NewNetwork() + defer network.Stop() env := newEnvironment(network, localID)