Skip to content

Commit

Permalink
Propagate Context
Browse files Browse the repository at this point in the history
Propagate a context from the main function to wherever it is needed.

Use `github.com/oklog/run` run groups to handle the life cycle of the go
routines running the rebooter and metrics server.

Ref: kubereboot#234
and kubereboot#808

Signed-off-by: leonnicolas <[email protected]>
  • Loading branch information
leonnicolas committed Mar 4, 2024
1 parent ebb7ccf commit 10ee650
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 94 deletions.
154 changes: 99 additions & 55 deletions cmd/kured/main.go

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions cmd/kured/main_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package main

import (
"context"
"reflect"
"testing"

"github.com/kubereboot/kured/pkg/alerts"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/kubereboot/kured/pkg/alerts"
assert "gotest.tools/v3/assert"

papi "github.com/prometheus/client_golang/api"
Expand All @@ -16,12 +17,14 @@ type BlockingChecker struct {
blocking bool
}

func (fbc BlockingChecker) isBlocked() bool {
func (fbc BlockingChecker) isBlocked(_ context.Context) bool {
return fbc.blocking
}

var _ RebootBlocker = BlockingChecker{} // Verify that Type implements Interface.
var _ RebootBlocker = (*BlockingChecker)(nil) // Verify that *Type implements Interface.
var (
_ RebootBlocker = BlockingChecker{} // Verify that Type implements Interface.
_ RebootBlocker = (*BlockingChecker)(nil) // Verify that *Type implements Interface.
)

func Test_flagCheck(t *testing.T) {
var cmd *cobra.Command
Expand Down Expand Up @@ -155,7 +158,7 @@ func Test_rebootBlocked(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := rebootBlocked(tt.args.blockers...); got != tt.want {
if got := rebootBlocked(context.Background(), tt.args.blockers...); got != tt.want {
t.Errorf("rebootBlocked() = %v, want %v", got, tt.want)
}
})
Expand Down Expand Up @@ -275,7 +278,7 @@ func Test_rebootRequired(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := rebootRequired(tt.args.sentinelCommand); got != tt.want {
if got := rebootRequired(context.Background(), tt.args.sentinelCommand); got != tt.want {
t.Errorf("rebootRequired() = %v, want %v", got, tt.want)
}
})
Expand Down Expand Up @@ -303,8 +306,7 @@ func Test_rebootRequired_fatals(t *testing.T) {

for _, c := range cases {
fatal = false
rebootRequired(c.param)
rebootRequired(context.Background(), c.param)
assert.Equal(t, c.expectFatal, fatal)
}

}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/oklog/run v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
github.com/pkg/errors v0.9.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU=
github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA=
github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU=
github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
Expand Down
5 changes: 2 additions & 3 deletions pkg/alerts/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ func NewPromClient(conf papi.Config) (*PromClient, error) {
// filter by regexp means when the regex finds the alert-name; the alert is exluded from the
// block-list and will NOT block rebooting. query by includeLabel means,
// if the query finds an alert, it will include it to the block-list and it WILL block rebooting.
func (p *PromClient) ActiveAlerts(filter *regexp.Regexp, firingOnly, filterMatchOnly bool) ([]string, error) {

func (p *PromClient) ActiveAlerts(ctx context.Context, filter *regexp.Regexp, firingOnly, filterMatchOnly bool) ([]string, error) {
// get all alerts from prometheus
value, _, err := p.api.Query(context.Background(), "ALERTS", time.Now())
value, _, err := p.api.Query(ctx, "ALERTS", time.Now())
if err != nil {
return nil, err
}
Expand Down
6 changes: 2 additions & 4 deletions pkg/alerts/prometheus_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package alerts

import (
"context"
"log"
"net/http"
"net/http/httptest"

"regexp"
"testing"

Expand All @@ -27,7 +27,6 @@ type MockServerProperties struct {

// NewMockServer sets up a new MockServer with properties ad starts the server.
func NewMockServer(props ...MockServerProperties) *httptest.Server {

handler := http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
for _, proc := range props {
Expand Down Expand Up @@ -140,7 +139,6 @@ func TestActiveAlerts(t *testing.T) {
defer mockServer.Close()

t.Run(tc.it, func(t *testing.T) {

// regex filter
regex, _ := regexp.Compile(tc.rFilter)

Expand All @@ -150,7 +148,7 @@ func TestActiveAlerts(t *testing.T) {
log.Fatal(err)
}

result, err := p.ActiveAlerts(regex, tc.firingOnly, tc.filterMatchOnly)
result, err := p.ActiveAlerts(context.Background(), regex, tc.firingOnly, tc.filterMatchOnly)
if err != nil {
log.Fatal(err)
}
Expand Down
36 changes: 18 additions & 18 deletions pkg/daemonsetlock/daemonsetlock.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func New(client *kubernetes.Clientset, nodeID, namespace, name, annotation strin
}

// Acquire attempts to annotate the kured daemonset with lock info from instantiated DaemonSetLock using client-go
func (dsl *DaemonSetLock) Acquire(metadata interface{}, TTL time.Duration) (bool, string, error) {
func (dsl *DaemonSetLock) Acquire(ctx context.Context, metadata interface{}, TTL time.Duration) (bool, string, error) {
for {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return false, "", fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand All @@ -75,7 +75,7 @@ func (dsl *DaemonSetLock) Acquire(metadata interface{}, TTL time.Duration) (bool
}
ds.ObjectMeta.Annotations[dsl.annotation] = string(valueBytes)

_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(context.TODO(), ds, metav1.UpdateOptions{})
_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(ctx, ds, metav1.UpdateOptions{})
if err != nil {
if se, ok := err.(*errors.StatusError); ok && se.ErrStatus.Reason == metav1.StatusReasonConflict {
// Something else updated the resource between us reading and writing - try again soon
Expand All @@ -90,9 +90,9 @@ func (dsl *DaemonSetLock) Acquire(metadata interface{}, TTL time.Duration) (bool
}

// AcquireMultiple creates and annotates the daemonset with a multiple owner lock
func (dsl *DaemonSetLock) AcquireMultiple(metadata interface{}, TTL time.Duration, maxOwners int) (bool, []string, error) {
func (dsl *DaemonSetLock) AcquireMultiple(ctx context.Context, metadata interface{}, TTL time.Duration, maxOwners int) (bool, []string, error) {
for {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return false, []string{}, fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand All @@ -119,7 +119,7 @@ func (dsl *DaemonSetLock) AcquireMultiple(metadata interface{}, TTL time.Duratio
}
ds.ObjectMeta.Annotations[dsl.annotation] = string(newAnnotationBytes)

_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(context.Background(), ds, metav1.UpdateOptions{})
_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(ctx, ds, metav1.UpdateOptions{})
if err != nil {
if se, ok := err.(*errors.StatusError); ok && se.ErrStatus.Reason == metav1.StatusReasonConflict {
time.Sleep(time.Second)
Expand Down Expand Up @@ -176,8 +176,8 @@ func (dsl *DaemonSetLock) canAcquireMultiple(annotation multiLockAnnotationValue
}

// Test attempts to check the kured daemonset lock status (existence, expiry) from instantiated DaemonSetLock using client-go
func (dsl *DaemonSetLock) Test(metadata interface{}) (bool, error) {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
func (dsl *DaemonSetLock) Test(ctx context.Context, metadata interface{}) (bool, error) {
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return false, fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand All @@ -198,8 +198,8 @@ func (dsl *DaemonSetLock) Test(metadata interface{}) (bool, error) {
}

// TestMultiple attempts to check the kured daemonset lock status for multi locks
func (dsl *DaemonSetLock) TestMultiple() (bool, error) {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
func (dsl *DaemonSetLock) TestMultiple(ctx context.Context) (bool, error) {
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return false, fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand All @@ -222,9 +222,9 @@ func (dsl *DaemonSetLock) TestMultiple() (bool, error) {
}

// Release attempts to remove the lock data from the kured ds annotations using client-go
func (dsl *DaemonSetLock) Release() error {
func (dsl *DaemonSetLock) Release(ctx context.Context) error {
for {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand All @@ -245,7 +245,7 @@ func (dsl *DaemonSetLock) Release() error {

delete(ds.ObjectMeta.Annotations, dsl.annotation)

_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(context.TODO(), ds, metav1.UpdateOptions{})
_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(ctx, ds, metav1.UpdateOptions{})
if err != nil {
if se, ok := err.(*errors.StatusError); ok && se.ErrStatus.Reason == metav1.StatusReasonConflict {
// Something else updated the resource between us reading and writing - try again soon
Expand All @@ -260,9 +260,9 @@ func (dsl *DaemonSetLock) Release() error {
}

// ReleaseMultiple attempts to remove the lock data from the kured ds annotations using client-go
func (dsl *DaemonSetLock) ReleaseMultiple() error {
func (dsl *DaemonSetLock) ReleaseMultiple(ctx context.Context) error {
for {
ds, err := dsl.GetDaemonSet(k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
ds, err := dsl.GetDaemonSet(ctx, k8sAPICallRetrySleep, k8sAPICallRetryTimeout)
if err != nil {
return fmt.Errorf("timed out trying to get daemonset %s in namespace %s: %w", dsl.name, dsl.namespace, err)
}
Expand Down Expand Up @@ -294,7 +294,7 @@ func (dsl *DaemonSetLock) ReleaseMultiple() error {
}
ds.ObjectMeta.Annotations[dsl.annotation] = string(newAnnotationBytes)

_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(context.TODO(), ds, metav1.UpdateOptions{})
_, err = dsl.client.AppsV1().DaemonSets(dsl.namespace).Update(ctx, ds, metav1.UpdateOptions{})
if err != nil {
if se, ok := err.(*errors.StatusError); ok && se.ErrStatus.Reason == metav1.StatusReasonConflict {
// Something else updated the resource between us reading and writing - try again soon
Expand All @@ -309,10 +309,10 @@ func (dsl *DaemonSetLock) ReleaseMultiple() error {
}

// GetDaemonSet returns the named DaemonSet resource from the DaemonSetLock's configured client
func (dsl *DaemonSetLock) GetDaemonSet(sleep, timeout time.Duration) (*v1.DaemonSet, error) {
func (dsl *DaemonSetLock) GetDaemonSet(ctx context.Context, sleep, timeout time.Duration) (*v1.DaemonSet, error) {
var ds *v1.DaemonSet
var lastError error
err := wait.PollImmediate(sleep, timeout, func() (bool, error) {
err := wait.PollUntilContextTimeout(ctx, sleep, timeout, true, func(ctx context.Context) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if ds, lastError = dsl.client.AppsV1().DaemonSets(dsl.namespace).Get(ctx, dsl.name, metav1.GetOptions{}); lastError != nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/reboot/command.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package reboot

import (
"context"

"github.com/kubereboot/kured/pkg/util"
log "github.com/sirupsen/logrus"
)
Expand All @@ -17,9 +19,9 @@ func NewCommandReboot(nodeID string, rebootCommand []string) *CommandRebootMetho
}

// Reboot triggers the command-reboot.
func (c *CommandRebootMethod) Reboot() {
func (c *CommandRebootMethod) Reboot(ctx context.Context) {
log.Infof("Running command: %s for node: %s", c.rebootCommand, c.nodeID)
if err := util.NewCommand(c.rebootCommand[0], c.rebootCommand[1:]...).Run(); err != nil {
if err := util.NewCommand(ctx, c.rebootCommand[0], c.rebootCommand[1:]...).Run(); err != nil {
log.Fatalf("Error invoking reboot command: %v", err)
}
}
4 changes: 3 additions & 1 deletion pkg/reboot/reboot.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package reboot

import "context"

// Reboot interface defines the Reboot function to be implemented.
type Reboot interface {
Reboot()
Reboot(context.Context)
}
3 changes: 2 additions & 1 deletion pkg/reboot/signal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package reboot

import (
"context"
"os"
"syscall"

Expand All @@ -19,7 +20,7 @@ func NewSignalReboot(nodeID string, signal int) *SignalRebootMethod {
}

// Reboot triggers the signal-reboot.
func (c *SignalRebootMethod) Reboot() {
func (c *SignalRebootMethod) Reboot(_ context.Context) {
log.Infof("Emit reboot-signal for node: %s", c.nodeID)

process, err := os.FindProcess(1)
Expand Down
5 changes: 3 additions & 2 deletions pkg/util/util.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package util

import (
"context"
"os/exec"

log "github.com/sirupsen/logrus"
)

// NewCommand creates a new Command with stdout/stderr wired to our standard logger
func NewCommand(name string, arg ...string) *exec.Cmd {
cmd := exec.Command(name, arg...)
func NewCommand(ctx context.Context, name string, arg ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, arg...)
cmd.Stdout = log.NewEntry(log.StandardLogger()).
WithField("cmd", cmd.Args[0]).
WithField("std", "out").
Expand Down

0 comments on commit 10ee650

Please sign in to comment.