Skip to content

Commit

Permalink
feat: add graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
gacevicljubisa committed Jan 8, 2025
1 parent 06dd4a2 commit 59d5c38
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 16 deletions.
6 changes: 5 additions & 1 deletion cmd/beekeeper/cmd/stamper.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func (c *command) initStamperDilute() *cobra.Command {
diluteExecutor.Start(ctx, func(ctx context.Context) error {
return c.stamper.Dilute(ctx, c.globalConfig.GetFloat64(optionUsageThreshold), c.globalConfig.GetUint16(optionDiutionDepth))
})
defer diluteExecutor.Stop()
defer func() {
if err := diluteExecutor.Close(); err != nil {
c.log.Errorf("failed to close dilution periodic executor: %v", err)
}
}()

<-ctx.Done()

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
k8s.io/api v0.30.3
k8s.io/apimachinery v0.30.3
k8s.io/client-go v0.30.3
resenje.org/x v0.6.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1
k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI=
lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k=
resenje.org/x v0.6.0 h1:afn9E4XhglF4y9Kq0VH5tdSyjnsVKxiYgB6HFj7ebss=
resenje.org/x v0.6.0/go.mod h1:qgwe4MCzh57EkkMDurg24ug7HHfZtAjtBkmCihNmOpM=
rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU=
rsc.io/tmplfunc v0.0.3/go.mod h1:AG3sTPzElb1Io3Yg4voV9AGZJuleGAwaVRxL9M49PhA=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
Expand Down
23 changes: 17 additions & 6 deletions pkg/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,43 @@ import (
"time"

"github.com/ethersphere/beekeeper/pkg/logging"
"resenje.org/x/shutdown"
)

type PeriodicExecutor struct {
ticker *time.Ticker
interval time.Duration
log logging.Logger
stopChan chan struct{}
shutdown *shutdown.Graceful
}

func NewPeriodicExecutor(interval time.Duration, log logging.Logger) *PeriodicExecutor {
return &PeriodicExecutor{
ticker: time.NewTicker(interval),
interval: interval,
log: log,
stopChan: make(chan struct{}),
shutdown: shutdown.NewGraceful(),
}
}

func (pe *PeriodicExecutor) Start(ctx context.Context, task func(ctx context.Context) error) {
pe.shutdown.Add(1)
go func() {
defer pe.shutdown.Done()
ctx = pe.shutdown.Context(ctx)

if err := task(ctx); err != nil {
pe.log.Errorf("Task execution failed: %v", err)
}

for {
select {
case <-pe.ticker.C:
pe.log.Tracef("Executing task")
pe.log.Tracef("Executing task after %s interval", pe.interval)
if err := task(ctx); err != nil {
pe.log.Errorf("Task execution failed: %v", err)
}
case <-pe.stopChan:
case <-pe.shutdown.Quit():
return
case <-ctx.Done():
return
Expand All @@ -41,7 +50,9 @@ func (pe *PeriodicExecutor) Start(ctx context.Context, task func(ctx context.Con
}()
}

func (pe *PeriodicExecutor) Stop() {
func (pe *PeriodicExecutor) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pe.ticker.Stop()
close(pe.stopChan)
return pe.shutdown.Shutdown(ctx)
}
27 changes: 18 additions & 9 deletions pkg/stamper/stamper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ type ClientConfig struct {
}

type StamperClient struct {
*ClientConfig
httpClient http.Client
log logging.Logger
namespace string
k8sClient *k8s.Client
labelSelector string
inCluster bool
httpClient http.Client
}

func NewStamperClient(cfg *ClientConfig) *StamperClient {
Expand All @@ -50,8 +54,12 @@ func NewStamperClient(cfg *ClientConfig) *StamperClient {
}

return &StamperClient{
httpClient: *httpClient,
ClientConfig: cfg,
httpClient: *httpClient,
log: cfg.Log,
namespace: cfg.Namespace,
k8sClient: cfg.K8sClient,
labelSelector: cfg.LabelSelector,
inCluster: cfg.InCluster,
}
}

Expand All @@ -62,6 +70,7 @@ func (s *StamperClient) Create(ctx context.Context, amount uint64, depth uint8)

// Dilute implements Client.
func (s *StamperClient) Dilute(ctx context.Context, usageThreshold float64, dilutionDepth uint16) error {
s.log.WithFields(map[string]interface{}{"usageThreshold": usageThreshold, "dilutionDepth": dilutionDepth}).Infof("diluting namespace %s", s.namespace)
nodes, err := s.getNamespaceNodes(ctx)
if err != nil {
return fmt.Errorf("get namespace nodes: %w", err)
Expand Down Expand Up @@ -97,11 +106,11 @@ func (s *StamperClient) Topup(ctx context.Context, ttlThreshold time.Duration, t
}

func (sc *StamperClient) getNamespaceNodes(ctx context.Context) (nodes []Node, err error) {
if sc.Namespace == "" {
if sc.namespace == "" {
return nil, fmt.Errorf("namespace not provided")
}

if sc.InCluster {
if sc.inCluster {
nodes, err = sc.getServiceNodes(ctx)
} else {
nodes, err = sc.getIngressNodes(ctx)
Expand All @@ -115,7 +124,7 @@ func (sc *StamperClient) getNamespaceNodes(ctx context.Context) (nodes []Node, e
}

func (sc *StamperClient) getServiceNodes(ctx context.Context) ([]Node, error) {
svcNodes, err := sc.K8sClient.Service.GetNodes(ctx, sc.Namespace, sc.LabelSelector)
svcNodes, err := sc.k8sClient.Service.GetNodes(ctx, sc.namespace, sc.labelSelector)
if err != nil {
return nil, fmt.Errorf("list api services: %w", err)
}
Expand All @@ -138,12 +147,12 @@ func (sc *StamperClient) getServiceNodes(ctx context.Context) ([]Node, error) {
}

func (sc *StamperClient) getIngressNodes(ctx context.Context) ([]Node, error) {
ingressNodes, err := sc.K8sClient.Ingress.GetNodes(ctx, sc.Namespace, sc.LabelSelector)
ingressNodes, err := sc.k8sClient.Ingress.GetNodes(ctx, sc.namespace, sc.labelSelector)
if err != nil {
return nil, fmt.Errorf("list ingress api nodes hosts: %w", err)
}

ingressRouteNodes, err := sc.K8sClient.IngressRoute.GetNodes(ctx, sc.Namespace, sc.LabelSelector)
ingressRouteNodes, err := sc.k8sClient.IngressRoute.GetNodes(ctx, sc.namespace, sc.labelSelector)
if err != nil {
return nil, fmt.Errorf("list ingress route api nodes hosts: %w", err)
}
Expand Down

0 comments on commit 59d5c38

Please sign in to comment.