From 6d92cfb0ed42fb668627517ed69e061625fe12e4 Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Wed, 3 Jul 2024 11:33:21 +0800 Subject: [PATCH] perf: construct virtual node with skip list to optimize consistent hash rebalance --- pkg/loadbalance/consist.go | 92 ++------ pkg/loadbalance/consist_test.go | 157 +++++++++++--- pkg/loadbalance/newconsist/newconsist.go | 184 ++++++++++++++++ pkg/loadbalance/newconsist/newconsist_test.go | 203 ++++++++++++++++++ pkg/loadbalance/newconsist/skiplist.go | 168 +++++++++++++++ pkg/loadbalance/weighted_balancer_test.go | 2 +- 6 files changed, 697 insertions(+), 109 deletions(-) create mode 100644 pkg/loadbalance/newconsist/newconsist.go create mode 100644 pkg/loadbalance/newconsist/newconsist_test.go create mode 100644 pkg/loadbalance/newconsist/skiplist.go diff --git a/pkg/loadbalance/consist.go b/pkg/loadbalance/consist.go index df56102e93..f30be26a85 100644 --- a/pkg/loadbalance/consist.go +++ b/pkg/loadbalance/consist.go @@ -18,6 +18,7 @@ package loadbalance import ( "context" + "github.com/cloudwego/kitex/pkg/loadbalance/newconsist" "sort" "sync" "time" @@ -26,7 +27,6 @@ import ( "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/discovery" - "github.com/cloudwego/kitex/pkg/utils" ) /* @@ -136,7 +136,7 @@ func (v *vNodeType) Swap(i, j int) { type consistPicker struct { cb *consistBalancer - info *consistInfo + info *newconsist.ConsistInfo // index int // result *consistResult } @@ -159,15 +159,14 @@ func (cp *consistPicker) Recycle() { // Next is not concurrency safe. func (cp *consistPicker) Next(ctx context.Context, request interface{}) discovery.Instance { - if len(cp.info.realNodes) == 0 { + if cp.info.IsEmpty() { return nil } key := cp.cb.opt.GetKey(ctx, request) if key == "" { return nil } - res := buildConsistResult(cp.info, xxhash3.HashString(key)) - return res.Primary + return cp.info.BuildConsistentResult(xxhash3.HashString(key)) // Todo(DMwangnima): Optimise Replica-related logic // This comment part is previous implementation considering connecting to Replica // Since we would create a new picker each time, the Replica logic is unreachable, so just comment it out for now @@ -255,7 +254,7 @@ func NewConsistBalancer(opt ConsistentHashOption) Loadbalancer { // GetPicker implements the Loadbalancer interface. func (cb *consistBalancer) GetPicker(e discovery.Result) Picker { - var ci *consistInfo + var ci *newconsist.ConsistInfo if e.Cacheable { cii, ok := cb.cachedConsistInfo.Load(e.CacheKey) if !ok { @@ -264,7 +263,7 @@ func (cb *consistBalancer) GetPicker(e discovery.Result) Picker { }) cb.cachedConsistInfo.Store(e.CacheKey, cii) } - ci = cii.(*consistInfo) + ci = cii.(*newconsist.ConsistInfo) } else { ci = cb.newConsistInfo(e) } @@ -274,75 +273,8 @@ func (cb *consistBalancer) GetPicker(e discovery.Result) Picker { return picker } -func (cb *consistBalancer) newConsistInfo(e discovery.Result) *consistInfo { - ci := &consistInfo{} - ci.realNodes, ci.virtualNodes = cb.buildNodes(e.Instances) - return ci -} - -func (cb *consistBalancer) buildNodes(ins []discovery.Instance) ([]realNode, []virtualNode) { - ret := make([]realNode, len(ins)) - for i := range ins { - ret[i].Ins = ins[i] - } - return ret, cb.buildVirtualNodes(ret) -} - -func (cb *consistBalancer) buildVirtualNodes(rNodes []realNode) []virtualNode { - totalLen := 0 - for i := range rNodes { - totalLen += cb.getVirtualNodeLen(rNodes[i]) - } - - ret := make([]virtualNode, totalLen) - if totalLen == 0 { - return ret - } - maxLen, maxSerial := 0, 0 - for i := range rNodes { - if len(rNodes[i].Ins.Address().String()) > maxLen { - maxLen = len(rNodes[i].Ins.Address().String()) - } - if vNodeLen := cb.getVirtualNodeLen(rNodes[i]); vNodeLen > maxSerial { - maxSerial = vNodeLen - } - } - l := maxLen + 1 + utils.GetUIntLen(uint64(maxSerial)) // "$address + # + itoa(i)" - // pre-allocate []byte here, and reuse it to prevent memory allocation. - b := make([]byte, l) - - // record the start index. - cur := 0 - for i := range rNodes { - bAddr := utils.StringToSliceByte(rNodes[i].Ins.Address().String()) - // Assign the first few bits of b to string. - copy(b, bAddr) - - // Initialize the last few bits, skipping '#'. - for j := len(bAddr) + 1; j < len(b); j++ { - b[j] = 0 - } - b[len(bAddr)] = '#' - - vLen := cb.getVirtualNodeLen(rNodes[i]) - for j := 0; j < vLen; j++ { - k := j - cnt := 0 - // Assign values to b one by one, starting with the last one. - for k > 0 { - b[l-1-cnt] = byte(k % 10) - k /= 10 - cnt++ - } - // At this point, the index inside ret should be cur + j. - index := cur + j - ret[index].hash = xxhash3.Hash(b) - ret[index].RealNode = &rNodes[i] - } - cur += vLen - } - sort.Sort(&vNodeType{s: ret}) - return ret +func (cb *consistBalancer) newConsistInfo(e discovery.Result) *newconsist.ConsistInfo { + return newconsist.NewConsistInfo(e, newconsist.ConsistInfoConfig{VirtualFactor: cb.opt.VirtualFactor, Weighted: cb.opt.Weighted}) } // get virtual node number from one realNode. @@ -364,9 +296,11 @@ func (cb *consistBalancer) Rebalance(change discovery.Change) { if !change.Result.Cacheable { return } - // TODO: Use TreeMap to optimize performance when updating. - // Now, due to the lack of a good red-black tree implementation, we can only build the full amount once per update. - cb.updateConsistInfo(change.Result) + if ci, ok := cb.cachedConsistInfo.Load(change.Result.CacheKey); ok { + ci.(*newconsist.ConsistInfo).Rebalance(change) + } else { + cb.updateConsistInfo(change.Result) + } } // Delete implements the Rebalancer interface. diff --git a/pkg/loadbalance/consist_test.go b/pkg/loadbalance/consist_test.go index 816583babe..7d8221f7d4 100644 --- a/pkg/loadbalance/consist_test.go +++ b/pkg/loadbalance/consist_test.go @@ -19,13 +19,13 @@ package loadbalance import ( "context" "fmt" + "github.com/bytedance/gopkg/lang/fastrand" + "github.com/cloudwego/kitex/pkg/loadbalance/newconsist" "math/rand" + "runtime" "strconv" "strings" "testing" - "time" - - "github.com/bytedance/gopkg/lang/fastrand" "github.com/cloudwego/kitex/internal" "github.com/cloudwego/kitex/internal/test" @@ -53,13 +53,12 @@ func getRandomKey(ctx context.Context, request interface{}) string { return key } -func getRandomString(length int) string { +func getRandomString(r *rand.Rand, length int) string { var resBuilder strings.Builder resBuilder.Grow(length) corpus := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - rand.Seed(time.Now().UnixNano() + int64(100)) for i := 0; i < length; i++ { - resBuilder.WriteByte(corpus[rand.Intn(len(corpus))]) + resBuilder.WriteByte(corpus[r.Intn(len(corpus))]) } return resBuilder.String() } @@ -166,20 +165,29 @@ func TestConsistPicker_Next_NoCache_Consist(t *testing.T) { CacheKey: "", Instances: insList, } + opt.GetKey = func(ctx context.Context, request interface{}) string { + v := ctx.Value("key") + return v.(string) + } + + cnt := make(map[string]int) + for _, ins := range insList { + cnt[ins.Address().String()] = 0 + } + cnt["null"] = 0 cb := NewConsistBalancer(opt) picker := cb.GetPicker(e) - ins := picker.Next(context.TODO(), nil) - for i := 0; i < 100; i++ { - picker := cb.GetPicker(e) - test.Assert(t, picker.Next(context.TODO(), nil) == ins) + for i := 0; i < 100000; i++ { + ctx := context.WithValue(context.Background(), "key", strconv.Itoa(i)) + if res := picker.Next(ctx, nil); res != nil { + cnt[res.Address().String()]++ + } else { + cnt["null"]++ + } } + fmt.Println(cnt) - cb = NewConsistBalancer(opt) - for i := 0; i < 100; i++ { - picker := cb.GetPicker(e) - test.Assert(t, picker.Next(context.TODO(), nil) == ins) - } } func TestConsistPicker_Next_Cache(t *testing.T) { @@ -315,7 +323,8 @@ func TestConsistPicker_Reblance(t *testing.T) { picker := cb.GetPicker(e) key := strconv.Itoa(i) ctx = context.WithValue(ctx, keyCtxKey, key) - test.DeepEqual(t, record[key], picker.Next(ctx, nil)) + res := picker.Next(ctx, nil) + test.DeepEqual(t, record[key], res) } } @@ -388,36 +397,126 @@ func BenchmarkNewConsistPicker(bb *testing.B) { // BenchmarkConsistPicker_RandomDistributionKey/1000ins-12 2848216. 407.7 ns/op 48 B/op 1 allocs/op // BenchmarkConsistPicker_RandomDistributionKey/10000ins // BenchmarkConsistPicker_RandomDistributionKey/10000ins-12 2701766 492.7 ns/op 48 B/op 1 allocs/op -func BenchmarkConsistPicker_RandomDistributionKey(bb *testing.B) { +func BenchmarkConsistPicker_RandomDistributionKey(b *testing.B) { n := 10 balancer := NewConsistBalancer(NewConsistentHashOption(getRandomKey)) for i := 0; i < 4; i++ { - bb.Run(fmt.Sprintf("%dins", n), func(b *testing.B) { + b.Run(fmt.Sprintf("%dins", n), func(b *testing.B) { + r := rand.New(rand.NewSource(int64(n))) inss := makeNInstances(n, 10) e := discovery.Result{ Cacheable: true, CacheKey: "test", Instances: inss, } + b.ReportAllocs() + b.ResetTimer() picker := balancer.GetPicker(e) - ctx := context.WithValue(context.Background(), keyCtxKey, getRandomString(30)) + ctx := context.WithValue(context.Background(), keyCtxKey, getRandomString(r, 30)) picker.Next(ctx, nil) picker.(internal.Reusable).Recycle() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.Logf("round %d", i) - b.StopTimer() - ctx = context.WithValue(context.Background(), keyCtxKey, getRandomString(30)) - b.StartTimer() - picker := balancer.GetPicker(e) + for j := 0; j < b.N; j++ { + ctx = context.WithValue(context.Background(), keyCtxKey, getRandomString(r, 30)) + picker = balancer.GetPicker(e) picker.Next(ctx, nil) - if r, ok := picker.(internal.Reusable); ok { - r.Recycle() + if toRecycle, ok := picker.(internal.Reusable); ok { + toRecycle.Recycle() } } }) n *= 10 } } + +func BenchmarkRebalance(bb *testing.B) { + weight := 10 + nums := 10000 + + for n := 0; n < 1; n++ { + bb.Run(fmt.Sprintf("consist-remove-%d", nums), func(b *testing.B) { + insList := makeNInstances(nums, weight) + e := discovery.Result{ + Cacheable: true, + CacheKey: "", + Instances: insList, + } + newConsist := newconsist.NewConsistInfo(e, newconsist.ConsistInfoConfig{ + VirtualFactor: 100, + Weighted: true, + }) + + b.ReportAllocs() + b.ResetTimer() + change := discovery.Change{ + Result: e, + Added: nil, + Updated: nil, + } + removed := []discovery.Instance{insList[0]} + for i := 0; i < nums; i++ { + e.Instances = insList[i+1:] + removed[0] = insList[i] + change.Result = e + change.Removed = removed + newConsist.Rebalance(change) + } + runtime.GC() + }) + + bb.Run(fmt.Sprintf("consist-add-%d", nums), func(b *testing.B) { + insList := makeNInstances(nums, weight) + e := discovery.Result{ + Cacheable: true, + CacheKey: "", + Instances: insList, + } + newConsist := newconsist.NewConsistInfo(e, newconsist.ConsistInfoConfig{ + VirtualFactor: 100, + Weighted: true, + }) + + b.ReportAllocs() + b.ResetTimer() + change := discovery.Change{ + Result: e, + Added: nil, + Updated: nil, + } + added := []discovery.Instance{insList[0]} + for i := 0; i < nums; i++ { + e.Instances = insList[:i+1] + added[0] = insList[i] + change.Result = e + change.Added = added + newConsist.Rebalance(change) + } + runtime.GC() + }) + nums *= 10 + } + +} + +func BenchmarkNewConsistInfo(b *testing.B) { + weight := 10 + nums := 10 + for n := 0; n < 4; n++ { + b.Run(fmt.Sprintf("new-consist-%d", nums), func(b *testing.B) { + insList := makeNInstances(nums, weight) + e := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: insList, + } + b.ResetTimer() + b.ReportAllocs() + newConsist := newconsist.NewConsistInfo(e, newconsist.ConsistInfoConfig{ + VirtualFactor: 100, + Weighted: true, + }) + _ = newConsist + }) + nums *= 10 + } +} diff --git a/pkg/loadbalance/newconsist/newconsist.go b/pkg/loadbalance/newconsist/newconsist.go new file mode 100644 index 0000000000..5ebd3dd4bc --- /dev/null +++ b/pkg/loadbalance/newconsist/newconsist.go @@ -0,0 +1,184 @@ +package newconsist + +import ( + "github.com/bytedance/gopkg/util/xxhash3" + "github.com/cloudwego/kitex/pkg/discovery" + "github.com/cloudwego/kitex/pkg/utils" + "math" + "sync" +) + +const ( + maxAddrLength int = 45 // used for construct + nodeRatio float64 = 2.1 + defaultNodeWeight int = 10 +) + +type realNode struct { + discovery.Instance +} + +type virtualNode struct { + realNode discovery.Instance + value uint64 + next []*virtualNode +} + +type ConsistInfoConfig struct { + VirtualFactor uint32 + Weighted bool +} + +// consistent hash +type ConsistInfo struct { + mu sync.RWMutex + cfg ConsistInfoConfig + lastRes discovery.Result + virtualNodes *skipList + // cache for calculate hash + hashByte []byte +} + +func NewConsistInfo(result discovery.Result, cfg ConsistInfoConfig) *ConsistInfo { + info := &ConsistInfo{ + cfg: cfg, + virtualNodes: newSkipList(), + lastRes: result, + } + info.hashByte = make([]byte, 0, utils.GetUIntLen(uint64(defaultNodeWeight*int(cfg.VirtualFactor)))+maxAddrLength+1) + info.batchAddAllVirtual(result.Instances) + return info +} + +func (info *ConsistInfo) IsEmpty() bool { + return len(info.lastRes.Instances) == 0 +} + +func (info *ConsistInfo) BuildConsistentResult(value uint64) discovery.Instance { + info.mu.RLock() + defer info.mu.RUnlock() + + if n := info.virtualNodes.FindGreater(value); n != nil { + return n.realNode + } + return nil +} + +func (info *ConsistInfo) Rebalance(change discovery.Change) { + info.mu.Lock() + defer info.mu.Unlock() + + info.lastRes = change.Result + // update + // TODO: optimize update logic + if len(change.Updated) > 0 { + info.virtualNodes = newSkipList() + info.batchAddAllVirtual(change.Result.Instances) + return + } + // add + info.batchAddAllVirtual(change.Added) + // delete + for _, ins := range change.Removed { + l := ins.Weight() * int(info.cfg.VirtualFactor) + addrByte := utils.StringToSliceByte(ins.Address().String()) + info.removeAllVirtual(l, addrByte) + } + +} + +func (info *ConsistInfo) getVirtualNodeHash(addr []byte, idx int) uint64 { + b := info.hashByte + b = append(b, addr...) + b = append(b, '#') + b = append(b, byte(idx)) + hashValue := xxhash3.Hash(b) + + b = b[:0] + return hashValue +} + +func (info *ConsistInfo) prepareByteHash(virtualNum int) { + newCap := utils.GetUIntLen(uint64(virtualNum)) + maxAddrLength + 1 + if newCap > cap(info.hashByte) { + info.hashByte = make([]byte, 0, newCap) + } +} + +func (info *ConsistInfo) batchAddAllVirtual(realNode []discovery.Instance) { + totalNode := 0 + maxNodeLen := 0 + for i := 0; i < len(realNode); i++ { + nodeLen := info.getVirtualNodeLen(realNode[i]) + if nodeLen > maxNodeLen { + maxNodeLen = nodeLen + } + totalNode += nodeLen + } + info.prepareByteHash(maxNodeLen) + + vns := make([]virtualNode, totalNode) + + var idx uint64 = 0 + estimatedTotalNode := math.Round(nodeRatio * float64(totalNode)) + info.virtualNodes.prepareNode(int(estimatedTotalNode)) + for i := 0; i < len(realNode); i++ { + vLen := info.getVirtualNodeLen(realNode[i]) + addrByte := utils.StringToSliceByte(realNode[i].Address().String()) + for j := 0; j < vLen; j++ { + vns[idx].realNode = realNode[i] + vns[idx].value = info.getVirtualNodeHash(addrByte, j) + info.virtualNodes.Insert(&vns[idx]) + idx++ + } + } + +} + +func (info *ConsistInfo) addAllVirtual(node discovery.Instance) { + l := info.getVirtualNodeLen(node) + addrByte := utils.StringToSliceByte(node.Address().String()) + + vns := make([]virtualNode, l) + for i := 0; i < l; i++ { + vv := info.getVirtualNodeHash(addrByte, i) + vns[i].realNode = node + vns[i].value = vv + info.virtualNodes.Insert(&vns[i]) + } +} + +func (info *ConsistInfo) removeAllVirtual(virtualNum int, addrByte []byte) { + for i := 0; i < virtualNum; i++ { + vv := info.getVirtualNodeHash(addrByte, i) + info.virtualNodes.Delete(vv) + } +} + +func (info *ConsistInfo) getVirtualNodeLen(node discovery.Instance) int { + if info.cfg.Weighted { + return node.Weight() * int(info.cfg.VirtualFactor) + } + return int(info.cfg.VirtualFactor) +} + +// only for test +func searchRealNode(info *ConsistInfo, node *realNode) (bool, bool) { + var ( + foundOne = false + foundAll = true + ) + l := info.getVirtualNodeLen(node) + addrByte := utils.StringToSliceByte(node.Address().String()) + + for i := 0; i < l; i++ { + vv := info.getVirtualNodeHash(addrByte, i) + ok := info.virtualNodes.Search(vv) + if ok { + foundOne = true + } else { + foundAll = false + } + } + return foundOne, foundAll +} diff --git a/pkg/loadbalance/newconsist/newconsist_test.go b/pkg/loadbalance/newconsist/newconsist_test.go new file mode 100644 index 0000000000..cbd478c3f9 --- /dev/null +++ b/pkg/loadbalance/newconsist/newconsist_test.go @@ -0,0 +1,203 @@ +package newconsist + +import ( + "fmt" + "github.com/bytedance/gopkg/util/xxhash3" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/discovery" + "math/rand" + "strconv" + "testing" + "time" +) + +func TestNewSkipList(t *testing.T) { + s := newSkipList() + dataCnt := 10000 + for i := 0; i < dataCnt; i++ { + s.Insert(&virtualNode{realNode: nil, value: uint64(i)}) + test.Assert(t, s.Search(uint64(i))) + } + totalCnt := 0 + for i := 0; i < s.totalLevel; i++ { + currCnt := countLevel(s, i) + totalCnt += currCnt + } + fmt.Printf("totalCnt: %d, ratio: %f\n", totalCnt, float64(totalCnt)/float64(dataCnt)) + for i := 0; i < dataCnt; i++ { + s.Delete(uint64(i)) + currCnt := countLevel(s, 0) + test.Assert(t, currCnt == dataCnt-i-1) + } +} + +func TestFuzzSkipList(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + sl := newSkipList() + + vs := make([]uint64, 1000) + // 插入1000个随机值 + for i := 0; i < 1000; i++ { + value := rand.Uint64() % 10000 + vs[i] = value + node := &virtualNode{nil, value, nil} + sl.Insert(node) + } + + // 搜索1000个随机值 + for i := 0; i < 1000; i++ { + found := sl.Search(vs[i]) + test.Assert(t, found) + } + + // 删除500个随机值 + for i := 0; i < 500; i++ { + value := vs[i] + sl.Delete(value) + } + for i := 500; i < 1000; i++ { + found := sl.Search(vs[i]) + test.Assert(t, found) + } +} + +func TestGetVirtualNodeHash(t *testing.T) { + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 10, nil), + discovery.NewInstance("tcp", "addr3", 10, nil), + discovery.NewInstance("tcp", "addr4", 10, nil), + discovery.NewInstance("tcp", "addr5", 10, nil), + } + e := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: insList, + } + info := NewConsistInfo(e, ConsistInfoConfig{VirtualFactor: 100, Weighted: true}) + info.getVirtualNodeHash([]byte{1, 2, 3}, 1) +} + +func Test_getConsistentResult(t *testing.T) { + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 10, nil), + discovery.NewInstance("tcp", "addr3", 10, nil), + discovery.NewInstance("tcp", "addr4", 10, nil), + discovery.NewInstance("tcp", "addr5", 10, nil), + } + e := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: insList, + } + info := NewConsistInfo(e, ConsistInfoConfig{VirtualFactor: 100, Weighted: true}) + newInsList := make([]discovery.Instance, len(insList)-1) + copy(newInsList, insList[1:]) + newResult := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: newInsList, + } + change := discovery.Change{ + Result: newResult, + Added: nil, + Removed: []discovery.Instance{insList[0]}, + Updated: nil, + } + _ = change + //info.Rebalance(change) + cnt := make(map[string]int) + for _, ins := range insList { + cnt[ins.Address().String()] = 0 + } + cnt["null"] = 0 + for i := 0; i < 100000; i++ { + if res := info.BuildConsistentResult(xxhash3.HashString(strconv.Itoa(i))); res != nil { + cnt[res.Address().String()]++ + } else { + cnt["null"]++ + } + } + fmt.Println(cnt) +} + +func TestRebalance(t *testing.T) { + nums := 1000 + insList := make([]discovery.Instance, 0, nums) + for i := 0; i < nums; i++ { + insList = append(insList, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) + } + e := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: insList, + } + newConsist := NewConsistInfo(e, ConsistInfoConfig{ + VirtualFactor: 100, + Weighted: true, + }) + for i := 0; i < nums; i++ { + _, all := searchRealNode(newConsist, &realNode{insList[i]}) + // should find all virtual node + test.Assert(t, all) + } + for i := 0; i < nums; i++ { + e.Instances = insList[i+1:] + change := discovery.Change{ + Result: e, + Added: nil, + Removed: []discovery.Instance{insList[i]}, + Updated: nil, + } + newConsist.Rebalance(change) + + one, _ := searchRealNode(newConsist, &realNode{insList[i]}) + // no virtual node should be found + test.Assert(t, !one) + } +} + +func TestRebalanceDupilicate(t *testing.T) { + nums := 1000 + duplicate := 10 + insList := make([]discovery.Instance, 0, nums) + for i := 0; i < nums; i++ { + insList = append(insList, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) + } + e := discovery.Result{ + Cacheable: false, + CacheKey: "", + Instances: insList, + } + newConsist := NewConsistInfo(e, ConsistInfoConfig{ + VirtualFactor: 100, + Weighted: true, + }) + + for i := 0; i < nums; i++ { + e.Instances = insList[i+1:] + change := discovery.Change{ + Result: e, + Added: nil, + Removed: []discovery.Instance{insList[i]}, + Updated: nil, + } + for j := 0; j < duplicate; j++ { + newConsist.Rebalance(change) + one, _ := searchRealNode(newConsist, &realNode{insList[i]}) + // no virtual node should be found + test.Assert(t, !one) + } + } +} + +func countLevel(s *skipList, level int) int { + n := s.dummy + cnt := 0 + for n.next[level] != nil { + n = n.next[level] + cnt++ + } + return cnt +} diff --git a/pkg/loadbalance/newconsist/skiplist.go b/pkg/loadbalance/newconsist/skiplist.go new file mode 100644 index 0000000000..50ac6da569 --- /dev/null +++ b/pkg/loadbalance/newconsist/skiplist.go @@ -0,0 +1,168 @@ +package newconsist + +import "github.com/bytedance/gopkg/lang/fastrand" + +const ( + MAX_LEVEL = 32 // max level of skip list +) + +// TODO: optimize allocation +// skipList for consistent hash +// not concurrent safe +type skipList struct { + dummy *virtualNode + totalLevel int + + // only for insert and delete + updateCache []*virtualNode + // node cache + nodeCache []*virtualNode +} + +// newSkipList returns a new skip list +func newSkipList() *skipList { + return &skipList{ + dummy: &virtualNode{nil, 0, make([]*virtualNode, MAX_LEVEL)}, + totalLevel: 1, + updateCache: make([]*virtualNode, MAX_LEVEL), + } +} + +// Insert inserts a node into the skip list +func (sl *skipList) Insert(n *virtualNode) { + level := sl.randomLevel() + // temporary slice for update + var update []*virtualNode + if sl.updateCache != nil { + update = sl.updateCache[:maxValue(level, sl.totalLevel)] + } else { + update = make([]*virtualNode, maxValue(level, sl.totalLevel)) + } + + if level > sl.totalLevel { + // grow new level + for i := sl.totalLevel; i < level; i++ { + update[i-1] = sl.dummy + } + sl.totalLevel = level + } + + // search the node with greater value than the new value on each level + current := sl.dummy + for i := sl.totalLevel - 1; i >= 0; i-- { + for len(current.next) > i && current.next[i] != nil && current.next[i].value < n.value { + current = current.next[i] + } + update[i] = current + } + + // insert the node into the [0:level] levels of the list + newNode := n + n.next = sl.makeNewVirtualNode(level) + for i := 0; i < level; i++ { + newNode.next[i] = update[i].next[i] + update[i].next[i] = newNode + } +} + +// Delete removes the nodes with the input value +func (sl *skipList) Delete(value uint64) { + // temporary slice for update + var update []*virtualNode + if sl.updateCache != nil { + update = sl.updateCache[:sl.totalLevel] + } else { + update = make([]*virtualNode, sl.totalLevel) + } + + current := sl.dummy + // search the node with equal or greater value than the new value on each level + for i := sl.totalLevel - 1; i >= 0; i-- { + for current.next[i] != nil && current.next[i].value < value { + current = current.next[i] + } + update[i] = current + } + + current = current.next[0] + if current != nil && current.value == value { + // if the value is found, remove the node + for i := 0; i < sl.totalLevel; i++ { + // check from low to high level + // if exist in the current level, remove it. Otherwise, break the loop + if update[i].next[i] != current { + break + } + update[i].next[i] = current.next[i] + } + + // check from high to low level + // if no node in one level, remove this level + for sl.totalLevel > 1 && sl.dummy.next[sl.totalLevel-1] == nil { + sl.totalLevel-- + } + } +} + +// Search checks if the value can be found in the skip list +func (sl *skipList) Search(value uint64) bool { + current := sl.dummy + for i := sl.totalLevel - 1; i >= 0; i-- { + for current.next[i] != nil && current.next[i].value < value { + current = current.next[i] + } + } + current = current.next[0] + return current != nil && current.value == value +} + +// FindGreater finds the first node with greater value than the input value +func (sl *skipList) FindGreater(value uint64) *virtualNode { + current := sl.dummy + for i := sl.totalLevel - 1; i >= 0; i-- { + for current.next[i] != nil && current.next[i].value <= value { + current = current.next[i] + } + } + if res := current.next[0]; res != nil { + return res + } else { + // return the first node if not found since the skip list is treated as a ring + return sl.dummy.next[0] + } +} + +// randomLevel returns a random level for a new node +func (sl *skipList) randomLevel() int { + level := 1 + for fastrand.Float32() < 0.5 && level < MAX_LEVEL { + level++ + } + return level +} + +func (sl *skipList) prepareNode(num int) { + if len(sl.nodeCache) < num { + sl.nodeCache = append(sl.nodeCache, make([]*virtualNode, num)...) + } +} + +func (sl *skipList) makeNewVirtualNode(num int) []*virtualNode { + cacheLen := len(sl.nodeCache) + var res []*virtualNode + if cacheLen > num { + res = sl.nodeCache[:num] + sl.nodeCache = sl.nodeCache[num:] + } else { + res = make([]*virtualNode, num) + } + return res +} + +func maxValue(a, b int) int { + if a > b { + return a + } else { + return b + } +} diff --git a/pkg/loadbalance/weighted_balancer_test.go b/pkg/loadbalance/weighted_balancer_test.go index 57f74b87e6..ca31164721 100644 --- a/pkg/loadbalance/weighted_balancer_test.go +++ b/pkg/loadbalance/weighted_balancer_test.go @@ -219,7 +219,7 @@ func TestWeightedPicker_NoMoreInstance(t *testing.T) { func makeNInstances(n, weight int) (res []discovery.Instance) { for i := 0; i < n; i++ { - res = append(res, discovery.NewInstance("tcp", fmt.Sprintf("addr[%d]-weight[%d]", i, weight), weight, nil)) + res = append(res, discovery.NewInstance("tcp", strconv.Itoa(i), weight, nil)) } return }