diff --git a/fanout_test.go b/fanout_test.go index 62bbc9e..1fd3891 100644 --- a/fanout_test.go +++ b/fanout_test.go @@ -21,6 +21,7 @@ package fanout import ( "context" "fmt" + "math/rand" "net" "os" "strings" @@ -29,15 +30,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "github.com/coredns/caddy" "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/stretchr/testify/suite" - "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/goleak" ) const testQuery = "example1." @@ -390,7 +389,11 @@ func (t *fanoutTestSuite) TestServerCount() { c1 := NewClient(s1.addr, t.network) c2 := NewClient(s2.addr, t.network) f := New() - f.serverSelectionPolicy = &weightedPolicy{loadFactor: []int{50, 100}} + f.serverSelectionPolicy = &weightedPolicy{ + loadFactor: []int{50, 100}, + //nolint:gosec // init rand with constant seed to get predefined result + r: rand.New(rand.NewSource(1)), + } f.net = t.network f.from = "." f.addClient(c1) diff --git a/internal/selector/rand.go b/internal/selector/rand.go index 4f186ad..40d13a3 100644 --- a/internal/selector/rand.go +++ b/internal/selector/rand.go @@ -19,7 +19,6 @@ package selector import ( "math/rand" - "time" ) // WeightedRand selector picks elements randomly based on their weights @@ -31,13 +30,12 @@ type WeightedRand[T any] struct { } // NewWeightedRandSelector inits WeightedRand by copying source values and calculating total weight -func NewWeightedRandSelector[T any](values []T, weights []int) *WeightedRand[T] { +func NewWeightedRandSelector[T any](values []T, weights []int, r *rand.Rand) *WeightedRand[T] { wrs := &WeightedRand[T]{ values: make([]T, len(values)), weights: make([]int, len(weights)), totalWeight: 0, - //nolint:gosec // it's overhead to use crypto/rand here - r: rand.New(rand.NewSource(time.Now().UnixNano())), + r: r, } // copy the underlying array values as we're going to modify content of slices copy(wrs.values, values) diff --git a/internal/selector/rand_test.go b/internal/selector/rand_test.go index 81e03a1..30462c3 100644 --- a/internal/selector/rand_test.go +++ b/internal/selector/rand_test.go @@ -64,10 +64,10 @@ func TestWeightedRand_Pick(t *testing.T) { } for name, tc := range testCases { t.Run(name, func(t *testing.T) { - wrs := NewWeightedRandSelector(tc.values, tc.weights) - // init rand with constant seed to get predefined result - //nolint:gosec - wrs.r = rand.New(rand.NewSource(1)) + //nolint:gosec // init rand with constant seed to get predefined result + r := rand.New(rand.NewSource(1)) + + wrs := NewWeightedRandSelector(tc.values, tc.weights, r) actual := make([]string, 0, tc.picksCount) for i := 0; i < tc.picksCount; i++ { diff --git a/policy.go b/policy.go index f715158..63ed49d 100644 --- a/policy.go +++ b/policy.go @@ -16,7 +16,11 @@ package fanout -import "github.com/networkservicemesh/fanout/internal/selector" +import ( + "math/rand" + + "github.com/networkservicemesh/fanout/internal/selector" +) type policy interface { selector(clients []Client) clientSelector @@ -38,9 +42,10 @@ func (p *sequentialPolicy) selector(clients []Client) clientSelector { // weightedPolicy is used to select clients randomly based on its loadFactor (weights) type weightedPolicy struct { loadFactor []int + r *rand.Rand } // creates new weighted random selector of provided clients based on loadFactor func (p *weightedPolicy) selector(clients []Client) clientSelector { - return selector.NewWeightedRandSelector(clients, p.loadFactor) + return selector.NewWeightedRandSelector(clients, p.loadFactor, p.r) } diff --git a/setup.go b/setup.go index f5ab666..fcc701f 100644 --- a/setup.go +++ b/setup.go @@ -19,6 +19,7 @@ package fanout import ( + "math/rand" "os" "path/filepath" "strconv" @@ -173,7 +174,11 @@ func initServerSelectionPolicy(f *Fanout) error { f.serverSelectionPolicy = &sequentialPolicy{} if f.policyType == policyWeightedRandom { - f.serverSelectionPolicy = &weightedPolicy{loadFactor: loadFactor} + f.serverSelectionPolicy = &weightedPolicy{ + loadFactor: loadFactor, + //nolint:gosec // it's overhead to use crypto/rand here + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } } return nil