From a6446d332890bfc68812beef5a565b9b27f4137c Mon Sep 17 00:00:00 2001
From: Gleb Kogtev <gleb.kogtev@gmail.com>
Date: Tue, 20 Aug 2024 21:03:57 +0300
Subject: [PATCH] - add policy support (sequential/weighted-random) - rename
 simple selector to sequential - update docs + add examples

Signed-off-by: Gleb Kogtev <gleb.kogtev@gmail.com>
---
 README.md                                     | 26 +++++-
 const.go                                      | 26 +++---
 fanout.go                                     | 53 +++++------
 fanout_test.go                                |  1 +
 .../selector/{simple.go => sequential.go}     | 12 +--
 .../{simple_test.go => sequential_test.go}    |  2 +-
 policy.go                                     | 46 ++++++++++
 setup.go                                      | 92 ++++++++++++++-----
 setup_test.go                                 | 43 ++++++---
 9 files changed, 209 insertions(+), 92 deletions(-)
 rename internal/selector/{simple.go => sequential.go} (74%)
 rename internal/selector/{simple_test.go => sequential_test.go} (97%)
 create mode 100644 policy.go

diff --git a/README.md b/README.md
index a3e3cbf..c324ad5 100644
--- a/README.md
+++ b/README.md
@@ -24,8 +24,11 @@ Each incoming DNS query that hits the CoreDNS fanout plugin will be replicated i
   (Cloudflare) will not work.
 
 * `worker-count` is the number of parallel queries per request. By default equals to count of IP list. Use this only for reducing parallel queries per request.
-* `server-count` is the number of DNS servers to be requested. Equals to the number of specified IPs by default. If this parameter is lower than the number of specified IP addresses, servers are randomly selected based on the `load-factor` parameter.
-* `load-factor` - the probability of selecting a server. This is specified in the order of the list of IP addresses and takes values between 1 and 100. By default, all servers have an equal probability of 100.
+* `policy` - specifies the policy of DNS server selection mechanism. The default is `sequential`.
+  * `sequential` - select DNS servers one-by-one based on its order
+  * `weighted-random` - select DNS servers randomly based on `server-count` and `load-factor` params:
+    * `server-count` is the number of DNS servers to be requested. Equals to the number of specified IPs by default.
+    * `load-factor` - the probability of selecting a server. This is specified in the order of the list of IP addresses and takes values between 1 and 100. By default, all servers have an equal probability of 100.
 * `network` is a specific network protocol. Could be `tcp`, `udp`, `tcp-tls`.
 * `except` is a list is a space-separated list of domains to exclude from proxying.
 * `except-file` is the path to file with line-separated list of domains to exclude from proxying.
@@ -112,3 +115,22 @@ If `race` is enable, we will get `NXDOMAIN` result quickly, otherwise we will ge
     }
 }
 ~~~
+
+Sends parallel requests between two randomly selected resolvers. Note, that `127.0.0.1:9007` would be selected more frequently as it has the highest `load-factor`.
+~~~ corefile
+example.org {
+    fanout . 127.0.0.1:9005 127.0.0.1:9006 127.0.0.1:9007
+    policy weighted-random {
+      server-count 2
+      load-factor 50 70 100
+    }
+}
+~~~
+
+Sends parallel requests between three resolver sequentially (default mode).
+~~~ corefile
+example.org {
+    fanout . 127.0.0.1:9005 127.0.0.1:9006 127.0.0.1:9007
+    policy sequential
+}
+~~~
diff --git a/const.go b/const.go
index 074d35e..ae8b79c 100644
--- a/const.go
+++ b/const.go
@@ -19,16 +19,18 @@ package fanout
 import "time"
 
 const (
-	maxIPCount     = 100
-	maxLoadFactor  = 100
-	minLoadFactor  = 1
-	maxWorkerCount = 32
-	minWorkerCount = 2
-	maxTimeout     = 2 * time.Second
-	defaultTimeout = 30 * time.Second
-	readTimeout    = 2 * time.Second
-	attemptDelay   = time.Millisecond * 100
-	tcptls         = "tcp-tls"
-	tcp            = "tcp"
-	udp            = "udp"
+	maxIPCount           = 100
+	maxLoadFactor        = 100
+	minLoadFactor        = 1
+	policyWeightedRandom = "weighted-random"
+	policySequential     = "sequential"
+	maxWorkerCount       = 32
+	minWorkerCount       = 2
+	maxTimeout           = 2 * time.Second
+	defaultTimeout       = 30 * time.Second
+	readTimeout          = 2 * time.Second
+	attemptDelay         = time.Millisecond * 100
+	tcptls               = "tcp-tls"
+	tcp                  = "tcp"
+	udp                  = "udp"
 )
diff --git a/fanout.go b/fanout.go
index 94c095f..4e3dc88 100644
--- a/fanout.go
+++ b/fanout.go
@@ -31,7 +31,6 @@ import (
 	clog "github.com/coredns/coredns/plugin/pkg/log"
 	"github.com/coredns/coredns/request"
 	"github.com/miekg/dns"
-	"github.com/networkservicemesh/fanout/internal/selector"
 	"github.com/pkg/errors"
 )
 
@@ -39,36 +38,36 @@ var log = clog.NewWithPlugin("fanout")
 
 // Fanout represents a plugin instance that can do async requests to list of DNS servers.
 type Fanout struct {
-	clients        []Client
-	tlsConfig      *tls.Config
-	excludeDomains Domain
-	tlsServerName  string
-	timeout        time.Duration
-	race           bool
-	net            string
-	from           string
-	attempts       int
-	workerCount    int
-	serverCount    int
-	loadFactor     []int
-	tapPlugin      *dnstap.Dnstap
-	Next           plugin.Handler
+	clients               []Client
+	tlsConfig             *tls.Config
+	excludeDomains        Domain
+	tlsServerName         string
+	timeout               time.Duration
+	race                  bool
+	net                   string
+	from                  string
+	attempts              int
+	workerCount           int
+	serverCount           int
+	serverSelectionPolicy policy
+	tapPlugin             *dnstap.Dnstap
+	Next                  plugin.Handler
 }
 
 // New returns reference to new Fanout plugin instance with default configs.
 func New() *Fanout {
 	return &Fanout{
-		tlsConfig:      new(tls.Config),
-		net:            "udp",
-		attempts:       3,
-		timeout:        defaultTimeout,
-		excludeDomains: NewDomain(),
+		tlsConfig:             new(tls.Config),
+		net:                   "udp",
+		attempts:              3,
+		timeout:               defaultTimeout,
+		excludeDomains:        NewDomain(),
+		serverSelectionPolicy: &sequentialPolicy{}, // default policy
 	}
 }
 
 func (f *Fanout) addClient(p Client) {
 	f.clients = append(f.clients, p)
-	f.loadFactor = append(f.loadFactor, maxLoadFactor)
 	f.workerCount++
 	f.serverCount++
 }
@@ -110,18 +109,8 @@ func (f *Fanout) ServeDNS(ctx context.Context, w dns.ResponseWriter, m *dns.Msg)
 	return 0, nil
 }
 
-type clientSelector interface {
-	Pick() Client
-}
-
 func (f *Fanout) runWorkers(ctx context.Context, req *request.Request) chan *response {
-	var sel clientSelector
-	if f.serverCount == len(f.clients) {
-		sel = selector.NewSimpleSelector(f.clients)
-	} else {
-		sel = selector.NewWeightedRandSelector(f.clients, f.loadFactor)
-	}
-
+	sel := f.serverSelectionPolicy.selector(f.clients)
 	workerCh := make(chan Client, f.workerCount)
 	responseCh := make(chan *response, f.serverCount)
 	go func() {
diff --git a/fanout_test.go b/fanout_test.go
index 284fb7e..62bbc9e 100644
--- a/fanout_test.go
+++ b/fanout_test.go
@@ -390,6 +390,7 @@ func (t *fanoutTestSuite) TestServerCount() {
 	c1 := NewClient(s1.addr, t.network)
 	c2 := NewClient(s2.addr, t.network)
 	f := New()
+	f.serverSelectionPolicy = &weightedPolicy{loadFactor: []int{50, 100}}
 	f.net = t.network
 	f.from = "."
 	f.addClient(c1)
diff --git a/internal/selector/simple.go b/internal/selector/sequential.go
similarity index 74%
rename from internal/selector/simple.go
rename to internal/selector/sequential.go
index 2dc0ac2..67436ed 100644
--- a/internal/selector/simple.go
+++ b/internal/selector/sequential.go
@@ -16,15 +16,15 @@
 
 package selector
 
-// Simple selector acts like a queue and picks elements one-by-one starting from the first element
-type Simple[T any] struct {
+// Sequential selector acts like a queue and picks elements one-by-one starting from the first element
+type Sequential[T any] struct {
 	values []T
 	idx    int
 }
 
-// NewSimpleSelector inits Simple selector with default starting index 0
-func NewSimpleSelector[T any](values []T) *Simple[T] {
-	return &Simple[T]{
+// NewSequentialSelector inits Sequential selector with default starting index 0
+func NewSequentialSelector[T any](values []T) *Sequential[T] {
+	return &Sequential[T]{
 		values: values,
 		idx:    0,
 	}
@@ -32,7 +32,7 @@ func NewSimpleSelector[T any](values []T) *Simple[T] {
 
 // Pick returns next available element from values array if exists.
 // Returns default value of type T otherwise
-func (s *Simple[T]) Pick() T {
+func (s *Sequential[T]) Pick() T {
 	var result T
 	if s.idx >= len(s.values) {
 		return result
diff --git a/internal/selector/simple_test.go b/internal/selector/sequential_test.go
similarity index 97%
rename from internal/selector/simple_test.go
rename to internal/selector/sequential_test.go
index c3e58db..ddcdedf 100644
--- a/internal/selector/simple_test.go
+++ b/internal/selector/sequential_test.go
@@ -47,7 +47,7 @@ func TestSimple_Pick(t *testing.T) {
 	}
 	for name, tc := range testCases {
 		t.Run(name, func(t *testing.T) {
-			wrs := NewSimpleSelector(tc.values)
+			wrs := NewSequentialSelector(tc.values)
 
 			actual := make([]string, 0, tc.picksCount)
 			for i := 0; i < tc.picksCount; i++ {
diff --git a/policy.go b/policy.go
new file mode 100644
index 0000000..f715158
--- /dev/null
+++ b/policy.go
@@ -0,0 +1,46 @@
+// Copyright (c) 2024 MWS and/or its affiliates.
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at:
+//
+//	http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fanout
+
+import "github.com/networkservicemesh/fanout/internal/selector"
+
+type policy interface {
+	selector(clients []Client) clientSelector
+}
+
+type clientSelector interface {
+	Pick() Client
+}
+
+// sequentialPolicy is used to select clients based on its sequential order
+type sequentialPolicy struct {
+}
+
+// creates new sequential selector of provided clients
+func (p *sequentialPolicy) selector(clients []Client) clientSelector {
+	return selector.NewSequentialSelector(clients)
+}
+
+// weightedPolicy is used to select clients randomly based on its loadFactor (weights)
+type weightedPolicy struct {
+	loadFactor []int
+}
+
+// creates new weighted random selector of provided clients based on loadFactor
+func (p *weightedPolicy) selector(clients []Client) clientSelector {
+	return selector.NewWeightedRandSelector(clients, p.loadFactor)
+}
diff --git a/setup.go b/setup.go
index ff233ca..69f4977 100644
--- a/setup.go
+++ b/setup.go
@@ -122,22 +122,14 @@ func parsefanoutStanza(c *caddyfile.Dispenser) (*Fanout, error) {
 		return f, err
 	}
 	for c.NextBlock() {
-		err = parseValue(strings.ToLower(c.Val()), f, c)
+		err = parseValue(strings.ToLower(c.Val()), f, c, toHosts)
 		if err != nil {
 			return nil, err
 		}
 	}
 	initClients(f, toHosts)
-	if f.serverCount > len(toHosts) || f.serverCount == 0 {
-		f.serverCount = len(toHosts)
-	}
-	if len(f.loadFactor) == 0 {
-		for i := 0; i < len(toHosts); i++ {
-			f.loadFactor = append(f.loadFactor, maxLoadFactor)
-		}
-	}
-	if len(f.loadFactor) != len(toHosts) {
-		return nil, errors.New("load-factor params count must be the same as the number of hosts")
+	if f.serverCount > len(f.clients) || f.serverCount == 0 {
+		f.serverCount = len(f.clients)
 	}
 
 	if f.workerCount > len(f.clients) || f.workerCount == 0 {
@@ -163,7 +155,7 @@ func initClients(f *Fanout, hosts []string) {
 	}
 }
 
-func parseValue(v string, f *Fanout, c *caddyfile.Dispenser) error {
+func parseValue(v string, f *Fanout, c *caddyfile.Dispenser, hosts []string) error {
 	switch v {
 	case "tls":
 		return parseTLS(f, c)
@@ -173,12 +165,8 @@ func parseValue(v string, f *Fanout, c *caddyfile.Dispenser) error {
 		return parseTLSServer(f, c)
 	case "worker-count":
 		return parseWorkerCount(f, c)
-	case "server-count":
-		num, err := parsePositiveInt(c)
-		f.serverCount = num
-		return err
-	case "load-factor":
-		return parseLoadFactor(f, c)
+	case "policy":
+		return parsePolicy(f, c, hosts)
 	case "timeout":
 		return parseTimeout(f, c)
 	case "race":
@@ -196,6 +184,59 @@ func parseValue(v string, f *Fanout, c *caddyfile.Dispenser) error {
 	}
 }
 
+func parsePolicy(f *Fanout, c *caddyfile.Dispenser, hosts []string) error {
+	if !c.NextArg() {
+		return c.ArgErr()
+	}
+
+	switch c.Val() {
+	case policyWeightedRandom:
+		// omit "{"
+		c.Next()
+		if c.Val() != "{" {
+			return c.Err("Wrong policy configuration")
+		}
+	case policySequential:
+		f.serverSelectionPolicy = &sequentialPolicy{}
+		return nil
+	default:
+		return errors.Errorf("unknown policy %q", c.Val())
+	}
+
+	var loadFactor []int
+	for c.Next() {
+		if c.Val() == "}" {
+			break
+		}
+
+		var err error
+		switch c.Val() {
+		case "server-count":
+			f.serverCount, err = parsePositiveInt(c)
+		case "load-factor":
+			loadFactor, err = parseLoadFactor(c)
+		default:
+			return errors.Errorf("unknown property %q", c.Val())
+		}
+		if err != nil {
+			return err
+		}
+	}
+
+	if len(loadFactor) == 0 {
+		for i := 0; i < len(hosts); i++ {
+			loadFactor = append(loadFactor, maxLoadFactor)
+		}
+	}
+	if len(loadFactor) != len(hosts) {
+		return errors.New("load-factor params count must be the same as the number of hosts")
+	}
+
+	f.serverSelectionPolicy = &weightedPolicy{loadFactor: loadFactor}
+
+	return nil
+}
+
 func parseTimeout(f *Fanout, c *caddyfile.Dispenser) error {
 	if !c.NextArg() {
 		return c.ArgErr()
@@ -263,29 +304,30 @@ func parseWorkerCount(f *Fanout, c *caddyfile.Dispenser) error {
 	return err
 }
 
-func parseLoadFactor(f *Fanout, c *caddyfile.Dispenser) error {
+func parseLoadFactor(c *caddyfile.Dispenser) ([]int, error) {
 	args := c.RemainingArgs()
 	if len(args) == 0 {
-		return c.ArgErr()
+		return nil, c.ArgErr()
 	}
 
+	result := make([]int, 0, len(args))
 	for _, arg := range args {
 		loadFactor, err := strconv.Atoi(arg)
 		if err != nil {
-			return c.ArgErr()
+			return nil, c.ArgErr()
 		}
 
 		if loadFactor < minLoadFactor {
-			return errors.New("load-factor should be more or equal 1")
+			return nil, errors.New("load-factor should be more or equal 1")
 		}
 		if loadFactor > maxLoadFactor {
-			return errors.Errorf("load-factor %d should be less than %d", loadFactor, maxLoadFactor)
+			return nil, errors.Errorf("load-factor %d should be less than %d", loadFactor, maxLoadFactor)
 		}
 
-		f.loadFactor = append(f.loadFactor, loadFactor)
+		result = append(result, loadFactor)
 	}
 
-	return nil
+	return result, nil
 }
 
 func parsePositiveInt(c *caddyfile.Dispenser) (int, error) {
diff --git a/setup_test.go b/setup_test.go
index 8b5f2e0..f368ce2 100644
--- a/setup_test.go
+++ b/setup_test.go
@@ -43,12 +43,15 @@ func TestSetup(t *testing.T) {
 		expectedErr         string
 	}{
 		// positive
-		{input: "fanout . 127.0.0.1", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 1, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: []int{100}},
-		{input: "fanout . 127.0.0.1 {\nserver-count 5\n}", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 1, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: []int{100}},
-		{input: "fanout . 127.0.0.1 {\nexcept a b\nworker-count 3\n}", expectedFrom: ".", expectedTimeout: defaultTimeout, expectedAttempts: 3, expectedWorkers: 1, expectedIgnored: []string{"a.", "b."}, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: []int{100}},
-		{input: "fanout . 127.0.0.1 127.0.0.2 {\nnetwork tcp\n}", expectedFrom: ".", expectedTimeout: defaultTimeout, expectedAttempts: 3, expectedWorkers: 2, expectedNetwork: "tcp", expectedTo: []string{"127.0.0.1:53", "127.0.0.2:53"}, expectedServerCount: 2, expectedLoadFactor: []int{100, 100}},
-		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 127.0.0.4 {\nworker-count 3\ntimeout 1m\n}", expectedTimeout: time.Minute, expectedAttempts: 3, expectedFrom: ".", expectedWorkers: 3, expectedNetwork: "udp", expectedServerCount: 4, expectedLoadFactor: []int{100, 100, 100, 100}},
-		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 127.0.0.4 {\nattempt-count 2\n}", expectedTimeout: defaultTimeout, expectedFrom: ".", expectedAttempts: 2, expectedWorkers: 4, expectedNetwork: "udp", expectedServerCount: 4, expectedLoadFactor: []int{100, 100, 100, 100}},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nserver-count 5 load-factor 100\n}", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 1, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: []int{100}},
+		{input: "fanout . 127.0.0.1", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 1, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: nil},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nserver-count 5 load-factor 100\n}", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 1, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: []int{100}},
+		{input: "fanout . 127.0.0.1 {\nexcept a b\nworker-count 3\n}", expectedFrom: ".", expectedTimeout: defaultTimeout, expectedAttempts: 3, expectedWorkers: 1, expectedIgnored: []string{"a.", "b."}, expectedNetwork: "udp", expectedServerCount: 1, expectedLoadFactor: nil},
+		{input: "fanout . 127.0.0.1 127.0.0.2 {\nnetwork tcp\n}", expectedFrom: ".", expectedTimeout: defaultTimeout, expectedAttempts: 3, expectedWorkers: 2, expectedNetwork: "tcp", expectedTo: []string{"127.0.0.1:53", "127.0.0.2:53"}, expectedServerCount: 2, expectedLoadFactor: nil},
+		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 127.0.0.4 {\nworker-count 3\ntimeout 1m\n}", expectedTimeout: time.Minute, expectedAttempts: 3, expectedFrom: ".", expectedWorkers: 3, expectedNetwork: "udp", expectedServerCount: 4, expectedLoadFactor: nil},
+		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 127.0.0.4 {\nattempt-count 2\n}", expectedTimeout: defaultTimeout, expectedFrom: ".", expectedAttempts: 2, expectedWorkers: 4, expectedNetwork: "udp", expectedServerCount: 4, expectedLoadFactor: nil},
+		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 {\npolicy weighted-random {}\n}", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 3, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 3, expectedLoadFactor: []int{100, 100, 100}},
+		{input: "fanout . 127.0.0.1 127.0.0.2 127.0.0.3 {\npolicy sequential\nworker-count 3\n}", expectedFrom: ".", expectedAttempts: 3, expectedWorkers: 3, expectedTimeout: defaultTimeout, expectedNetwork: "udp", expectedServerCount: 3, expectedLoadFactor: nil},
 
 		// negative
 		{input: "fanout . aaa", expectedErr: "not an IP address or file"},
@@ -57,12 +60,13 @@ func TestSetup(t *testing.T) {
 		{input: "fanout . 127.0.0.1 {\nexcept a b\nworker-count ten\n}", expectedErr: "'ten'"},
 		{input: "fanout . 127.0.0.1 {\nexcept a:\nworker-count ten\n}", expectedErr: "unable to normalize 'a:'"},
 		{input: "fanout . 127.0.0.1 127.0.0.2 {\nnetwork XXX\n}", expectedErr: "unknown network protocol"},
-		{input: "fanout . 127.0.0.1 {\nserver-count -100\n}", expectedErr: "Wrong argument count or unexpected line ending"},
-		{input: "fanout . 127.0.0.1 {\nload-factor 150\n}", expectedErr: "load-factor 150 should be less than 100"},
-		{input: "fanout . 127.0.0.1 {\nload-factor 0\n}", expectedErr: "load-factor should be more or equal 1"},
-		{input: "fanout . 127.0.0.1 {\nload-factor 50 100\n}", expectedErr: "load-factor params count must be the same as the number of hosts"},
-		{input: "fanout . 127.0.0.1 127.0.0.2 {\nload-factor 50\n}", expectedErr: "load-factor params count must be the same as the number of hosts"},
-		{input: "fanout . 127.0.0.1 127.0.0.2 {\nload-factor \n}", expectedErr: "Wrong argument count or unexpected line ending"},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nserver-count -100\n}\n}", expectedErr: "Wrong argument count or unexpected line ending"},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nload-factor 150\n}\n}", expectedErr: "load-factor 150 should be less than 100"},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nload-factor 0\n}\n}", expectedErr: "load-factor should be more or equal 1"},
+		{input: "fanout . 127.0.0.1 {\npolicy weighted-random {\nload-factor 50 100\n}\n}", expectedErr: "load-factor params count must be the same as the number of hosts"},
+		{input: "fanout . 127.0.0.1 127.0.0.2 {\npolicy weighted-random {\nload-factor 50\n}\n}", expectedErr: "load-factor params count must be the same as the number of hosts"},
+		{input: "fanout . 127.0.0.1 127.0.0.2 {\npolicy weighted-random {\nload-factor \n}\n}", expectedErr: "Wrong argument count or unexpected line ending"},
+		{input: "fanout . 127.0.0.1 127.0.0.2 {\npolicy weighted-random\nworker-count 10\n}", expectedErr: "Wrong policy configuration"},
 	}
 
 	for i, test := range tests {
@@ -112,8 +116,19 @@ func TestSetup(t *testing.T) {
 		if f.serverCount != test.expectedServerCount {
 			t.Fatalf("Test %d: expected: %d, got: %d", i, test.expectedServerCount, f.serverCount)
 		}
-		if !reflect.DeepEqual(f.loadFactor, test.expectedLoadFactor) {
-			t.Fatalf("Test %d: expected: %d, got: %d", i, test.expectedLoadFactor, f.loadFactor)
+
+		selectionPolicy, ok := f.serverSelectionPolicy.(*weightedPolicy)
+		if len(test.expectedLoadFactor) > 0 {
+			if !ok {
+				t.Fatalf("Test %d: expected weighted policy to be set, got: %T", i, f.serverSelectionPolicy)
+			}
+			if !reflect.DeepEqual(selectionPolicy.loadFactor, test.expectedLoadFactor) {
+				t.Fatalf("Test %d: expected: %d, got: %d", i, test.expectedLoadFactor, selectionPolicy.loadFactor)
+			}
+		} else {
+			if ok {
+				t.Fatalf("Test %d: expected sequential policy to be set, got: %T", i, f.serverSelectionPolicy)
+			}
 		}
 	}
 }