diff --git a/.ai-agent/guidelines.md b/.ai-agent/guidelines.md index e02f4251..cacfb0ec 100644 --- a/.ai-agent/guidelines.md +++ b/.ai-agent/guidelines.md @@ -4,6 +4,8 @@ All the rules below are not optional and must be followed to the letter. # Go Rules +NEVER use `go build`, use `go run` instead. + ## rudder-go-kit https://github.com/rudderlabs/rudder-go-kit opinionated library for rudderstack. Should be used for all Golang projects, diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 3742a285..00000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "permissions": { - "allow": [ - "WebSearch", - "Bash(make test:*)", - "Bash(go vet:*)", - "Bash(go run:*)", - "Bash(go test:*)", - "Bash(gofmt:*)", - "Bash(git tag:*)" - ], - "deny": [], - "ask": [] - } -} diff --git a/.gitignore b/.gitignore index f10862a6..7b4c026c 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ /.env +.idea +.claude/settings.local.json diff --git a/cmd/node/main.go b/cmd/node/main.go index 3f4b0718..2f343dc9 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -16,15 +16,15 @@ import ( "syscall" "time" - "github.com/rudderlabs/keydb/release" - "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" "github.com/rudderlabs/keydb/internal/cloudstorage" "github.com/rudderlabs/keydb/internal/hash" "github.com/rudderlabs/keydb/node" pb "github.com/rudderlabs/keydb/proto" + "github.com/rudderlabs/keydb/release" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" _ "github.com/rudderlabs/rudder-go-kit/maxprocs" @@ -103,6 +103,7 @@ func run(ctx context.Context, cancel func(), conf *config.Config, stat stats.Sta if len(nodeAddresses) == 0 { return fmt.Errorf("no node addresses provided") } + degradedNodes := conf.GetReloadableStringVar("", "degradedNodes") nodeConfig := node.Config{ NodeID: uint32(nodeID), @@ -115,7 +116,24 @@ func run(ctx context.Context, cancel func(), conf *config.Config, stat stats.Sta GarbageCollectionInterval: conf.GetDuration("gcInterval", // node.DefaultGarbageCollectionInterval will be used 0, time.Nanosecond, ), - Addresses: strings.Split(nodeAddresses, ","), + Addresses: strings.Split(nodeAddresses, ","), + DegradedNodes: func() []bool { + raw := strings.TrimSpace(degradedNodes.Load()) + if raw == "" { + return nil + } + v := strings.Split(raw, ",") + b := make([]bool, len(v)) + for i, s := range v { + var err error + b[i], err = strconv.ParseBool(s) + if err != nil { + log.Warnn("Failed to parse degraded node", logger.NewStringField("v", raw), obskit.Error(err)) + return nil + } + } + return b + }, LogTableStructureDuration: conf.GetDuration("logTableStructureDuration", 10, time.Minute), BackupFolderName: conf.GetString("KUBE_NAMESPACE", ""), } @@ -157,8 +175,31 @@ func run(ctx context.Context, cancel func(), conf *config.Config, stat stats.Sta } }() - // create a gRPC server with latency interceptors + // Configure gRPC server keepalive parameters + grpcKeepaliveMinTime := conf.GetDuration("grpc.keepalive.minTime", 10, time.Second) + grpcKeepalivePermitWithoutStream := conf.GetBool("grpc.keepalive.permitWithoutStream", true) + grpcKeepaliveTime := conf.GetDuration("grpc.keepalive.time", 60, time.Second) + grpcKeepaliveTimeout := conf.GetDuration("grpc.keepalive.timeout", 20, time.Second) + + log.Infon("gRPC server keepalive configuration", + logger.NewDurationField("enforcementMinTime", grpcKeepaliveMinTime), + logger.NewBoolField("enforcementPermitWithoutStream", grpcKeepalivePermitWithoutStream), + logger.NewDurationField("serverTime", grpcKeepaliveTime), + logger.NewDurationField("serverTimeout", grpcKeepaliveTimeout), + ) + + // create a gRPC server with latency interceptors and keepalive parameters server := grpc.NewServer( + // Keepalive enforcement policy - controls what the server requires from clients + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: grpcKeepaliveMinTime, + PermitWithoutStream: grpcKeepalivePermitWithoutStream, + }), + // Keepalive parameters - controls server's own keepalive behavior + grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: grpcKeepaliveTime, + Timeout: grpcKeepaliveTimeout, + }), // Unary interceptor to record latency for unary RPCs grpc.UnaryInterceptor( func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( diff --git a/cmd/scaler/config.go b/cmd/scaler/config.go new file mode 100644 index 00000000..85b51a46 --- /dev/null +++ b/cmd/scaler/config.go @@ -0,0 +1,8 @@ +package main + +var defaultHistogramBuckets = []float64{ + 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, + 300 /* 5 mins */, 600 /* 10 mins */, 1800, /* 30 mins */ +} + +var customBuckets = map[string][]float64{} diff --git a/cmd/scaler/main.go b/cmd/scaler/main.go index 6065e435..71fec47b 100644 --- a/cmd/scaler/main.go +++ b/cmd/scaler/main.go @@ -13,8 +13,11 @@ import ( "github.com/rudderlabs/keydb/client" "github.com/rudderlabs/keydb/internal/scaler" + "github.com/rudderlabs/keydb/release" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + svcMetric "github.com/rudderlabs/rudder-go-kit/stats/metric" obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" ) @@ -26,13 +29,25 @@ func main() { defer logFactory.Sync() log := logFactory.NewLogger() - if err := run(ctx, cancel, conf, log); err != nil { + releaseInfo := release.NewInfo() + statsOptions := []stats.Option{ + stats.WithServiceName("keydb-scaler"), + stats.WithServiceVersion(releaseInfo.Version), + stats.WithDefaultHistogramBuckets(defaultHistogramBuckets), + } + for histogramName, buckets := range customBuckets { + statsOptions = append(statsOptions, stats.WithHistogramBuckets(histogramName, buckets)) + } + stat := stats.NewStats(conf, logFactory, svcMetric.NewManager(), statsOptions...) + defer stat.Stop() + + if err := run(ctx, cancel, conf, stat, log); err != nil { log.Fataln("failed to run", obskit.Error(err)) os.Exit(1) } } -func run(ctx context.Context, cancel func(), conf *config.Config, log logger.Logger) error { +func run(ctx context.Context, cancel func(), conf *config.Config, stat stats.Stats, log logger.Logger) error { defer cancel() nodeAddresses := conf.GetString("nodeAddresses", "") @@ -114,7 +129,7 @@ func run(ctx context.Context, cancel func(), conf *config.Config, log logger.Log // Create and start HTTP server serverAddr := conf.GetString("serverAddr", ":8080") - server := newHTTPServer(c, scClient, serverAddr, log) + server := newHTTPServer(c, scClient, serverAddr, stat, log) // Start server in a goroutine serverErrCh := make(chan error, 1) diff --git a/cmd/scaler/server.go b/cmd/scaler/server.go index 3d55a7cd..10239409 100644 --- a/cmd/scaler/server.go +++ b/cmd/scaler/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strconv" "strings" "time" @@ -18,9 +19,17 @@ import ( "github.com/rudderlabs/rudder-go-kit/httputil" "github.com/rudderlabs/rudder-go-kit/jsonrs" "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" ) +const ( + totalSnapshotsToLoadMetricName = "scaler_total_snapshots_to_load" + currentSnapshotsToLoadMetricName = "scaler_current_snapshots_loaded" + totalSnapshotsToCreateMetricName = "scaler_total_snapshots_to_create" + currentSnapshotsToCreateMetricName = "scaler_current_snapshots_created" +) + type scalerClient interface { Scale(ctx context.Context, nodeIDs []uint32) error ScaleComplete(ctx context.Context, nodeIDs []uint32) error @@ -38,14 +47,18 @@ type httpServer struct { client *client.Client scaler scalerClient server *http.Server + stat stats.Stats logger logger.Logger } // newHTTPServer creates a new HTTP server -func newHTTPServer(client *client.Client, scaler *scaler.Client, addr string, log logger.Logger) *httpServer { +func newHTTPServer( + client *client.Client, scaler *scaler.Client, addr string, stat stats.Stats, log logger.Logger, +) *httpServer { s := &httpServer{ client: client, scaler: scaler, + stat: stat, logger: log, } @@ -195,7 +208,10 @@ func (s *httpServer) handleCreateSnapshots(w http.ResponseWriter, r *http.Reques } // Create snapshot - if err := s.scaler.CreateSnapshots(r.Context(), req.NodeID, req.FullSync, req.HashRanges...); err != nil { + err := s.createSnapshotsWithProgress( + r.Context(), req.NodeID, req.FullSync, req.DisableCreateSnapshotsSequentially, req.HashRanges, + ) + if err != nil { http.Error(w, fmt.Sprintf("Error creating snapshot: %v", err), http.StatusInternalServerError) return } @@ -215,12 +231,28 @@ func (s *httpServer) handleLoadSnapshots(w http.ResponseWriter, r *http.Request) return } + // Initialize metrics + nodeIDStr := strconv.FormatInt(int64(req.NodeID), 10) + totalSnapshotsToLoad := s.stat.NewTaggedStat(totalSnapshotsToLoadMetricName, stats.GaugeType, stats.Tags{ + "nodeId": nodeIDStr, + }) + totalSnapshotsToLoad.Observe(float64(len(req.HashRanges))) + defer totalSnapshotsToLoad.Observe(0) + + currentSnapshotsLoaded := s.stat.NewTaggedStat(currentSnapshotsToLoadMetricName, stats.GaugeType, stats.Tags{ + "nodeId": nodeIDStr, + }) + currentSnapshotsLoaded.Observe(0) + // Load snapshots from cloud storage if err := s.scaler.LoadSnapshots(r.Context(), req.NodeID, req.MaxConcurrency, req.HashRanges...); err != nil { http.Error(w, fmt.Sprintf("Error loading snapshots: %v", err), http.StatusInternalServerError) return } + // Update metrics after successful load + currentSnapshotsLoaded.Observe(float64(len(req.HashRanges))) + // Write response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -335,11 +367,13 @@ func (s *httpServer) handleAutoScale(w http.ResponseWriter, r *http.Request) { err = s.handleScaleUp( r.Context(), req.OldNodesAddresses, req.NewNodesAddresses, req.FullSync, req.SkipCreateSnapshots, loadSnapshotsMaxConcurrency, + req.DisableCreateSnapshotsSequentially, ) } else if newClusterSize < oldClusterSize { err = s.handleScaleDown( r.Context(), req.OldNodesAddresses, req.NewNodesAddresses, req.FullSync, req.SkipCreateSnapshots, loadSnapshotsMaxConcurrency, + req.DisableCreateSnapshotsSequentially, ) } else { // Auto-healing: propagate cluster addresses to all nodes for consistency @@ -360,6 +394,7 @@ func (s *httpServer) handleAutoScale(w http.ResponseWriter, r *http.Request) { func (s *httpServer) handleScaleUp( ctx context.Context, oldAddresses, newAddresses []string, fullSync, skipCreateSnapshots bool, loadSnapshotsMaxConcurrency uint32, + disableCreateSnapshotsSequentially bool, ) error { oldClusterSize := uint32(len(oldAddresses)) newClusterSize := uint32(len(newAddresses)) @@ -405,7 +440,9 @@ func (s *httpServer) handleScaleUp( } group.Go(func() error { createSnapshotsStart := time.Now() - err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, fullSync, hashRanges...) + err := s.createSnapshotsWithProgress( + gCtx, sourceNodeID, fullSync, disableCreateSnapshotsSequentially, hashRanges, + ) if err != nil { return fmt.Errorf("creating snapshots from node %d for hash ranges %v: %w", sourceNodeID, hashRanges, err, @@ -468,6 +505,7 @@ func (s *httpServer) handleScaleUp( func (s *httpServer) handleScaleDown( ctx context.Context, oldAddresses, newAddresses []string, fullSync, skipCreateSnapshots bool, loadSnapshotsMaxConcurrency uint32, + disableCreateSnapshotsSequentially bool, ) error { oldClusterSize := uint32(len(oldAddresses)) newClusterSize := uint32(len(newAddresses)) @@ -504,7 +542,10 @@ func (s *httpServer) handleScaleDown( } group.Go(func() error { createSnapshotsStart := time.Now() - if err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, fullSync, hashRanges...); err != nil { + err := s.createSnapshotsWithProgress( + gCtx, sourceNodeID, fullSync, disableCreateSnapshotsSequentially, hashRanges, + ) + if err != nil { return fmt.Errorf("creating snapshots from node %d for hash ranges %v: %w", sourceNodeID, hashRanges, err, ) @@ -709,7 +750,9 @@ func (s *httpServer) handleHashRangeMovements(w http.ResponseWriter, r *http.Req // Call CreateSnapshots once per node with all hash ranges group.Go(func() error { start := time.Now() - err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, req.FullSync, hashRanges...) + err := s.createSnapshotsWithProgress( + gCtx, sourceNodeID, req.FullSync, req.DisableCreateSnapshotsSequentially, hashRanges, + ) if err != nil { return fmt.Errorf("creating snapshots for node %d: %w", sourceNodeID, err) } @@ -770,6 +813,54 @@ func (s *httpServer) handleHashRangeMovements(w http.ResponseWriter, r *http.Req } } +// createSnapshotsWithProgress creates snapshots either sequentially (one at a time) or in batch +// depending on the disableSequential flag +func (s *httpServer) createSnapshotsWithProgress( + ctx context.Context, nodeID uint32, fullSync, disableSequential bool, hashRanges []uint32, +) error { + // Initialize metrics + nodeIDStr := strconv.FormatInt(int64(nodeID), 10) + totalSnapshotsToCreate := s.stat.NewTaggedStat(totalSnapshotsToCreateMetricName, stats.GaugeType, stats.Tags{ + "nodeId": nodeIDStr, + }) + totalSnapshotsToCreate.Observe(float64(len(hashRanges))) + defer totalSnapshotsToCreate.Observe(0) + + currentSnapshotsCreated := s.stat.NewTaggedStat(currentSnapshotsToCreateMetricName, stats.GaugeType, stats.Tags{ + "nodeId": nodeIDStr, + }) + currentSnapshotsCreated.Observe(0) + + if disableSequential || len(hashRanges) == 0 { + // Call with all hash ranges at once (existing behavior) + err := s.scaler.CreateSnapshots(ctx, nodeID, fullSync, hashRanges...) + if err != nil { + return err + } + currentSnapshotsCreated.Observe(float64(len(hashRanges))) + return nil + } + + // Call CreateSnapshots once for each hash range + for i, hashRange := range hashRanges { + s.logger.Infon("Creating snapshot", + logger.NewIntField("nodeId", int64(nodeID)), + logger.NewIntField("hashRange", int64(hashRange)), + logger.NewIntField("progress", int64(i+1)), + logger.NewIntField("total", int64(len(hashRanges))), + ) + + if err := s.scaler.CreateSnapshots(ctx, nodeID, fullSync, hashRange); err != nil { + return fmt.Errorf("creating snapshot for hash range %d: %w", hashRange, err) + } + + // Update progress metric + currentSnapshotsCreated.Observe(float64(i + 1)) + } + + return nil +} + // handleLastOperation handles GET /lastOperation requests func (s *httpServer) handleLastOperation(w http.ResponseWriter, r *http.Request) { operation := s.scaler.GetLastOperation() diff --git a/cmd/scaler/server_test.go b/cmd/scaler/server_test.go index a06a7737..2e3ed377 100644 --- a/cmd/scaler/server_test.go +++ b/cmd/scaler/server_test.go @@ -1566,7 +1566,7 @@ func startScalerHTTPServer(t testing.TB, totalHashRanges uint32, rp scaler.Retry addr := fmt.Sprintf(":%d", freePort) - opServer := newHTTPServer(c, op, addr, log) + opServer := newHTTPServer(c, op, addr, stats.NOP, log) go func() { err := opServer.Start(context.Background()) if !errors.Is(err, http.ErrServerClosed) { @@ -1707,3 +1707,272 @@ func startMockNodeService(t testing.TB, identifier string) *mockNodeServiceServe return mockServer } + +func TestDegradedModeDuringScaling(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + pool.MaxWait = 1 * time.Minute + + newConf := func() *config.Config { + conf := config.New() + conf.Set("BadgerDB.Dedup.Path", t.TempDir()) + conf.Set("BadgerDB.Dedup.Compress", true) + return conf + } + + minioContainer, err := miniokit.Setup(pool, t) + require.NoError(t, err) + + cloudStorage := keydbth.GetCloudStorage(t, newConf(), minioContainer) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + totalHashRanges := uint32(3) + + // Create a variable to hold degraded state that can be updated during the test + degradedNodes := make([]bool, 2) + + // Create two nodes with DegradedNodes function + node0Conf := newConf() + node0, node0Address := getService(ctx, t, cloudStorage, node.Config{ + NodeID: 0, + ClusterSize: 2, + TotalHashRanges: totalHashRanges, + DegradedNodes: func() []bool { + return degradedNodes + }, + }, node0Conf) + + node1Conf := newConf() + node1, node1Address := getService(ctx, t, cloudStorage, node.Config{ + NodeID: 1, + ClusterSize: 2, + TotalHashRanges: totalHashRanges, + Addresses: []string{node0Address}, + DegradedNodes: func() []bool { + return degradedNodes + }, + }, node1Conf) + + // Start the Scaler HTTP Server + s := startScalerHTTPServer(t, totalHashRanges, scaler.RetryPolicy{ + Disabled: true, + }, node0Address, node1Address) + + // Test Put some initial data + _ = s.Do("/put", PutRequest{ + Keys: []string{"key1", "key2", "key3"}, TTL: testTTL, + }, true) + + // Test Get to verify data exists + body := s.Do("/get", GetRequest{ + Keys: []string{"key1", "key2", "key3", "key4"}, + }) + require.JSONEq(t, `{"key1":true,"key2":true,"key3":true,"key4":false}`, body) + + // Mark node 1 as degraded + degradedNodes[1] = true + + // Verify that node 1 rejects Get requests + resp, err := node1.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, resp.ErrorCode) + require.Len(t, resp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, resp.NodesAddresses[0]) + + // Verify that node 1 rejects Put requests + putResp, err := node1.Put(ctx, &pb.PutRequest{Keys: []string{"key5"}, TtlSeconds: uint64(5 * 60)}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, putResp.ErrorCode) + require.Len(t, putResp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, putResp.NodesAddresses[0]) + + // Verify that GetNodeInfo returns only non-degraded addresses + nodeInfo, err := node0.GetNodeInfo(ctx, &pb.GetNodeInfoRequest{NodeId: 0}) + require.NoError(t, err) + require.Len(t, nodeInfo.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, nodeInfo.NodesAddresses[0]) + + // Mark node 1 as non-degraded again + degradedNodes[1] = false + + // Verify that node 1 now accepts requests (should not return SCALING error) + resp, err = node1.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.NotEqual(t, pb.ErrorCode_SCALING, resp.ErrorCode, "Node 1 should not be in degraded mode") + require.Len(t, resp.NodesAddresses, 2, "All non-degraded nodes should be in NodesAddresses") + + cancel() + node0.Close() + node1.Close() +} + +func TestScaleUpInDegradedMode(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + pool.MaxWait = 1 * time.Minute + + newConf := func() *config.Config { + conf := config.New() + conf.Set("BadgerDB.Dedup.Path", t.TempDir()) + return conf + } + + minioContainer, err := miniokit.Setup(pool, t) + require.NoError(t, err) + + cloudStorage := keydbth.GetCloudStorage(t, newConf(), minioContainer) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + totalHashRanges := uint32(3) + + // Create a variable to hold degraded state that can be updated during the test + degradedNodes := make([]bool, 1) + + // Step 1: Create a cluster with 1 node + node0Conf := newConf() + node0, node0Address := getService(ctx, t, cloudStorage, node.Config{ + NodeID: 0, + ClusterSize: 1, + TotalHashRanges: totalHashRanges, + DegradedNodes: func() []bool { return degradedNodes }, + }, node0Conf) + + // Start the Scaler HTTP Server + s := startScalerHTTPServer(t, totalHashRanges, scaler.RetryPolicy{ + Disabled: true, + }, node0Address) + + // Step 2: Add keys via Put and verify them via Get + // Add enough keys to ensure distribution across hash ranges + // Based on hash distribution with clusterSize=2, totalHashRanges=3: + // key1, key2, key3 → node0; key4, key5, key6, key7, key8 → will distribute across nodes + _ = s.Do("/put", PutRequest{ + Keys: []string{"key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8"}, TTL: testTTL, + }, true) + + body := s.Do("/get", GetRequest{ + Keys: []string{"key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8"}, + }) + require.JSONEq(t, + `{"key1":true,"key2":true,"key3":true,"key4":true,"key5":true,"key6":true,"key7":true,"key8":true}`, + body, + ) + + // Step 3: Create a second node + node1Conf := newConf() + node1, node1Address := getService(ctx, t, cloudStorage, node.Config{ + NodeID: 1, + ClusterSize: 2, + TotalHashRanges: totalHashRanges, + Addresses: []string{node0Address}, + DegradedNodes: func() []bool { return degradedNodes }, + }, node1Conf) + + // Step 4: Update degradedNodes - mark node 1 as degraded + degradedNodes = append(degradedNodes, true) //nolint:makezero + + // Verify that node 1 is in degraded mode + resp, err := node1.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, resp.ErrorCode) + + // Step 5: Use /autoScale to scale the cluster while node1 is degraded + _ = s.Do("/autoScale", AutoScaleRequest{ + OldNodesAddresses: []string{node0Address}, + NewNodesAddresses: []string{node0Address, node1Address}, + }, true) + + // Verify scale up worked - check node info + // Note: While node1 is degraded, NodesAddresses will only include non-degraded nodes + body = s.Do("/info", InfoRequest{NodeID: 0}) + infoResponse := pb.GetNodeInfoResponse{} + require.NoError(t, jsonrs.Unmarshal([]byte(body), &infoResponse)) + require.EqualValues(t, 2, infoResponse.ClusterSize) + require.Len(t, infoResponse.NodesAddresses, 1, + "Only non-degraded node should be in NodesAddresses while node1 is degraded", + ) + require.Equal(t, node0Address, infoResponse.NodesAddresses[0]) + require.ElementsMatch(t, []uint32{0, 1}, infoResponse.HashRanges) + + body = s.Do("/info", InfoRequest{NodeID: 1}) + infoResponse = pb.GetNodeInfoResponse{} + require.NoError(t, jsonrs.Unmarshal([]byte(body), &infoResponse)) + require.EqualValues(t, 2, infoResponse.ClusterSize) + require.Len(t, infoResponse.NodesAddresses, 1, + "Only non-degraded node should be in NodesAddresses while node1 is degraded", + ) + require.Equal(t, node0Address, infoResponse.NodesAddresses[0]) + require.ElementsMatch(t, []uint32{2}, infoResponse.HashRanges) + + // Step 6: mark node 1 as non-degraded + degradedNodes[1] = false + + // Verify that node 1 now accepts requests and has loaded snapshots correctly + // Node1 owns hash range 2, which contains key4 and other keys + // Let's determine which keys belong to node1's hash range + h := hash.New(2, totalHashRanges) + allKeys := []string{"key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8"} + var node1Keys []string + for _, key := range allKeys { + if h.GetNodeNumber(key) == 1 { + node1Keys = append(node1Keys, key) + } + } + require.Greater(t, len(node1Keys), 0, "Node 1 should own at least one key") + t.Logf("Node 1 owns %d keys: %v", len(node1Keys), node1Keys) + + // Query node1 directly for keys that belong to its hash range + resp, err = node1.Get(ctx, &pb.GetRequest{Keys: node1Keys}) + require.NoError(t, err) + require.NotEqual(t, pb.ErrorCode_SCALING, resp.ErrorCode, "Node 1 should not be in degraded mode") + require.Len(t, resp.Exists, len(node1Keys), "Node 1 should return results for all its keys") + for _, exists := range resp.Exists { + require.True(t, exists, "All keys belonging to node1 should exist after loading snapshots") + } + + // Verify that now both nodes appear in NodesAddresses + body = s.Do("/info", InfoRequest{NodeID: 0}) + infoResponse = pb.GetNodeInfoResponse{} + require.NoError(t, jsonrs.Unmarshal([]byte(body), &infoResponse)) + require.EqualValues(t, 2, infoResponse.ClusterSize) + require.Len(t, infoResponse.NodesAddresses, 2, + "Both nodes should be in NodesAddresses after node1 is no longer degraded", + ) + require.Contains(t, infoResponse.NodesAddresses, node0Address) + require.Contains(t, infoResponse.NodesAddresses, node1Address) + + // Step 7: Verify that the cluster is scaled and Get and Put are now served by both nodes + // Verify all existing data is accessible via scaler + body = s.Do("/get", GetRequest{ + Keys: []string{"key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8"}, + }) + require.JSONEq(t, + `{"key1":true,"key2":true,"key3":true,"key4":true,"key5":true,"key6":true,"key7":true,"key8":true}`, + body, + ) + + // Test Put with new keys now that both nodes are operational + _ = s.Do("/put", PutRequest{ + Keys: []string{"key9", "key10", "key11"}, TTL: testTTL, + }, true) + + // Verify all keys including new ones are accessible + body = s.Do("/get", GetRequest{ + Keys: []string{"key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8", "key9", "key10", "key11"}, + }) + require.JSONEq(t, + `{ + "key1":true,"key2":true,"key3":true,"key4":true,"key5":true,"key6":true, + "key7":true,"key8":true,"key9":true,"key10":true,"key11":true + }`, + body, + ) + + cancel() + node0.Close() + node1.Close() +} diff --git a/cmd/scaler/types.go b/cmd/scaler/types.go index ab5e75b6..b2247e71 100644 --- a/cmd/scaler/types.go +++ b/cmd/scaler/types.go @@ -22,9 +22,10 @@ type InfoRequest struct { // CreateSnapshotsRequest represents a request to create snapshots type CreateSnapshotsRequest struct { - NodeID uint32 `json:"node_id"` - FullSync bool `json:"full_sync"` - HashRanges []uint32 `json:"hash_ranges,omitempty"` + NodeID uint32 `json:"node_id"` + FullSync bool `json:"full_sync"` + HashRanges []uint32 `json:"hash_ranges,omitempty"` + DisableCreateSnapshotsSequentially bool `json:"disable_create_snapshots_sequentially,omitempty"` } // LoadSnapshotsRequest represents a request to load snapshots @@ -67,20 +68,22 @@ type AutoScaleRequest struct { NewNodesAddresses []string `json:"new_nodes_addresses"` // FullSync indicates whether to perform a full synchronization during snapshot creation. // When true, all data will be included in snapshots regardless of incremental changes. - FullSync bool `json:"full_sync,omitempty"` - SkipCreateSnapshots bool `json:"skip_create_snapshots,omitempty"` - LoadSnapshotsMaxConcurrency uint32 `json:"load_snapshots_max_concurrency,omitempty"` + FullSync bool `json:"full_sync,omitempty"` + SkipCreateSnapshots bool `json:"skip_create_snapshots,omitempty"` + LoadSnapshotsMaxConcurrency uint32 `json:"load_snapshots_max_concurrency,omitempty"` + DisableCreateSnapshotsSequentially bool `json:"disable_create_snapshots_sequentially,omitempty"` } // HashRangeMovementsRequest represents a request to preview hash range movements type HashRangeMovementsRequest struct { - OldClusterSize uint32 `json:"old_cluster_size"` - NewClusterSize uint32 `json:"new_cluster_size"` - TotalHashRanges uint32 `json:"total_hash_ranges"` - Upload bool `json:"upload,omitempty"` - Download bool `json:"download,omitempty"` - FullSync bool `json:"full_sync,omitempty"` - LoadSnapshotsMaxConcurrency uint32 `json:"load_snapshots_max_concurrency,omitempty"` + OldClusterSize uint32 `json:"old_cluster_size"` + NewClusterSize uint32 `json:"new_cluster_size"` + TotalHashRanges uint32 `json:"total_hash_ranges"` + Upload bool `json:"upload,omitempty"` + Download bool `json:"download,omitempty"` + FullSync bool `json:"full_sync,omitempty"` + LoadSnapshotsMaxConcurrency uint32 `json:"load_snapshots_max_concurrency,omitempty"` + DisableCreateSnapshotsSequentially bool `json:"disable_create_snapshots_sequentially,omitempty"` } type HashRangeMovementsResponse struct { diff --git a/internal/cache/badger/badger.go b/internal/cache/badger/badger.go index 735c6c2d..924da7e3 100644 --- a/internal/cache/badger/badger.go +++ b/internal/cache/badger/badger.go @@ -122,7 +122,7 @@ func New(conf *config.Config, log logger.Logger) (*Cache, error) { discardRatio: conf.GetFloat64("BadgerDB.Dedup.DiscardRatio", 0.7), debugMode: conf.GetBool("BadgerDB.DebugMode", false), jitterEnabled: conf.GetBool("cache.ttlJitter.enabled", false), - jitterDuration: conf.GetDuration("cache.ttlJitter", 1, time.Hour), + jitterDuration: conf.GetDuration("cache.ttlJitter.duration", 1, time.Hour), }, nil } diff --git a/node/config.go b/node/config.go new file mode 100644 index 00000000..5c4834f5 --- /dev/null +++ b/node/config.go @@ -0,0 +1,47 @@ +package node + +import "time" + +// Config holds the configuration for a node +type Config struct { + // NodeID is the ID of this node (0-based) + NodeID uint32 + + // TotalHashRanges is the total number of hash ranges + TotalHashRanges uint32 + + // MaxFilesToList specifies the maximum number of files that can be listed in a single operation. + MaxFilesToList int64 + + // SnapshotInterval is the interval for creating snapshots (in seconds) + SnapshotInterval time.Duration + + // GarbageCollectionInterval defines the duration between automatic GC operation per cache + GarbageCollectionInterval time.Duration + + // Addresses is a list of node addresses that this node will advertise to clients + Addresses []string + + // DegradedNodes is a list of nodes that are considered degraded and should not be used for reads and writes. + DegradedNodes func() []bool + + // logTableStructureDuration defines the duration for which the table structure is logged + LogTableStructureDuration time.Duration + + // backupFolderName is the name of the folder in the S3 bucket where snapshots are stored + BackupFolderName string +} + +func (c *Config) getClusterSize() uint32 { + l := uint32(len(c.Addresses)) + if c.DegradedNodes == nil { + return l + } + degradedNodes := c.DegradedNodes() + for _, degraded := range degradedNodes { + if degraded { + l-- + } + } + return l +} diff --git a/node/node.go b/node/node.go index 00b6121a..fdd88d75 100644 --- a/node/node.go +++ b/node/node.go @@ -46,57 +46,25 @@ const ( // file format is hr__s__.snapshot var snapshotFilenameRegex = regexp.MustCompile(`^.+/hr_(\d+)_s_(\d+)_(\d+).snapshot$`) -// Config holds the configuration for a node -type Config struct { - // NodeID is the ID of this node (0-based) - NodeID uint32 - - // ClusterSize is the total number of nodes in the cluster - ClusterSize uint32 - - // TotalHashRanges is the total number of hash ranges - TotalHashRanges uint32 - - // MaxFilesToList specifies the maximum number of files that can be listed in a single operation. - MaxFilesToList int64 - - // SnapshotInterval is the interval for creating snapshots (in seconds) - SnapshotInterval time.Duration - - // GarbageCollectionInterval defines the duration between automatic GC operation per cache - GarbageCollectionInterval time.Duration - - // Addresses is a list of node addresses that this node will advertise to clients - Addresses []string - - // logTableStructureDuration defines the duration for which the table structure is logged - LogTableStructureDuration time.Duration - - // backupFolderName is the name of the folder in the S3 bucket where snapshots are stored - BackupFolderName string -} - // Service implements the NodeService gRPC service type Service struct { pb.UnimplementedNodeServiceServer config Config - // mu protects cache, scaling, since, lastSnapshotTime and hasher + // mu protects cache, scaling, lastSnapshotTime and hasher mu sync.RWMutex cache Cache scaling bool - since map[uint32]uint64 lastSnapshotTime time.Time hasher *hash.Hash - now func() time.Time - maxFilesToList int64 - forceSkipFilesListing config.ValueLoader[bool] - waitGroup sync.WaitGroup - storage cloudStorage - stats stats.Stats - logger logger.Logger + now func() time.Time + maxFilesToList int64 + waitGroup sync.WaitGroup + storage cloudStorage + stats stats.Stats + logger logger.Logger metrics struct { getKeysCounters map[uint32]stats.Counter @@ -183,14 +151,12 @@ func NewService( } service := &Service{ - now: time.Now, - config: config, - storage: storage, - since: make(map[uint32]uint64), - maxFilesToList: config.MaxFilesToList, - forceSkipFilesListing: kitConf.GetReloadableBoolVar(false, "NodeService.forceSkipFilesListing"), - hasher: hash.New(config.ClusterSize, config.TotalHashRanges), - stats: stat, + now: time.Now, + config: config, + storage: storage, + maxFilesToList: config.MaxFilesToList, + hasher: hash.New(config.getClusterSize(), config.TotalHashRanges), + stats: stat, logger: log.Withn( logger.NewIntField("nodeId", int64(config.NodeID)), logger.NewIntField("totalHashRanges", int64(config.TotalHashRanges)), @@ -351,80 +317,14 @@ func (s *Service) initCaches( s.metrics.putKeysCounter[r] = s.stats.NewTaggedStat("keydb_put_keys_count", stats.CountType, statsTags) } - // List all files in the bucket - var ( - err error - files []*filemanager.FileInfo - ) - if !s.forceSkipFilesListing.Load() { - list := s.storage.ListFilesWithPrefix(ctx, "", s.getSnapshotFilenamePrefix(), s.maxFilesToList) - files, err = list.Next() - if err != nil { - return fmt.Errorf("failed to list snapshot files: %w", err) - } - } - if len(files) == 0 { - s.logger.Infon("No snapshots found, skipping caches initialization") - return nil - } - - var selected map[uint32]struct{} - if len(selectedHashRanges) > 0 { - selected = make(map[uint32]struct{}, len(selectedHashRanges)) - for _, r := range selectedHashRanges { - selected[r] = struct{}{} - } - } else { // no hash range was selected, download the data for all the ranges handled by this node - selected = currentRanges - } - - totalFiles := int64(0) - filesByHashRange := make(map[uint32][]string, len(files)) - for _, file := range files { - matches := snapshotFilenameRegex.FindStringSubmatch(file.Key) - if len(matches) != 4 { - continue - } - hashRangeInt, err := strconv.Atoi(matches[1]) - if err != nil { - s.logger.Warnn("Invalid snapshot filename (hash range)", logger.NewStringField("filename", file.Key)) - continue - } - hashRange := uint32(hashRangeInt) - if len(selectedHashRanges) > 0 { - if _, shouldHandle := selected[hashRange]; !shouldHandle { - s.logger.Warnn("Ignoring snapshot file for hash range since it was not selected") - continue - } - } - since, err := strconv.Atoi(matches[3]) // getting "to", not "from" - if err != nil { - s.logger.Warnn("Invalid snapshot filename (since)", logger.NewStringField("filename", file.Key)) - continue - } - if s.since[hashRange] < uint64(since) { - s.since[hashRange] = uint64(since) - } - if filesByHashRange[hashRange] == nil { - filesByHashRange[hashRange] = make([]string, 0, 1) - } - - s.logger.Debugn("Found snapshot file", - logger.NewIntField("hashRange", int64(hashRange)), - logger.NewStringField("filename", file.Key), - ) - filesByHashRange[hashRange] = append(filesByHashRange[hashRange], file.Key) - totalFiles++ - } - if !download { - // We still had to do the above in order to populate the "since" map - s.logger.Infon("Downloading disabled, skipping snapshots initialization") + s.logger.Infon("Downloading disabled, skipping caches initialization") return nil } - for i := range filesByHashRange { - sort.Strings(filesByHashRange[i]) + totalFiles, filesByHashRange, err := s.listSnapshots(ctx, selectedHashRanges...) + if err != nil { + return fmt.Errorf("list snapshots: %w", err) } if maxConcurrency == 0 { @@ -436,7 +336,7 @@ func (s *Service) initCaches( loadDone = make(chan error, 1) filesLoaded int64 group, gCtx = kitsync.NewEagerGroup(ctx, 0) - readers = make(chan snapshot, maxConcurrency) + readers = make(chan snapshotReader, maxConcurrency) ) defer loadCancel() go func() { @@ -445,7 +345,7 @@ func (s *Service) initCaches( s.logger.Infon("Loading downloaded snapshots", logger.NewIntField("range", int64(sn.hashRange)), logger.NewStringField("filename", sn.filename), - logger.NewIntField("totalFiles", totalFiles), + logger.NewIntField("totalFiles", int64(totalFiles)), logger.NewFloatField("loadingPercentage", float64(filesLoaded)*100/float64(totalFiles), ), @@ -469,12 +369,12 @@ func (s *Service) initCaches( for _, snapshotFile := range snapshotFiles { group.Go(func() error { s.logger.Infon("Starting download of snapshot file from cloud storage", - logger.NewStringField("filename", snapshotFile), + logger.NewStringField("filename", snapshotFile.filename), ) startDownload := time.Now() buf := manager.NewWriteAtBuffer([]byte{}) - err := s.storage.Download(gCtx, buf, snapshotFile) + err := s.storage.Download(gCtx, buf, snapshotFile.filename) if err != nil { if errors.Is(err, filemanager.ErrKeyNotFound) { s.logger.Warnn("No cached snapshot for range", logger.NewIntField("range", int64(r))) @@ -487,8 +387,8 @@ func (s *Service) initCaches( select { case <-gCtx.Done(): return gCtx.Err() - case readers <- snapshot{ - filename: snapshotFile, + case readers <- snapshotReader{ + filename: snapshotFile.filename, hashRange: r, reader: bytes.NewReader(buf.Bytes()), }: @@ -524,17 +424,104 @@ func (s *Service) initCaches( return nil } +func (s *Service) listSnapshots(ctx context.Context, selectedHashRanges ...uint32) ( + int, + map[uint32][]snapshotFile, + error, +) { + // List all files in the bucket + list := s.storage.ListFilesWithPrefix(ctx, "", s.getSnapshotFilenamePrefix(), s.maxFilesToList) + files, err := list.Next() + if err != nil { + return 0, nil, fmt.Errorf("failed to list snapshot files: %w", err) + } + if len(files) == 0 { + s.logger.Infon("No snapshots found, skipping caches initialization") + return 0, nil, nil + } + + var selectedRangesMap map[uint32]struct{} + if len(selectedHashRanges) > 0 { + selectedRangesMap = make(map[uint32]struct{}, len(selectedHashRanges)) + for _, r := range selectedHashRanges { + selectedRangesMap[r] = struct{}{} + } + } else { // no hash range was selected, download the data for all the ranges handled by this node + selectedRangesMap = s.getCurrentRanges() + } + + totalFiles := 0 + filesByHashRange := make(map[uint32][]snapshotFile, len(files)) + for _, file := range files { + matches := snapshotFilenameRegex.FindStringSubmatch(file.Key) + if len(matches) != 4 { + continue + } + hashRangeInt, err := strconv.Atoi(matches[1]) + if err != nil { + s.logger.Warnn("Invalid snapshot filename (hash range)", logger.NewStringField("filename", file.Key)) + continue + } + hashRange := uint32(hashRangeInt) + if len(selectedHashRanges) > 0 { + if _, shouldHandle := selectedRangesMap[hashRange]; !shouldHandle { + s.logger.Warnn("Ignoring snapshot file for hash range since it was not selected") + continue + } + } + from, err := strconv.ParseUint(matches[2], 10, 64) + if err != nil { + s.logger.Warnn("Invalid snapshot filename (from)", logger.NewStringField("filename", file.Key)) + continue + } + to, err := strconv.ParseUint(matches[3], 10, 64) + if err != nil { + s.logger.Warnn("Invalid snapshot filename (to)", logger.NewStringField("filename", file.Key)) + continue + } + if filesByHashRange[hashRange] == nil { + filesByHashRange[hashRange] = make([]snapshotFile, 0, 1) + } + + s.logger.Debugn("Found snapshot file", + logger.NewIntField("hashRange", int64(hashRange)), + logger.NewStringField("filename", file.Key), + logger.NewIntField("from", int64(from)), + logger.NewIntField("to", int64(to)), + ) + filesByHashRange[hashRange] = append(filesByHashRange[hashRange], snapshotFile{ + filename: file.Key, + hashRange: hashRange, + from: from, + to: to, + }) + totalFiles++ + } + + // Sort each slice by "from" and "to" + for hashRange := range filesByHashRange { + sort.Slice(filesByHashRange[hashRange], func(i, j int) bool { + if filesByHashRange[hashRange][i].from != filesByHashRange[hashRange][j].from { + return filesByHashRange[hashRange][i].from < filesByHashRange[hashRange][j].from + } + return filesByHashRange[hashRange][i].to < filesByHashRange[hashRange][j].to + }) + } + + return totalFiles, filesByHashRange, nil +} + // Get implements the Get RPC method -func (s *Service) Get(ctx context.Context, req *pb.GetRequest) (*pb.GetResponse, error) { +func (s *Service) Get(_ context.Context, req *pb.GetRequest) (*pb.GetResponse, error) { s.mu.RLock() defer s.mu.RUnlock() response := &pb.GetResponse{ - ClusterSize: s.config.ClusterSize, - NodesAddresses: s.config.Addresses, + ClusterSize: s.config.getClusterSize(), + NodesAddresses: s.getNonDegradedAddresses(), } - if s.scaling { + if s.isDegraded() || s.scaling { s.metrics.errScalingCounter.Increment() response.ErrorCode = pb.ErrorCode_SCALING return response, nil @@ -576,17 +563,17 @@ func (s *Service) Get(ctx context.Context, req *pb.GetRequest) (*pb.GetResponse, } // Put implements the Put RPC method -func (s *Service) Put(ctx context.Context, req *pb.PutRequest) (*pb.PutResponse, error) { +func (s *Service) Put(_ context.Context, req *pb.PutRequest) (*pb.PutResponse, error) { s.mu.RLock() defer s.mu.RUnlock() resp := &pb.PutResponse{ Success: false, - ClusterSize: s.config.ClusterSize, - NodesAddresses: s.config.Addresses, + ClusterSize: s.config.getClusterSize(), + NodesAddresses: s.getNonDegradedAddresses(), } - if s.scaling { + if s.isDegraded() || s.scaling { s.metrics.errScalingCounter.Increment() resp.ErrorCode = pb.ErrorCode_SCALING return resp, nil @@ -629,7 +616,7 @@ func (s *Service) Put(ctx context.Context, req *pb.PutRequest) (*pb.PutResponse, } // GetNodeInfo implements the GetNodeInfo RPC method -func (s *Service) GetNodeInfo(ctx context.Context, req *pb.GetNodeInfoRequest) (*pb.GetNodeInfoResponse, error) { +func (s *Service) GetNodeInfo(_ context.Context, req *pb.GetNodeInfoRequest) (*pb.GetNodeInfoResponse, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -651,8 +638,8 @@ func (s *Service) GetNodeInfo(ctx context.Context, req *pb.GetNodeInfoRequest) ( return &pb.GetNodeInfoResponse{ NodeId: s.config.NodeID, - ClusterSize: s.config.ClusterSize, - NodesAddresses: s.config.Addresses, + ClusterSize: s.config.getClusterSize(), + NodesAddresses: s.getNonDegradedAddresses(), HashRanges: hashRanges, LastSnapshotTimestamp: uint64(s.lastSnapshotTime.Unix()), }, nil @@ -666,13 +653,15 @@ func (s *Service) Scale(ctx context.Context, req *pb.ScaleRequest) (*pb.ScaleRes log := s.logger.Withn(logger.NewIntField("newClusterSize", int64(len(req.NodesAddresses)))) log.Infon("Scale request received") + previousClusterSize := s.config.getClusterSize() + if s.scaling { log.Infon("Scaling operation already in progress") return &pb.ScaleResponse{ Success: false, ErrorMessage: "scaling operation already in progress", - PreviousClusterSize: s.config.ClusterSize, - NewClusterSize: s.config.ClusterSize, + PreviousClusterSize: previousClusterSize, + NewClusterSize: previousClusterSize, }, nil } @@ -682,8 +671,8 @@ func (s *Service) Scale(ctx context.Context, req *pb.ScaleRequest) (*pb.ScaleRes return &pb.ScaleResponse{ Success: false, ErrorMessage: "new cluster size must be greater than 0", - PreviousClusterSize: s.config.ClusterSize, - NewClusterSize: s.config.ClusterSize, + PreviousClusterSize: previousClusterSize, + NewClusterSize: previousClusterSize, }, nil } @@ -694,21 +683,17 @@ func (s *Service) Scale(ctx context.Context, req *pb.ScaleRequest) (*pb.ScaleRes // Set scaling flag s.scaling = true - // Save the previous cluster size - previousClusterSize := s.config.ClusterSize - // Update cluster size - s.config.ClusterSize = uint32(len(req.NodesAddresses)) s.config.Addresses = req.NodesAddresses + newClusterSize := uint32(len(s.config.Addresses)) // Update hash instance - s.hasher = hash.New(s.config.ClusterSize, s.config.TotalHashRanges) + s.hasher = hash.New(newClusterSize, s.config.TotalHashRanges) // Reinitialize caches for the new cluster size if err := s.initCaches(ctx, false, 0); err != nil { // Revert to previous cluster size on error log.Errorn("Failed to initialize caches", obskit.Error(err)) - s.config.ClusterSize = previousClusterSize - s.hasher = hash.New(s.config.ClusterSize, s.config.TotalHashRanges) + s.hasher = hash.New(previousClusterSize, s.config.TotalHashRanges) s.scaling = false return &pb.ScaleResponse{ Success: false, @@ -723,12 +708,13 @@ func (s *Service) Scale(ctx context.Context, req *pb.ScaleRequest) (*pb.ScaleRes log.Infon("Scale phase 1 of 2 completed successfully", logger.NewIntField("previousClusterSize", int64(previousClusterSize)), + logger.NewIntField("newClusterSize", int64(newClusterSize)), ) return &pb.ScaleResponse{ Success: true, PreviousClusterSize: previousClusterSize, - NewClusterSize: s.config.ClusterSize, + NewClusterSize: newClusterSize, }, nil } @@ -739,9 +725,7 @@ func (s *Service) ScaleComplete(_ context.Context, _ *pb.ScaleCompleteRequest) ( // No need to check the s.scaling value, let's be optimistic and go ahead here for auto-healing purposes s.scaling = false - s.logger.Infon("Scale phase 2 of 2 completed successfully", - logger.NewIntField("newClusterSize", int64(s.config.ClusterSize)), - ) + s.logger.Infon("Scale phase 2 of 2 completed successfully") return &pb.ScaleCompleteResponse{Success: true}, nil } @@ -848,8 +832,25 @@ func (s *Service) createSnapshots(ctx context.Context, fullSync bool, selectedHa i++ } } else { + _, filesByHashRange, err := s.listSnapshots(ctx, selectedHashRanges...) + if err != nil { + return fmt.Errorf("list snapshots: %w", err) + } + + since = make(map[uint32]uint64, len(currentRanges)) + for r := range currentRanges { + since[r] = 0 + } + + for hr, files := range filesByHashRange { + for _, file := range files { + if since[hr] < file.to { + since[hr] = file.to + } + } + } + i := 0 - since = s.since for hr, ss := range since { if i != 0 { sinceLog.WriteString(",") @@ -937,8 +938,6 @@ func (s *Service) createSnapshots(ctx context.Context, fullSync bool, selectedHa } s.metrics.uploadSnapshotDuration.Since(startUpload) - s.since[hashRange] = newSince - if fullSync { if len(filesToBeDeletedByHashRange[hashRange]) > 0 { // Clearing up old files that are incremental updates. @@ -978,6 +977,43 @@ func (s *Service) GetKeysByHashRangeWithIndexes(keys []string) (map[uint32][]str return s.hasher.GetKeysByHashRangeWithIndexes(keys, s.config.NodeID) } +// isDegraded checks if the current node is in degraded mode +func (s *Service) isDegraded() bool { + if s.config.DegradedNodes == nil { + return false + } + degradedNodes := s.config.DegradedNodes() + if len(degradedNodes) == 0 { + return false + } + if int(s.config.NodeID) >= len(degradedNodes) { + s.logger.Warnn("Node ID out of range for degraded nodes list", + logger.NewIntField("nodeId", int64(s.config.NodeID)), + logger.NewIntField("degradedNodes", int64(len(degradedNodes))), + ) + return false + } + return degradedNodes[s.config.NodeID] +} + +// getNonDegradedAddresses returns the list of node addresses excluding degraded nodes +func (s *Service) getNonDegradedAddresses() []string { + if s.config.DegradedNodes == nil { + return s.config.Addresses + } + degradedNodes := s.config.DegradedNodes() + if len(degradedNodes) == 0 { + return s.config.Addresses + } + nonDegraded := make([]string, 0, len(s.config.Addresses)) + for i, addr := range s.config.Addresses { + if i >= len(degradedNodes) || !degradedNodes[i] { + nonDegraded = append(nonDegraded, addr) + } + } + return nonDegraded +} + func (s *Service) getSnapshotFilenamePrefix() string { return path.Join(s.config.BackupFolderName, "hr_") } @@ -988,8 +1024,14 @@ func getSnapshotFilenamePostfix(hashRange uint32, from, to uint64) string { ".snapshot" } -type snapshot struct { +type snapshotReader struct { filename string hashRange uint32 reader io.Reader } + +type snapshotFile struct { + filename string + hashRange uint32 + from, to uint64 +} diff --git a/node/node_test.go b/node/node_test.go index c675ef1c..b71837c2 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -6,6 +6,7 @@ import ( "net" "regexp" "strconv" + "strings" "testing" "time" @@ -265,9 +266,6 @@ func TestScaleUpAndDown(t *testing.T) { SnapshotInterval: 60 * time.Second, Addresses: []string{node0Address}, }, node1Conf) - require.Equal(t, map[uint32]uint64{0: 1, 1: 1}, node1.since, - "Node should populate the since map upon start-up", - ) require.NoError(t, op.UpdateClusterData(node0Address, node1Address)) require.NoError(t, op.LoadSnapshots(ctx, 1, 0, node1.hasher.GetNodeHashRangesList(1)...)) require.NoError(t, op.Scale(ctx, []uint32{0, 1})) @@ -658,124 +656,6 @@ func TestSelectedSnapshots(t *testing.T) { }) } -func TestForceSkipFilesListing(t *testing.T) { - pool, err := dockertest.NewPool("") - require.NoError(t, err) - pool.MaxWait = 1 * time.Minute - - run := func(t *testing.T, newConf func() *config.Config) { - t.Parallel() - - minioContainer, err := miniokit.Setup(pool, t) - require.NoError(t, err) - - cloudStorage := keydbth.GetCloudStorage(t, newConf(), minioContainer) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - totalHashRanges := uint32(4) - node0Conf := newConf() - node0, node0Address := getService(ctx, t, cloudStorage, Config{ - NodeID: 0, - ClusterSize: 1, - TotalHashRanges: totalHashRanges, - SnapshotInterval: 60 * time.Second, - }, node0Conf) - c := getClient(t, totalHashRanges, node0Address) - op := getScaler(t, totalHashRanges, node0Address) - - require.NoError(t, c.Put(ctx, []string{"key1", "key2", "key3"}, testTTL)) - - exists, err := c.Get(ctx, []string{"key1", "key2", "key3", "key4"}) - require.NoError(t, err) - require.Equal(t, []bool{true, true, true, false}, exists) - - err = op.CreateSnapshots(ctx, 0, false) - require.NoError(t, err) - - keydbth.RequireExpectedFiles(ctx, t, minioContainer, defaultBackupFolderName, - regexp.MustCompile("^.+/hr_1_s_0_1.snapshot$"), - regexp.MustCompile("^.+/hr_2_s_0_1.snapshot$"), - regexp.MustCompile("^.+/hr_3_s_0_1.snapshot$"), - ) - - cancel() - node0.Close() - - ctx, cancel = context.WithCancel(context.Background()) - defer cancel() - node0Conf = newConf() - node0, node0Address = getService(ctx, t, cloudStorage, Config{ - NodeID: 0, - ClusterSize: 1, - TotalHashRanges: totalHashRanges, - SnapshotInterval: 60 * time.Second, - }, node0Conf) - c = getClient(t, totalHashRanges, node0Address) - - require.Equal(t, map[uint32]uint64{1: 1, 2: 1, 3: 1}, node0.since, - "Without forceSkipFilesListing, the since map should be populated from existing snapshots on startup", - ) - - require.NoError(t, op.UpdateClusterData(node0Address)) - require.NoError(t, op.LoadSnapshots(ctx, 0, 0)) - - exists, err = c.Get(ctx, []string{"key1", "key2", "key3", "key4"}) - require.NoError(t, err) - require.Equal(t, []bool{true, true, true, false}, exists) - - cancel() - node0.Close() - - // Repeat with forceSkipFilesListing=true - ctx, cancel = context.WithCancel(context.Background()) - defer cancel() - node0Conf = newConf() - node0Conf.Set("NodeService.forceSkipFilesListing", true) - node0, node0Address = getService(ctx, t, cloudStorage, Config{ - NodeID: 0, - ClusterSize: 1, - TotalHashRanges: totalHashRanges, - SnapshotInterval: 60 * time.Second, - }, node0Conf) - c = getClient(t, totalHashRanges, node0Address) - - require.Empty(t, node0.since, - "With forceSkipFilesListing, the since map should NOT be populated on startup", - ) - - require.NoError(t, op.UpdateClusterData(node0Address)) - require.NoError(t, op.LoadSnapshots(ctx, 0, 0)) - - exists, err = c.Get(ctx, []string{"key1", "key2", "key3", "key4"}) - require.NoError(t, err) - require.Equal(t, []bool{false, false, false, false}, exists, - "With forceSkipFilesListing, snapshots cannot be loaded even with explicit LoadSnapshots call") - - cancel() - node0.Close() - } - - t.Run("badger", func(t *testing.T) { - run(t, func() *config.Config { - conf := config.New() - conf.Set("BadgerDB.Dedup.Path", t.TempDir()) - conf.Set("BadgerDB.Dedup.Compress", false) - return conf - }) - }) - - t.Run("badger compressed", func(t *testing.T) { - run(t, func() *config.Config { - conf := config.New() - conf.Set("BadgerDB.Dedup.Path", t.TempDir()) - conf.Set("BadgerDB.Dedup.Compress", true) - return conf - }) - }) -} - func getService( ctx context.Context, t testing.TB, cs cloudStorage, nodeConfig Config, conf *config.Config, ) (*Service, string) { @@ -845,3 +725,226 @@ func getScaler(t testing.TB, totalHashRanges uint32, addresses ...string) *scale return op } + +func TestListSnapshotsSorting(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + pool.MaxWait = 1 * time.Minute + + run := func(t *testing.T, newConf func() *config.Config) { + t.Parallel() + + minioContainer, err := miniokit.Setup(pool, t) + require.NoError(t, err) + + cloudStorage := keydbth.GetCloudStorage(t, newConf(), minioContainer) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + totalHashRanges := uint32(4) + node0Conf := newConf() + node0, _ := getService(ctx, t, cloudStorage, Config{ + NodeID: 0, + ClusterSize: 1, + TotalHashRanges: totalHashRanges, + SnapshotInterval: 60 * time.Second, + }, node0Conf) + + // Create snapshot files with out-of-order from/to values for hash range 3 + // Expected order after sorting: hr_3_s_0_100, hr_3_s_100_200, hr_3_s_200_300 + snapshotFiles := []struct { + filename string + from uint64 + to uint64 + }{ + {defaultBackupFolderName + "/hr_3_s_200_300.snapshot", 200, 300}, + {defaultBackupFolderName + "/hr_3_s_0_100.snapshot", 0, 100}, + {defaultBackupFolderName + "/hr_3_s_100_200.snapshot", 100, 200}, + // Add some files for hash range 1 to test sorting across multiple ranges + {defaultBackupFolderName + "/hr_1_s_50_150.snapshot", 50, 150}, + {defaultBackupFolderName + "/hr_1_s_0_50.snapshot", 0, 50}, + } + + // Upload empty snapshot files to minio + for _, sf := range snapshotFiles { + _, err := cloudStorage.UploadReader(ctx, sf.filename, strings.NewReader("")) + require.NoError(t, err) + } + + // Call listSnapshots to get the sorted results + totalFiles, filesByHashRange, err := node0.listSnapshots(ctx) + require.NoError(t, err) + require.Equal(t, 5, totalFiles) + require.Equal(t, map[uint32][]snapshotFile{ + 1: { + { + filename: defaultBackupFolderName + "/hr_1_s_0_50.snapshot", + hashRange: 1, + from: 0, + to: 50, + }, + { + filename: defaultBackupFolderName + "/hr_1_s_50_150.snapshot", + hashRange: 1, + from: 50, + to: 150, + }, + }, + 3: { + { + filename: defaultBackupFolderName + "/hr_3_s_0_100.snapshot", + hashRange: 3, + from: 0, + to: 100, + }, + { + filename: defaultBackupFolderName + "/hr_3_s_100_200.snapshot", + hashRange: 3, + from: 100, + to: 200, + }, + { + filename: defaultBackupFolderName + "/hr_3_s_200_300.snapshot", + hashRange: 3, + from: 200, + to: 300, + }, + }, + }, filesByHashRange) + + cancel() + node0.Close() + } + + run(t, func() *config.Config { + conf := config.New() + conf.Set("BadgerDB.Dedup.Path", t.TempDir()) + return conf + }) +} + +func TestDegradedMode(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + pool.MaxWait = 1 * time.Minute + + run := func(t *testing.T, newConf func() *config.Config) { + t.Parallel() + + minioContainer, err := miniokit.Setup(pool, t) + require.NoError(t, err) + + cloudStorage := keydbth.GetCloudStorage(t, newConf(), minioContainer) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + totalHashRanges := uint32(3) + + // Create a variable to hold degraded state that can be updated during the test + degradedNodes := make([]bool, 2) + + // Create two nodes with DegradedNodes function + node0Conf := newConf() + node0, node0Address := getService(ctx, t, cloudStorage, Config{ + NodeID: 0, + ClusterSize: 2, + TotalHashRanges: totalHashRanges, + DegradedNodes: func() []bool { + return degradedNodes + }, + }, node0Conf) + + node1Conf := newConf() + node1, node1Address := getService(ctx, t, cloudStorage, Config{ + NodeID: 1, + ClusterSize: 2, + TotalHashRanges: totalHashRanges, + Addresses: []string{node0Address}, + DegradedNodes: func() []bool { + return degradedNodes + }, + }, node1Conf) + + c := getClient(t, totalHashRanges, node0Address, node1Address) + + // Test that both nodes work normally when not degraded + require.NoError(t, c.Put(ctx, []string{"key1", "key2", "key3"}, testTTL)) + + exists, err := c.Get(ctx, []string{"key1", "key2", "key3", "key4"}) + require.NoError(t, err) + require.Equal(t, []bool{true, true, true, false}, exists) + + // Mark node 1 as degraded + degradedNodes[1] = true + + // Test that degraded node rejects Get requests + resp, err := node1.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, resp.ErrorCode) + require.Len(t, resp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, resp.NodesAddresses[0]) + + // Test that degraded node rejects Put requests + putResp, err := node1.Put(ctx, &pb.PutRequest{Keys: []string{"key5"}, TtlSeconds: uint64(testTTL.Seconds())}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, putResp.ErrorCode) + require.Len(t, putResp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, putResp.NodesAddresses[0]) + + // Test that non-degraded node returns only non-degraded addresses in Get + resp, err = node0.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_NO_ERROR, resp.ErrorCode) + require.Len(t, resp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, resp.NodesAddresses[0]) + + // Test that non-degraded node returns only non-degraded addresses in Put + putResp, err = node0.Put(ctx, &pb.PutRequest{Keys: []string{"key6"}, TtlSeconds: uint64(testTTL.Seconds())}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_NO_ERROR, putResp.ErrorCode) + require.Len(t, putResp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, putResp.NodesAddresses[0]) + + // Test that GetNodeInfo returns only non-degraded addresses + nodeInfo, err := node0.GetNodeInfo(ctx, &pb.GetNodeInfoRequest{NodeId: 0}) + require.NoError(t, err) + require.Len(t, nodeInfo.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node0Address, nodeInfo.NodesAddresses[0]) + + // Test that LoadSnapshots still works on degraded node + op := getScaler(t, totalHashRanges, node0Address, node1Address) + require.NoError(t, op.CreateSnapshots(ctx, 0, false)) + + loadResp, err := node1.LoadSnapshots(ctx, &pb.LoadSnapshotsRequest{}) + require.NoError(t, err) + require.True(t, loadResp.Success, "LoadSnapshots should work on degraded nodes") + + // Mark node 0 as degraded and node 1 as non-degraded + degradedNodes[0] = true + degradedNodes[1] = false + + // Test that now node 0 rejects traffic and node 1 accepts it + resp, err = node0.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.Equal(t, pb.ErrorCode_SCALING, resp.ErrorCode) + + // Node 1 should not return SCALING error (might return WRONG_NODE or NO_ERROR depending on key hash) + resp, err = node1.Get(ctx, &pb.GetRequest{Keys: []string{"key1"}}) + require.NoError(t, err) + require.NotEqual(t, pb.ErrorCode_SCALING, resp.ErrorCode, "Node 1 should not be in degraded mode") + require.Len(t, resp.NodesAddresses, 1, "Only non-degraded node should be in NodesAddresses") + require.Equal(t, node1Address, resp.NodesAddresses[0]) + + cancel() + node0.Close() + node1.Close() + } + + run(t, func() *config.Config { + conf := config.New() + conf.Set("BadgerDB.Dedup.Path", t.TempDir()) + return conf + }) +}