Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ai-agent/guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All the rules below are not optional and must be followed to the letter.

# Go Rules

NEVER use `go build`, use `go run` instead.

## rudder-go-kit

https://github.com/rudderlabs/rudder-go-kit opinionated library for rudderstack. Should be used for all Golang projects,
Expand Down
15 changes: 0 additions & 15 deletions .claude/settings.local.json

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
/.env
.idea
.claude/settings.local.json
49 changes: 45 additions & 4 deletions cmd/node/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
"syscall"
"time"

"github.com/rudderlabs/keydb/release"

"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"

"github.com/rudderlabs/keydb/internal/cloudstorage"
"github.com/rudderlabs/keydb/internal/hash"
"github.com/rudderlabs/keydb/node"
pb "github.com/rudderlabs/keydb/proto"
"github.com/rudderlabs/keydb/release"
"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
_ "github.com/rudderlabs/rudder-go-kit/maxprocs"
Expand Down Expand Up @@ -103,10 +103,11 @@
if len(nodeAddresses) == 0 {
return fmt.Errorf("no node addresses provided")
}
degradedNodes := conf.GetReloadableStringVar("", "degradedNodes")

nodeConfig := node.Config{
NodeID: uint32(nodeID),
ClusterSize: uint32(conf.GetInt("clusterSize", 1)),

Check failure on line 110 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

unknown field ClusterSize in struct literal of type node.Config

Check failure on line 110 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

unknown field ClusterSize in struct literal of type node.Config
TotalHashRanges: uint32(conf.GetInt("totalHashRanges", node.DefaultTotalHashRanges)),
MaxFilesToList: conf.GetInt64("maxFilesToList", node.DefaultMaxFilesToList),
SnapshotInterval: conf.GetDuration("snapshotInterval",
Expand All @@ -115,7 +116,24 @@
GarbageCollectionInterval: conf.GetDuration("gcInterval", // node.DefaultGarbageCollectionInterval will be used
0, time.Nanosecond,
),
Addresses: strings.Split(nodeAddresses, ","),
Addresses: strings.Split(nodeAddresses, ","),
DegradedNodes: func() []bool {
raw := strings.TrimSpace(degradedNodes.Load())
if raw == "" {
return nil
}
v := strings.Split(raw, ",")
b := make([]bool, len(v))
for i, s := range v {
var err error
b[i], err = strconv.ParseBool(s)
if err != nil {
log.Warnn("Failed to parse degraded node", logger.NewStringField("v", raw), obskit.Error(err))
return nil
}
}
return b
},
LogTableStructureDuration: conf.GetDuration("logTableStructureDuration", 10, time.Minute),
BackupFolderName: conf.GetString("KUBE_NAMESPACE", ""),
}
Expand All @@ -124,7 +142,7 @@
log = log.Withn(
logger.NewIntField("port", int64(port)),
logger.NewIntField("nodeId", int64(nodeConfig.NodeID)),
logger.NewIntField("clusterSize", int64(nodeConfig.ClusterSize)),

Check failure on line 145 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

nodeConfig.ClusterSize undefined (type node.Config has no field or method ClusterSize)

Check failure on line 145 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

nodeConfig.ClusterSize undefined (type node.Config has no field or method ClusterSize)
logger.NewIntField("totalHashRanges", int64(nodeConfig.TotalHashRanges)),
logger.NewStringField("nodeAddresses", fmt.Sprintf("%+v", nodeConfig.Addresses)),
logger.NewIntField("noOfAddresses", int64(len(nodeConfig.Addresses))),
Expand Down Expand Up @@ -157,8 +175,31 @@
}
}()

// create a gRPC server with latency interceptors
// Configure gRPC server keepalive parameters
grpcKeepaliveMinTime := conf.GetDuration("grpc.keepalive.minTime", 10, time.Second)
grpcKeepalivePermitWithoutStream := conf.GetBool("grpc.keepalive.permitWithoutStream", true)
grpcKeepaliveTime := conf.GetDuration("grpc.keepalive.time", 60, time.Second)
grpcKeepaliveTimeout := conf.GetDuration("grpc.keepalive.timeout", 20, time.Second)

log.Infon("gRPC server keepalive configuration",
logger.NewDurationField("enforcementMinTime", grpcKeepaliveMinTime),
logger.NewBoolField("enforcementPermitWithoutStream", grpcKeepalivePermitWithoutStream),
logger.NewDurationField("serverTime", grpcKeepaliveTime),
logger.NewDurationField("serverTimeout", grpcKeepaliveTimeout),
)

// create a gRPC server with latency interceptors and keepalive parameters
server := grpc.NewServer(
// Keepalive enforcement policy - controls what the server requires from clients
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: grpcKeepaliveMinTime,
PermitWithoutStream: grpcKeepalivePermitWithoutStream,
}),
// Keepalive parameters - controls server's own keepalive behavior
grpc.KeepaliveParams(keepalive.ServerParameters{
Time: grpcKeepaliveTime,
Timeout: grpcKeepaliveTimeout,
}),
// Unary interceptor to record latency for unary RPCs
grpc.UnaryInterceptor(
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
Expand Down Expand Up @@ -188,7 +229,7 @@
return fmt.Errorf("failed to listen: %w", err)
}

h := hash.New(nodeConfig.ClusterSize, nodeConfig.TotalHashRanges)

Check failure on line 232 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

nodeConfig.ClusterSize undefined (type node.Config has no field or method ClusterSize)

Check failure on line 232 in cmd/node/main.go

View workflow job for this annotation

GitHub Actions / Unit

nodeConfig.ClusterSize undefined (type node.Config has no field or method ClusterSize)
log.Infon("Starting node",
logger.NewStringField("addresses", fmt.Sprintf("%+v", nodeConfig.Addresses)),
logger.NewIntField("hashRanges", int64(len(h.GetNodeHashRanges(nodeConfig.NodeID)))),
Expand Down
8 changes: 8 additions & 0 deletions cmd/scaler/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package main

var defaultHistogramBuckets = []float64{
0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60,
300 /* 5 mins */, 600 /* 10 mins */, 1800, /* 30 mins */
}

var customBuckets = map[string][]float64{}
21 changes: 18 additions & 3 deletions cmd/scaler/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import (

"github.com/rudderlabs/keydb/client"
"github.com/rudderlabs/keydb/internal/scaler"
"github.com/rudderlabs/keydb/release"
"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"
svcMetric "github.com/rudderlabs/rudder-go-kit/stats/metric"
obskit "github.com/rudderlabs/rudder-observability-kit/go/labels"
)

Expand All @@ -26,13 +29,25 @@ func main() {
defer logFactory.Sync()
log := logFactory.NewLogger()

if err := run(ctx, cancel, conf, log); err != nil {
releaseInfo := release.NewInfo()
statsOptions := []stats.Option{
stats.WithServiceName("keydb-scaler"),
stats.WithServiceVersion(releaseInfo.Version),
stats.WithDefaultHistogramBuckets(defaultHistogramBuckets),
}
for histogramName, buckets := range customBuckets {
statsOptions = append(statsOptions, stats.WithHistogramBuckets(histogramName, buckets))
}
stat := stats.NewStats(conf, logFactory, svcMetric.NewManager(), statsOptions...)
defer stat.Stop()

if err := run(ctx, cancel, conf, stat, log); err != nil {
log.Fataln("failed to run", obskit.Error(err))
os.Exit(1)
}
}

func run(ctx context.Context, cancel func(), conf *config.Config, log logger.Logger) error {
func run(ctx context.Context, cancel func(), conf *config.Config, stat stats.Stats, log logger.Logger) error {
defer cancel()

nodeAddresses := conf.GetString("nodeAddresses", "")
Expand Down Expand Up @@ -114,7 +129,7 @@ func run(ctx context.Context, cancel func(), conf *config.Config, log logger.Log

// Create and start HTTP server
serverAddr := conf.GetString("serverAddr", ":8080")
server := newHTTPServer(c, scClient, serverAddr, log)
server := newHTTPServer(c, scClient, serverAddr, stat, log)

// Start server in a goroutine
serverErrCh := make(chan error, 1)
Expand Down
101 changes: 96 additions & 5 deletions cmd/scaler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"

Expand All @@ -18,9 +19,17 @@ import (
"github.com/rudderlabs/rudder-go-kit/httputil"
"github.com/rudderlabs/rudder-go-kit/jsonrs"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"
obskit "github.com/rudderlabs/rudder-observability-kit/go/labels"
)

const (
totalSnapshotsToLoadMetricName = "scaler_total_snapshots_to_load"
currentSnapshotsToLoadMetricName = "scaler_current_snapshots_loaded"
totalSnapshotsToCreateMetricName = "scaler_total_snapshots_to_create"
currentSnapshotsToCreateMetricName = "scaler_current_snapshots_created"
)

type scalerClient interface {
Scale(ctx context.Context, nodeIDs []uint32) error
ScaleComplete(ctx context.Context, nodeIDs []uint32) error
Expand All @@ -38,14 +47,18 @@ type httpServer struct {
client *client.Client
scaler scalerClient
server *http.Server
stat stats.Stats
logger logger.Logger
}

// newHTTPServer creates a new HTTP server
func newHTTPServer(client *client.Client, scaler *scaler.Client, addr string, log logger.Logger) *httpServer {
func newHTTPServer(
client *client.Client, scaler *scaler.Client, addr string, stat stats.Stats, log logger.Logger,
) *httpServer {
s := &httpServer{
client: client,
scaler: scaler,
stat: stat,
logger: log,
}

Expand Down Expand Up @@ -195,7 +208,10 @@ func (s *httpServer) handleCreateSnapshots(w http.ResponseWriter, r *http.Reques
}

// Create snapshot
if err := s.scaler.CreateSnapshots(r.Context(), req.NodeID, req.FullSync, req.HashRanges...); err != nil {
err := s.createSnapshotsWithProgress(
r.Context(), req.NodeID, req.FullSync, req.DisableCreateSnapshotsSequentially, req.HashRanges,
)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating snapshot: %v", err), http.StatusInternalServerError)
return
}
Expand All @@ -215,12 +231,28 @@ func (s *httpServer) handleLoadSnapshots(w http.ResponseWriter, r *http.Request)
return
}

// Initialize metrics
nodeIDStr := strconv.FormatInt(int64(req.NodeID), 10)
totalSnapshotsToLoad := s.stat.NewTaggedStat(totalSnapshotsToLoadMetricName, stats.GaugeType, stats.Tags{
"nodeId": nodeIDStr,
})
totalSnapshotsToLoad.Observe(float64(len(req.HashRanges)))
defer totalSnapshotsToLoad.Observe(0)

currentSnapshotsLoaded := s.stat.NewTaggedStat(currentSnapshotsToLoadMetricName, stats.GaugeType, stats.Tags{
"nodeId": nodeIDStr,
})
currentSnapshotsLoaded.Observe(0)

// Load snapshots from cloud storage
if err := s.scaler.LoadSnapshots(r.Context(), req.NodeID, req.MaxConcurrency, req.HashRanges...); err != nil {
http.Error(w, fmt.Sprintf("Error loading snapshots: %v", err), http.StatusInternalServerError)
return
}

// Update metrics after successful load
currentSnapshotsLoaded.Observe(float64(len(req.HashRanges)))

// Write response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -335,11 +367,13 @@ func (s *httpServer) handleAutoScale(w http.ResponseWriter, r *http.Request) {
err = s.handleScaleUp(
r.Context(), req.OldNodesAddresses, req.NewNodesAddresses,
req.FullSync, req.SkipCreateSnapshots, loadSnapshotsMaxConcurrency,
req.DisableCreateSnapshotsSequentially,
)
} else if newClusterSize < oldClusterSize {
err = s.handleScaleDown(
r.Context(), req.OldNodesAddresses, req.NewNodesAddresses,
req.FullSync, req.SkipCreateSnapshots, loadSnapshotsMaxConcurrency,
req.DisableCreateSnapshotsSequentially,
)
} else {
// Auto-healing: propagate cluster addresses to all nodes for consistency
Expand All @@ -360,6 +394,7 @@ func (s *httpServer) handleAutoScale(w http.ResponseWriter, r *http.Request) {
func (s *httpServer) handleScaleUp(
ctx context.Context, oldAddresses, newAddresses []string,
fullSync, skipCreateSnapshots bool, loadSnapshotsMaxConcurrency uint32,
disableCreateSnapshotsSequentially bool,
) error {
oldClusterSize := uint32(len(oldAddresses))
newClusterSize := uint32(len(newAddresses))
Expand Down Expand Up @@ -405,7 +440,9 @@ func (s *httpServer) handleScaleUp(
}
group.Go(func() error {
createSnapshotsStart := time.Now()
err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, fullSync, hashRanges...)
err := s.createSnapshotsWithProgress(
gCtx, sourceNodeID, fullSync, disableCreateSnapshotsSequentially, hashRanges,
)
if err != nil {
return fmt.Errorf("creating snapshots from node %d for hash ranges %v: %w",
sourceNodeID, hashRanges, err,
Expand Down Expand Up @@ -468,6 +505,7 @@ func (s *httpServer) handleScaleUp(
func (s *httpServer) handleScaleDown(
ctx context.Context, oldAddresses, newAddresses []string,
fullSync, skipCreateSnapshots bool, loadSnapshotsMaxConcurrency uint32,
disableCreateSnapshotsSequentially bool,
) error {
oldClusterSize := uint32(len(oldAddresses))
newClusterSize := uint32(len(newAddresses))
Expand Down Expand Up @@ -504,7 +542,10 @@ func (s *httpServer) handleScaleDown(
}
group.Go(func() error {
createSnapshotsStart := time.Now()
if err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, fullSync, hashRanges...); err != nil {
err := s.createSnapshotsWithProgress(
gCtx, sourceNodeID, fullSync, disableCreateSnapshotsSequentially, hashRanges,
)
if err != nil {
return fmt.Errorf("creating snapshots from node %d for hash ranges %v: %w",
sourceNodeID, hashRanges, err,
)
Expand Down Expand Up @@ -709,7 +750,9 @@ func (s *httpServer) handleHashRangeMovements(w http.ResponseWriter, r *http.Req
// Call CreateSnapshots once per node with all hash ranges
group.Go(func() error {
start := time.Now()
err := s.scaler.CreateSnapshots(gCtx, sourceNodeID, req.FullSync, hashRanges...)
err := s.createSnapshotsWithProgress(
gCtx, sourceNodeID, req.FullSync, req.DisableCreateSnapshotsSequentially, hashRanges,
)
if err != nil {
return fmt.Errorf("creating snapshots for node %d: %w", sourceNodeID, err)
}
Expand Down Expand Up @@ -770,6 +813,54 @@ func (s *httpServer) handleHashRangeMovements(w http.ResponseWriter, r *http.Req
}
}

// createSnapshotsWithProgress creates snapshots either sequentially (one at a time) or in batch
// depending on the disableSequential flag
func (s *httpServer) createSnapshotsWithProgress(
ctx context.Context, nodeID uint32, fullSync, disableSequential bool, hashRanges []uint32,
) error {
// Initialize metrics
nodeIDStr := strconv.FormatInt(int64(nodeID), 10)
totalSnapshotsToCreate := s.stat.NewTaggedStat(totalSnapshotsToCreateMetricName, stats.GaugeType, stats.Tags{
"nodeId": nodeIDStr,
})
totalSnapshotsToCreate.Observe(float64(len(hashRanges)))
defer totalSnapshotsToCreate.Observe(0)

currentSnapshotsCreated := s.stat.NewTaggedStat(currentSnapshotsToCreateMetricName, stats.GaugeType, stats.Tags{
"nodeId": nodeIDStr,
})
currentSnapshotsCreated.Observe(0)

if disableSequential || len(hashRanges) == 0 {
// Call with all hash ranges at once (existing behavior)
err := s.scaler.CreateSnapshots(ctx, nodeID, fullSync, hashRanges...)
if err != nil {
return err
}
currentSnapshotsCreated.Observe(float64(len(hashRanges)))
return nil
}

// Call CreateSnapshots once for each hash range
for i, hashRange := range hashRanges {
s.logger.Infon("Creating snapshot",
logger.NewIntField("nodeId", int64(nodeID)),
logger.NewIntField("hashRange", int64(hashRange)),
logger.NewIntField("progress", int64(i+1)),
logger.NewIntField("total", int64(len(hashRanges))),
)

if err := s.scaler.CreateSnapshots(ctx, nodeID, fullSync, hashRange); err != nil {
return fmt.Errorf("creating snapshot for hash range %d: %w", hashRange, err)
}

// Update progress metric
currentSnapshotsCreated.Observe(float64(i + 1))
}

return nil
}

// handleLastOperation handles GET /lastOperation requests
func (s *httpServer) handleLastOperation(w http.ResponseWriter, r *http.Request) {
operation := s.scaler.GetLastOperation()
Expand Down
Loading
Loading