From 730fb75a85449792ef4d5af45c18c22ea2cf6b9d Mon Sep 17 00:00:00 2001 From: Bogdan Rosianu Date: Mon, 25 Sep 2023 15:28:58 +0300 Subject: [PATCH] map counters holder component --- .../availabilityProvider.go | 5 + .../availabilityProvider_test.go | 15 ++ observer/circularQueueNodesProvider.go | 45 ++--- observer/circularQueueNodesProvider_test.go | 26 +-- observer/mapCounters/mapCounter.go | 56 ++++++ observer/mapCounters/mapCounter_test.go | 107 +++++++++++ observer/mapCounters/mapCountersHolder.go | 63 +++++++ .../mapCounters/mapCountersHolder_test.go | 178 ++++++++++++++++++ 8 files changed, 451 insertions(+), 44 deletions(-) create mode 100644 observer/mapCounters/mapCounter.go create mode 100644 observer/mapCounters/mapCounter_test.go create mode 100644 observer/mapCounters/mapCountersHolder.go create mode 100644 observer/mapCounters/mapCountersHolder_test.go diff --git a/observer/availabilityCommon/availabilityProvider.go b/observer/availabilityCommon/availabilityProvider.go index 1d6e41db..2ae8b142 100644 --- a/observer/availabilityCommon/availabilityProvider.go +++ b/observer/availabilityCommon/availabilityProvider.go @@ -46,3 +46,8 @@ func (ap *AvailabilityProvider) GetDescriptionForAvailability(availability data. return "N/A" } } + +// GetAllAvailabilityTypes returns all data availability types +func (ap *AvailabilityProvider) GetAllAvailabilityTypes() []data.ObserverDataAvailabilityType { + return []data.ObserverDataAvailabilityType{data.AvailabilityAll, data.AvailabilityRecent} +} diff --git a/observer/availabilityCommon/availabilityProvider_test.go b/observer/availabilityCommon/availabilityProvider_test.go index c1e7f725..18b08a38 100644 --- a/observer/availabilityCommon/availabilityProvider_test.go +++ b/observer/availabilityCommon/availabilityProvider_test.go @@ -10,6 +10,8 @@ import ( ) func TestAvailabilityForAccountQueryOptions(t *testing.T) { + t.Parallel() + ap := &AvailabilityProvider{} // Test with historical coordinates set @@ -22,6 +24,8 @@ func TestAvailabilityForAccountQueryOptions(t *testing.T) { } func TestAvailabilityForVmQuery(t *testing.T) { + t.Parallel() + ap := &AvailabilityProvider{} // Test with BlockNonce set @@ -38,6 +42,8 @@ func TestAvailabilityForVmQuery(t *testing.T) { } func TestIsNodeValid(t *testing.T) { + t.Parallel() + ap := &AvailabilityProvider{} // Test with AvailabilityRecent and snapshotless node @@ -58,9 +64,18 @@ func TestIsNodeValid(t *testing.T) { } func TestGetDescriptionForAvailability(t *testing.T) { + t.Parallel() + ap := &AvailabilityProvider{} require.Equal(t, "regular nodes", ap.GetDescriptionForAvailability(data.AvailabilityAll)) require.Equal(t, "snapshotless nodes", ap.GetDescriptionForAvailability(data.AvailabilityRecent)) require.Equal(t, "N/A", ap.GetDescriptionForAvailability("invalid")) // Invalid value } + +func TestAvailabilityProvider_GetAllAvailabilityTypes(t *testing.T) { + t.Parallel() + + ap := &AvailabilityProvider{} + require.Equal(t, []data.ObserverDataAvailabilityType{data.AvailabilityAll, data.AvailabilityRecent}, ap.GetAllAvailabilityTypes()) +} diff --git a/observer/circularQueueNodesProvider.go b/observer/circularQueueNodesProvider.go index ec5f1a16..40dda1c7 100644 --- a/observer/circularQueueNodesProvider.go +++ b/observer/circularQueueNodesProvider.go @@ -1,18 +1,15 @@ package observer import ( - "sync" - "github.com/multiversx/mx-chain-proxy-go/data" + "github.com/multiversx/mx-chain-proxy-go/observer/mapCounters" ) // circularQueueNodesProvider will handle the providing of observers in a circular queue way, guaranteeing the // balancing of them type circularQueueNodesProvider struct { *baseNodeProvider - countersMap map[uint32]uint32 - counterForAllNodes uint32 - mutCounters sync.RWMutex + positionsHolder *mapCounters.MapCountersHolder } // NewCircularQueueNodesProvider returns a new instance of circularQueueNodesProvider @@ -26,11 +23,9 @@ func NewCircularQueueNodesProvider(observers []*data.NodeData, configurationFile return nil, err } - countersMap := make(map[uint32]uint32) return &circularQueueNodesProvider{ - baseNodeProvider: bop, - countersMap: countersMap, - counterForAllNodes: 0, + baseNodeProvider: bop, + positionsHolder: mapCounters.NewMapCountersHolder(), }, nil } @@ -44,7 +39,11 @@ func (cqnp *circularQueueNodesProvider) GetNodesByShardId(shardId uint32, dataAv return nil, err } - position := cqnp.computeCounterForShard(shardId, uint32(len(syncedNodesForShard))) + position, err := cqnp.positionsHolder.ComputeShardPosition(dataAvailability, shardId, uint32(len(syncedNodesForShard))) + if err != nil { + return nil, err + } + sliceToRet := append(syncedNodesForShard[position:], syncedNodesForShard[:position]...) return sliceToRet, nil @@ -60,32 +59,16 @@ func (cqnp *circularQueueNodesProvider) GetAllNodes(dataAvailability data.Observ return nil, err } - position := cqnp.computeCounterForAllNodes(uint32(len(allNodes))) + position, err := cqnp.positionsHolder.ComputeAllNodesPosition(dataAvailability, uint32(len(allNodes))) + if err != nil { + return nil, err + } + sliceToRet := append(allNodes[position:], allNodes[:position]...) return sliceToRet, nil } -func (cqnp *circularQueueNodesProvider) computeCounterForShard(shardID uint32, lenNodes uint32) uint32 { - cqnp.mutCounters.Lock() - defer cqnp.mutCounters.Unlock() - - cqnp.countersMap[shardID]++ - cqnp.countersMap[shardID] %= lenNodes - - return cqnp.countersMap[shardID] -} - -func (cqnp *circularQueueNodesProvider) computeCounterForAllNodes(lenNodes uint32) uint32 { - cqnp.mutCounters.Lock() - defer cqnp.mutCounters.Unlock() - - cqnp.counterForAllNodes++ - cqnp.counterForAllNodes %= lenNodes - - return cqnp.counterForAllNodes -} - // IsInterfaceNil returns true if there is no value under the interface func (cqnp *circularQueueNodesProvider) IsInterfaceNil() bool { return cqnp == nil diff --git a/observer/circularQueueNodesProvider_test.go b/observer/circularQueueNodesProvider_test.go index 209978d7..41372a28 100644 --- a/observer/circularQueueNodesProvider_test.go +++ b/observer/circularQueueNodesProvider_test.go @@ -52,7 +52,7 @@ func TestCircularQueueObserversProvider_GetObserversByShardIdShouldWork(t *testi cfg := getDummyConfig() cqop, _ := NewCircularQueueNodesProvider(cfg.Observers, "path") - res, err := cqop.GetNodesByShardId(shardId, "") + res, err := cqop.GetNodesByShardId(shardId, data.AvailabilityAll) assert.Nil(t, err) assert.Equal(t, 1, len(res)) } @@ -79,14 +79,14 @@ func TestCircularQueueObserversProvider_GetObserversByShardIdShouldBalanceObserv } cqop, _ := NewCircularQueueNodesProvider(cfg.Observers, "path") - res1, _ := cqop.GetNodesByShardId(shardId, "") - res2, _ := cqop.GetNodesByShardId(shardId, "") + res1, _ := cqop.GetNodesByShardId(shardId, data.AvailabilityAll) + res2, _ := cqop.GetNodesByShardId(shardId, data.AvailabilityAll) assert.NotEqual(t, res1, res2) // there are 3 observers. so after 3 steps, the queue should be the same as the original - _, _ = cqop.GetNodesByShardId(shardId, "") + _, _ = cqop.GetNodesByShardId(shardId, data.AvailabilityAll) - res4, _ := cqop.GetNodesByShardId(shardId, "") + res4, _ := cqop.GetNodesByShardId(shardId, data.AvailabilityAll) assert.Equal(t, res1, res4) } @@ -96,7 +96,7 @@ func TestCircularQueueObserversProvider_GetAllObserversShouldWork(t *testing.T) cfg := getDummyConfig() cqop, _ := NewCircularQueueNodesProvider(cfg.Observers, "path") - res, err := cqop.GetAllNodes("") + res, err := cqop.GetAllNodes(data.AvailabilityAll) assert.NoError(t, err) assert.Equal(t, 2, len(res)) } @@ -122,14 +122,14 @@ func TestCircularQueueObserversProvider_GetAllObserversShouldWorkAndBalanceObser } cqop, _ := NewCircularQueueNodesProvider(cfg.Observers, "path") - res1, _ := cqop.GetAllNodes("") - res2, _ := cqop.GetAllNodes("") + res1, _ := cqop.GetAllNodes(data.AvailabilityAll) + res2, _ := cqop.GetAllNodes(data.AvailabilityAll) assert.NotEqual(t, res1, res2) // there are 3 observers. so after 3 steps, the queue should be the same as the original - _, _ = cqop.GetAllNodes("") + _, _ = cqop.GetAllNodes(data.AvailabilityAll) - res4, _ := cqop.GetAllNodes("") + res4, _ := cqop.GetAllNodes(data.AvailabilityAll) assert.Equal(t, res1, res4) } @@ -172,7 +172,7 @@ func TestCircularQueueObserversProvider_GetAllObservers_ConcurrentSafe(t *testin for i := 0; i < numOfGoRoutinesToStart; i++ { for j := 0; j < numOfTimesToCallForEachRoutine; j++ { go func(mutMap *sync.RWMutex, mapCalledObs map[string]int) { - obs, _ := cqop.GetAllNodes("") + obs, _ := cqop.GetAllNodes(data.AvailabilityAll) mutMap.Lock() mapCalledObs[obs[0].Address]++ mutMap.Unlock() @@ -232,8 +232,8 @@ func TestCircularQueueObserversProvider_GetObserversByShardId_ConcurrentSafe(t * for i := 0; i < numOfGoRoutinesToStart; i++ { for j := 0; j < numOfTimesToCallForEachRoutine; j++ { go func(mutMap *sync.RWMutex, mapCalledObs map[string]int) { - obsSh0, _ := cqop.GetNodesByShardId(shardId0, "") - obsSh1, _ := cqop.GetNodesByShardId(shardId1, "") + obsSh0, _ := cqop.GetNodesByShardId(shardId0, data.AvailabilityAll) + obsSh1, _ := cqop.GetNodesByShardId(shardId1, data.AvailabilityAll) mutMap.Lock() mapCalledObs[obsSh0[0].Address]++ mapCalledObs[obsSh1[0].Address]++ diff --git a/observer/mapCounters/mapCounter.go b/observer/mapCounters/mapCounter.go new file mode 100644 index 00000000..96f498f5 --- /dev/null +++ b/observer/mapCounters/mapCounter.go @@ -0,0 +1,56 @@ +package mapCounters + +import "sync" + +type mapCounter struct { + positions map[uint32]uint32 + allNodesCount uint32 + allNodesPosition uint32 + mut sync.RWMutex +} + +// newMapCounter returns a new instance of a mapCounter +func newMapCounter() *mapCounter { + return &mapCounter{ + positions: make(map[uint32]uint32), + allNodesPosition: 0, + } +} + +func (mc *mapCounter) computePositionForShard(shardID uint32, numNodes uint32) uint32 { + mc.mut.Lock() + defer mc.mut.Unlock() + + mc.initShardPositionIfNeededUnprotected(shardID) + + mc.positions[shardID]++ + mc.positions[shardID] %= numNodes + + return mc.positions[shardID] +} + +func (mc *mapCounter) computePositionForAllNodes(numNodes uint32) uint32 { + mc.mut.Lock() + defer mc.mut.Unlock() + + mc.initAllNodesPositionIfNeededUnprotected(numNodes) + + mc.allNodesPosition++ + mc.allNodesPosition %= numNodes + + return mc.allNodesPosition +} + +func (mc *mapCounter) initShardPositionIfNeededUnprotected(shardID uint32) { + _, shardExists := mc.positions[shardID] + if !shardExists { + mc.positions[shardID] = 0 + } +} + +func (mc *mapCounter) initAllNodesPositionIfNeededUnprotected(numNodes uint32) { + if numNodes != mc.allNodesCount { + mc.allNodesCount = numNodes + mc.allNodesPosition = 0 + } +} diff --git a/observer/mapCounters/mapCounter_test.go b/observer/mapCounters/mapCounter_test.go new file mode 100644 index 00000000..b3b77ad7 --- /dev/null +++ b/observer/mapCounters/mapCounter_test.go @@ -0,0 +1,107 @@ +package mapCounters + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMapCounter(t *testing.T) { + t.Parallel() + + mc := newMapCounter() + require.NotNil(t, mc) + require.NotNil(t, mc.positions) +} + +func TestMapCounter_ComputeShardPositionShouldWorkWithDifferentNumOfNodes(t *testing.T) { + t.Parallel() + + mc := newMapCounter() + computeShardPosAndAssert(t, mc, 3, 1) + computeShardPosAndAssert(t, mc, 3, 2) + computeShardPosAndAssert(t, mc, 3, 0) + computeShardPosAndAssert(t, mc, 3, 1) + // change num nodes + computeShardPosAndAssert(t, mc, 2, 0) + computeShardPosAndAssert(t, mc, 2, 1) + computeShardPosAndAssert(t, mc, 2, 0) + // change num nodes again + computeShardPosAndAssert(t, mc, 5, 1) + computeShardPosAndAssert(t, mc, 5, 2) + computeShardPosAndAssert(t, mc, 5, 3) +} + +func TestMapCounter_ComputeShardPositionShouldWorkMultiShard(t *testing.T) { + t.Parallel() + + mc := newMapCounter() + computeShardPosAndAssertForShard(t, mc, 0, 3, 1) + computeShardPosAndAssertForShard(t, mc, 1, 4, 1) + + computeShardPosAndAssertForShard(t, mc, 0, 3, 2) + computeShardPosAndAssertForShard(t, mc, 1, 4, 2) + + computeShardPosAndAssertForShard(t, mc, 0, 3, 0) + computeShardPosAndAssertForShard(t, mc, 1, 4, 3) + + computeShardPosAndAssertForShard(t, mc, 0, 3, 1) + computeShardPosAndAssertForShard(t, mc, 1, 4, 0) + +} + +func computeShardPosAndAssertForShard(t *testing.T, mc *mapCounter, shardID uint32, numNodes uint32, expectedPos uint32) { + actualPos := mc.computePositionForShard(shardID, numNodes) + require.Equal(t, expectedPos, actualPos) +} + +func computeShardPosAndAssert(t *testing.T, mc *mapCounter, numNodes uint32, expectedPos uint32) { + computeShardPosAndAssertForShard(t, mc, 0, numNodes, expectedPos) +} + +func TestMapCounter_ComputeAllNodesPosition(t *testing.T) { + t.Parallel() + + mc := newMapCounter() + computeAllNodesPosAndAssert(t, mc, 3, 1) + computeAllNodesPosAndAssert(t, mc, 3, 2) + computeAllNodesPosAndAssert(t, mc, 3, 0) + computeAllNodesPosAndAssert(t, mc, 3, 1) + // change num nodes - should reset + computeAllNodesPosAndAssert(t, mc, 5, 1) + computeAllNodesPosAndAssert(t, mc, 5, 2) + computeAllNodesPosAndAssert(t, mc, 5, 3) + // change num nodes again - should reset + computeAllNodesPosAndAssert(t, mc, 2, 1) + computeAllNodesPosAndAssert(t, mc, 2, 0) + computeAllNodesPosAndAssert(t, mc, 2, 1) +} + +func computeAllNodesPosAndAssert(t *testing.T, mc *mapCounter, numNodes uint32, expectedPos uint32) { + actualPos := mc.computePositionForAllNodes(numNodes) + require.Equal(t, expectedPos, actualPos) +} + +func TestMapCounter_ConcurrentOperations(t *testing.T) { + t.Parallel() + + mc := newMapCounter() + + numOperations := 10_000 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + mc.computePositionForShard(uint32(idx), uint32(10+idx)) + case 1: + mc.computePositionForAllNodes(uint32(10 + idx)) + } + wg.Done() + }(i % 2) + } + + wg.Wait() +} diff --git a/observer/mapCounters/mapCountersHolder.go b/observer/mapCounters/mapCountersHolder.go new file mode 100644 index 00000000..be9722b1 --- /dev/null +++ b/observer/mapCounters/mapCountersHolder.go @@ -0,0 +1,63 @@ +package mapCounters + +import ( + "errors" + + "github.com/multiversx/mx-chain-proxy-go/data" + "github.com/multiversx/mx-chain-proxy-go/observer/availabilityCommon" +) + +var ( + errInvalidAvailability = errors.New("invalid data availability type") + errNumNodesMustBeGreaterThanZero = errors.New("the number of nodes must be greater than 0") +) + +// MapCountersHolder handles multiple counters map based on the data availability +type MapCountersHolder struct { + countersMap map[data.ObserverDataAvailabilityType]*mapCounter +} + +// NewMapCountersHolder populates the initial map and returns a new instance of MapCountersHolder +func NewMapCountersHolder() *MapCountersHolder { + availabilityProvider := availabilityCommon.AvailabilityProvider{} + dataAvailabilityTypes := availabilityProvider.GetAllAvailabilityTypes() + + countersMap := make(map[data.ObserverDataAvailabilityType]*mapCounter) + for _, availability := range dataAvailabilityTypes { + countersMap[availability] = newMapCounter() + } + + return &MapCountersHolder{ + countersMap: countersMap, + } +} + +// ComputeShardPosition returns the shard position based on the availability and the shard +func (mch *MapCountersHolder) ComputeShardPosition(availability data.ObserverDataAvailabilityType, shardID uint32, numNodes uint32) (uint32, error) { + counterMap, exists := mch.countersMap[availability] + if !exists { + return 0, errInvalidAvailability + } + + if numNodes == 0 { + return 0, errNumNodesMustBeGreaterThanZero + } + + position := counterMap.computePositionForShard(shardID, numNodes) + return position, nil +} + +// ComputeAllNodesPosition returns the all nodes position based on the availability +func (mch *MapCountersHolder) ComputeAllNodesPosition(availability data.ObserverDataAvailabilityType, numNodes uint32) (uint32, error) { + counterMap, exists := mch.countersMap[availability] + if !exists { + return 0, errInvalidAvailability + } + + if numNodes == 0 { + return 0, errNumNodesMustBeGreaterThanZero + } + + position := counterMap.computePositionForAllNodes(numNodes) + return position, nil +} diff --git a/observer/mapCounters/mapCountersHolder_test.go b/observer/mapCounters/mapCountersHolder_test.go new file mode 100644 index 00000000..55fae73d --- /dev/null +++ b/observer/mapCounters/mapCountersHolder_test.go @@ -0,0 +1,178 @@ +package mapCounters + +import ( + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-proxy-go/data" + "github.com/stretchr/testify/require" +) + +func TestNewMapCountersHolder(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + require.NotNil(t, mch) + require.Len(t, mch.countersMap, 2) +} + +func TestMapCountersHolder_ComputeShardPositionShouldFailDueToInvalidAvailability(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + pos, err := mch.ComputeShardPosition("invalid", 0, 10) + require.Equal(t, errInvalidAvailability, err) + require.Empty(t, pos) +} + +func TestMapCountersHolder_ComputeShardPositionShouldFailDueToZeroNumNodes(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + pos, err := mch.ComputeShardPosition(data.AvailabilityAll, 0, 0) + require.Equal(t, errNumNodesMustBeGreaterThanZero, err) + require.Empty(t, pos) +} + +func TestMapCountersHolder_ComputeShardPositionShouldWorkWhileChangingNumNodes(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + calculatePosAndAssert(t, mch, 0, 3, 1) + calculatePosAndAssert(t, mch, 1, 3, 1) + calculatePosAndAssert(t, mch, 2, 3, 1) + calculatePosAndAssert(t, mch, core.MetachainShardId, 3, 1) + + calculatePosAndAssert(t, mch, 0, 2, 0) + calculatePosAndAssert(t, mch, 1, 2, 0) + calculatePosAndAssert(t, mch, 2, 2, 0) + calculatePosAndAssert(t, mch, core.MetachainShardId, 2, 0) + + calculatePosAndAssert(t, mch, 0, 2, 1) + calculatePosAndAssert(t, mch, 1, 2, 1) + calculatePosAndAssert(t, mch, 2, 2, 1) + calculatePosAndAssert(t, mch, core.MetachainShardId, 2, 1) + + calculatePosAndAssert(t, mch, 0, 5, 2) + calculatePosAndAssert(t, mch, 1, 4, 2) + calculatePosAndAssert(t, mch, 2, 3, 2) + calculatePosAndAssert(t, mch, core.MetachainShardId, 2, 0) +} + +func TestMapCountersHolder_ComputeShardPositionShouldWorkForMultipleAvailabilities(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 0, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 0, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 1, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 1, 3, 1) + + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 0, 2, 0) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 0, 3, 2) + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 1, 2, 0) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 1, 5, 2) + + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 0, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 0, 3, 0) + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 1, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 1, 3, 0) + + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 0, 3, 2) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 0, 3, 1) + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, 1, 3, 2) + calculatePosAndAssertForShard(t, mch, data.AvailabilityAll, 1, 3, 1) +} + +func calculatePosAndAssertForShard( + t *testing.T, + mch *MapCountersHolder, + availability data.ObserverDataAvailabilityType, + shardID uint32, + numNodes uint32, + expectedPos uint32, +) { + pos, err := mch.ComputeShardPosition(availability, shardID, numNodes) + require.NoError(t, err) + require.Equal(t, expectedPos, pos) +} + +func calculatePosAndAssert(t *testing.T, mch *MapCountersHolder, shardID uint32, numNodes uint32, expectedPos uint32) { + calculatePosAndAssertForShard(t, mch, data.AvailabilityRecent, shardID, numNodes, expectedPos) +} + +func TestMapCountersHolder_ComputeAllNodesPositionShouldFailDueToInvalidAvailaility(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + pos, err := mch.ComputeAllNodesPosition("invalid", 10) + require.Equal(t, errInvalidAvailability, err) + require.Empty(t, pos) +} + +func TestMapCountersHolder_ComputeAllNodesPositionShouldFailDueToZeroNumNodes(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + pos, err := mch.ComputeAllNodesPosition(data.AvailabilityAll, 0) + require.Equal(t, errNumNodesMustBeGreaterThanZero, err) + require.Empty(t, pos) +} + +func TestMapCountersHolder_ComputeAllNodesPositionShouldWork(t *testing.T) { + t.Parallel() + + mch := NewMapCountersHolder() + + calculateAllNodesPosAndAssert(t, mch, 3, 1) + calculateAllNodesPosAndAssert(t, mch, 3, 2) + calculateAllNodesPosAndAssert(t, mch, 3, 0) + calculateAllNodesPosAndAssert(t, mch, 3, 1) + + calculateAllNodesPosAndAssert(t, mch, 5, 1) + calculateAllNodesPosAndAssert(t, mch, 5, 2) + calculateAllNodesPosAndAssert(t, mch, 5, 3) + calculateAllNodesPosAndAssert(t, mch, 5, 4) + calculateAllNodesPosAndAssert(t, mch, 5, 0) + calculateAllNodesPosAndAssert(t, mch, 5, 1) + + calculateAllNodesPosAndAssert(t, mch, 2, 1) + calculateAllNodesPosAndAssert(t, mch, 2, 0) +} + +func calculateAllNodesPosAndAssert(t *testing.T, mch *MapCountersHolder, numNodes uint32, expectedPos uint32) { + pos, err := mch.ComputeAllNodesPosition(data.AvailabilityRecent, numNodes) + require.NoError(t, err) + require.Equal(t, expectedPos, pos) +} + +func TestMapCountersHolder_ConcurrentOperations(t *testing.T) { + t.Parallel() + + numOperations := 10_000 + mch := NewMapCountersHolder() + + wg := sync.WaitGroup{} + wg.Add(numOperations) + + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + _, _ = mch.ComputeShardPosition(data.AvailabilityRecent, uint32(idx), uint32(10+idx)) + case 1: + _, _ = mch.ComputeShardPosition(data.AvailabilityAll, uint32(idx), uint32(10+idx)) + case 2: + _, _ = mch.ComputeAllNodesPosition(data.AvailabilityRecent, uint32(10+idx)) + case 3: + _, _ = mch.ComputeAllNodesPosition(data.AvailabilityAll, uint32(10+idx)) + } + }(i % 2) + } +}