forked from hashicorp/memberlist
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.go
309 lines (264 loc) · 7.98 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
package memberlist
import (
"bytes"
"compress/lzw"
"encoding/binary"
"fmt"
"io"
"math"
"math/rand"
"net"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-msgpack/codec"
"github.com/sean-/seed"
)
// pushPullScale is the minimum number of nodes
// before we start scaling the push/pull timing. The scale
// effect is the log2(Nodes) - log2(pushPullScale). This means
// that the 33rd node will cause us to double the interval,
// while the 65th will triple it.
const pushPullScaleThreshold = 32
const (
// Constant litWidth 2-8
lzwLitWidth = 8
)
func init() {
seed.Init()
}
// Decode reverses the encode operation on a byte slice input
func decode(buf []byte, out interface{}) error {
r := bytes.NewReader(buf)
hd := codec.MsgpackHandle{}
dec := codec.NewDecoder(r, &hd)
return dec.Decode(out)
}
// Encode writes an encoded object to a new bytes buffer
func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
buf := bytes.NewBuffer(nil)
buf.WriteByte(uint8(msgType))
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(buf, &hd)
err := enc.Encode(in)
return buf, err
}
// Returns a random offset between 0 and n
func randomOffset(n int) int {
if n == 0 {
return 0
}
return int(rand.Uint32() % uint32(n))
}
// suspicionTimeout computes the timeout that should be used when
// a node is suspected
func suspicionTimeout(suspicionMult, n int, interval time.Duration) time.Duration {
nodeScale := math.Max(1.0, math.Log10(math.Max(1.0, float64(n))))
// multiply by 1000 to keep some precision because time.Duration is an int64 type
timeout := time.Duration(suspicionMult) * time.Duration(nodeScale*1000) * interval / 1000
return timeout
}
// retransmitLimit computes the limit of retransmissions
func retransmitLimit(retransmitMult, n int) int {
nodeScale := math.Ceil(math.Log10(float64(n + 1)))
limit := retransmitMult * int(nodeScale)
return limit
}
// shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle
func shuffleNodes(nodes []*nodeState) {
n := len(nodes)
rand.Shuffle(n, func(i, j int) {
nodes[i], nodes[j] = nodes[j], nodes[i]
})
}
// pushPushScale is used to scale the time interval at which push/pull
// syncs take place. It is used to prevent network saturation as the
// cluster size grows
func pushPullScale(interval time.Duration, n int) time.Duration {
// Don't scale until we cross the threshold
if n <= pushPullScaleThreshold {
return interval
}
multiplier := math.Ceil(math.Log2(float64(n))-math.Log2(pushPullScaleThreshold)) + 1.0
return time.Duration(multiplier) * interval
}
// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval
// to the end of the slice and returns the index of the first moved node.
func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int {
numDead := 0
n := len(nodes)
for i := 0; i < n-numDead; i++ {
if nodes[i].State != StateDead {
continue
}
// Respect the gossip to the dead interval
if time.Since(nodes[i].StateChange) <= gossipToTheDeadTime {
continue
}
// Move this node to the end
nodes[i], nodes[n-numDead-1] = nodes[n-numDead-1], nodes[i]
numDead++
i--
}
return n - numDead
}
// kRandomNodes is used to select up to k random nodes, excluding any nodes where
// the filter function returns true. It is possible that less than k nodes are
// returned.
func kRandomNodes(k int, nodes []*nodeState, filterFn func(*nodeState) bool) []*nodeState {
n := len(nodes)
kNodes := make([]*nodeState, 0, k)
OUTER:
// Probe up to 3*n times, with large n this is not necessary
// since k << n, but with small n we want search to be
// exhaustive
for i := 0; i < 3*n && len(kNodes) < k; i++ {
// Get random node
idx := randomOffset(n)
node := nodes[idx]
// Give the filter a shot at it.
if filterFn != nil && filterFn(node) {
continue OUTER
}
// Check if we have this node already
for j := 0; j < len(kNodes); j++ {
if node == kNodes[j] {
continue OUTER
}
}
// Append the node
kNodes = append(kNodes, node)
}
return kNodes
}
// makeCompoundMessage takes a list of messages and generates
// a single compound message containing all of them
func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
// Create a local buffer
buf := bytes.NewBuffer(nil)
// Write out the type
buf.WriteByte(uint8(compoundMsg))
// Write out the number of message
buf.WriteByte(uint8(len(msgs)))
// Add the message lengths
for _, m := range msgs {
binary.Write(buf, binary.BigEndian, uint16(len(m)))
}
// Append the messages
for _, m := range msgs {
buf.Write(m)
}
return buf
}
// decodeCompoundMessage splits a compound message and returns
// the slices of individual messages. Also returns the number
// of truncated messages and any potential error
func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
if len(buf) < 1 {
err = fmt.Errorf("missing compound length byte")
return
}
numParts := uint8(buf[0])
buf = buf[1:]
// Check we have enough bytes
if len(buf) < int(numParts*2) {
err = fmt.Errorf("truncated len slice")
return
}
// Decode the lengths
lengths := make([]uint16, numParts)
for i := 0; i < int(numParts); i++ {
lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
}
buf = buf[numParts*2:]
// Split each message
for idx, msgLen := range lengths {
if len(buf) < int(msgLen) {
trunc = int(numParts) - idx
return
}
// Extract the slice, seek past on the buffer
slice := buf[:msgLen]
buf = buf[msgLen:]
parts = append(parts, slice)
}
return
}
// compressPayload takes an opaque input buffer, compresses it
// and wraps it in a compress{} message that is encoded.
func compressPayload(inp []byte) (*bytes.Buffer, error) {
var buf bytes.Buffer
compressor := lzw.NewWriter(&buf, lzw.LSB, lzwLitWidth)
_, err := compressor.Write(inp)
if err != nil {
return nil, err
}
// Ensure we flush everything out
if err := compressor.Close(); err != nil {
return nil, err
}
// Create a compressed message
c := compress{
Algo: lzwAlgo,
Buf: buf.Bytes(),
}
return encode(compressMsg, &c)
}
// decompressPayload is used to unpack an encoded compress{}
// message and return its payload uncompressed
func decompressPayload(msg []byte) ([]byte, error) {
// Decode the message
var c compress
if err := decode(msg, &c); err != nil {
return nil, err
}
return decompressBuffer(&c)
}
// decompressBuffer is used to decompress the buffer of
// a single compress message, handling multiple algorithms
func decompressBuffer(c *compress) ([]byte, error) {
// Verify the algorithm
if c.Algo != lzwAlgo {
return nil, fmt.Errorf("Cannot decompress unknown algorithm %d", c.Algo)
}
// Create a uncompressor
uncomp := lzw.NewReader(bytes.NewReader(c.Buf), lzw.LSB, lzwLitWidth)
defer uncomp.Close()
// Read all the data
var b bytes.Buffer
_, err := io.Copy(&b, uncomp)
if err != nil {
return nil, err
}
// Return the uncompressed bytes
return b.Bytes(), nil
}
// joinHostPort returns the host:port form of an address, for use with a
// transport.
func joinHostPort(host string, port uint16) string {
return net.JoinHostPort(host, strconv.Itoa(int(port)))
}
// hasPort is given a string of the form "host", "host:port", "ipv6::address",
// or "[ipv6::address]:port", and returns true if the string includes a port.
func hasPort(s string) bool {
// IPv6 address in brackets.
if strings.LastIndex(s, "[") == 0 {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}
// Otherwise the presence of a single colon determines if there's a port
// since IPv6 addresses outside of brackets (count > 1) can't have a
// port.
return strings.Count(s, ":") == 1
}
// ensurePort makes sure the given string has a port number on it, otherwise it
// appends the given port as a default.
func ensurePort(s string, port int) string {
if hasPort(s) {
return s
}
// If this is an IPv6 address, the join call will add another set of
// brackets, so we have to trim before we add the default port.
s = strings.Trim(s, "[]")
s = net.JoinHostPort(s, strconv.Itoa(port))
return s
}