diff --git a/cmd/argocd-application-controller/commands/argocd_application_controller.go b/cmd/argocd-application-controller/commands/argocd_application_controller.go index 2670283fedc23..ddf6a978a42aa 100644 --- a/cmd/argocd-application-controller/commands/argocd_application_controller.go +++ b/cmd/argocd-application-controller/commands/argocd_application_controller.go @@ -205,7 +205,7 @@ func NewCommand() *cobra.Command { enableK8sEvent, ) errors.CheckError(err) - cacheutil.CollectMetrics(redisClient, appController.GetMetricsServer()) + cacheutil.CollectMetrics(redisClient, appController.GetMetricsServer(), nil) stats.RegisterStackDumper() stats.StartStatsTicker(10 * time.Minute) diff --git a/cmd/argocd-repo-server/commands/argocd_repo_server.go b/cmd/argocd-repo-server/commands/argocd_repo_server.go index e55bfda4ac075..c66de4da0858c 100644 --- a/cmd/argocd-repo-server/commands/argocd_repo_server.go +++ b/cmd/argocd-repo-server/commands/argocd_repo_server.go @@ -130,7 +130,7 @@ func NewCommand() *cobra.Command { askPassServer := askpass.NewServer(askpass.SocketPath) metricsServer := metrics.NewMetricsServer() - cacheutil.CollectMetrics(redisClient, metricsServer) + cacheutil.CollectMetrics(redisClient, metricsServer, nil) server, err := reposerver.NewServer(metricsServer, cache, tlsConfigCustomizer, repository.RepoServerInitConstants{ ParallelismLimit: parallelismLimit, PauseGenerationAfterFailedGenerationAttempts: pauseGenerationAfterFailedGenerationAttempts, diff --git a/cmd/argocd-server/commands/argocd_server.go b/cmd/argocd-server/commands/argocd_server.go index 56bc73db87c48..b9b6d04c54db7 100644 --- a/cmd/argocd-server/commands/argocd_server.go +++ b/cmd/argocd-server/commands/argocd_server.go @@ -258,22 +258,25 @@ func NewCommand() *cobra.Command { stats.RegisterHeapDumper("memprofile") argocd := server.NewServer(ctx, argoCDOpts, appsetOpts) argocd.Init(ctx) - lns, err := argocd.Listen() - errors.CheckError(err) for { var closer func() - ctx, cancel := context.WithCancel(ctx) + serverCtx, cancel := context.WithCancel(ctx) + lns, err := argocd.Listen() + errors.CheckError(err) if otlpAddress != "" { - closer, err = traceutil.InitTracer(ctx, "argocd-server", otlpAddress, otlpInsecure, otlpHeaders, otlpAttrs) + closer, err = traceutil.InitTracer(serverCtx, "argocd-server", otlpAddress, otlpInsecure, otlpHeaders, otlpAttrs) if err != nil { log.Fatalf("failed to initialize tracing: %v", err) } } - argocd.Run(ctx, lns) - cancel() + argocd.Run(serverCtx, lns) if closer != nil { closer() } + cancel() + if argocd.TerminateRequested() { + break + } } }, Example: templates.Examples(` diff --git a/server/server.go b/server/server.go index 6625461dfab03..eaef1262cf9ba 100644 --- a/server/server.go +++ b/server/server.go @@ -13,13 +13,17 @@ import ( "net/url" "os" "os/exec" + "os/signal" "path" "path/filepath" "reflect" "regexp" go_runtime "runtime" + "runtime/debug" "strings" gosync "sync" + "sync/atomic" + "syscall" "time" // nolint:staticcheck @@ -187,17 +191,20 @@ type ArgoCDServer struct { db db.ArgoDB // stopCh is the channel which when closed, will shutdown the Argo CD server - stopCh chan struct{} - userStateStorage util_session.UserStateStorage - indexDataInit gosync.Once - indexData []byte - indexDataErr error - staticAssets http.FileSystem - apiFactory api.Factory - secretInformer cache.SharedIndexInformer - configMapInformer cache.SharedIndexInformer - serviceSet *ArgoCDServiceSet - extensionManager *extension.Manager + stopCh chan os.Signal + userStateStorage util_session.UserStateStorage + indexDataInit gosync.Once + indexData []byte + indexDataErr error + staticAssets http.FileSystem + apiFactory api.Factory + secretInformer cache.SharedIndexInformer + configMapInformer cache.SharedIndexInformer + serviceSet *ArgoCDServiceSet + extensionManager *extension.Manager + shutdown func() + terminateRequested atomic.Bool + available atomic.Bool } type ArgoCDServerOpts struct { @@ -329,6 +336,9 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio pg := extension.NewDefaultProjectGetter(projLister, dbInstance) ug := extension.NewDefaultUserGetter(policyEnf) em := extension.NewManager(logger, opts.Namespace, sg, ag, pg, enf, ug) + noopShutdown := func() { + log.Error("API Server Shutdown function called but server is not started yet.") + } a := &ArgoCDServer{ ArgoCDServerOpts: opts, @@ -352,6 +362,8 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio secretInformer: secretInformer, configMapInformer: configMapInformer, extensionManager: em, + shutdown: noopShutdown, + stopCh: make(chan os.Signal, 1), } err = a.logInClusterWarnings() @@ -369,6 +381,12 @@ const ( ) func (a *ArgoCDServer) healthCheck(r *http.Request) error { + if a.terminateRequested.Load() { + return errors.New("API Server is terminating and unable to serve requests.") + } + if !a.available.Load() { + return errors.New("API Server is not available. It either hasn't started or is restarting.") + } if val, ok := r.URL.Query()["full"]; ok && len(val) > 0 && val[0] == "true" { argoDB := db.NewDB(a.Namespace, a.settingsMgr, a.KubeClientset) _, err := argoDB.ListClusters(r.Context()) @@ -515,11 +533,19 @@ func (a *ArgoCDServer) Init(ctx context.Context) { // k8s.io/ go-to-protobuf uses protoc-gen-gogo, which comes from gogo/protobuf (a fork of // golang/protobuf). func (a *ArgoCDServer) Run(ctx context.Context, listeners *Listeners) { + defer func() { + if r := recover(); r != nil { + log.WithField("trace", string(debug.Stack())).Error("Recovered from panic: ", r) + a.terminateRequested.Store(true) + a.shutdown() + } + }() + a.userStateStorage.Init(ctx) metricsServ := metrics.NewMetricsServer(a.MetricsHost, a.MetricsPort) if a.RedisClient != nil { - cacheutil.CollectMetrics(a.RedisClient, metricsServ) + cacheutil.CollectMetrics(a.RedisClient, metricsServ, a.userStateStorage.GetLockObject()) } svcSet := newArgoCDServiceSet(a) @@ -601,35 +627,118 @@ func (a *ArgoCDServer) Run(ctx context.Context, listeners *Listeners) { log.Fatal("Timed out waiting for project cache to sync") } - a.stopCh = make(chan struct{}) - <-a.stopCh + shutdownFunc := func() { + log.Info("API Server shutdown initiated. Shutting down servers...") + a.available.Store(false) + shutdownCtx, cancel := context.WithTimeout(ctx, 20*time.Second) + defer cancel() + var wg gosync.WaitGroup + + // Shutdown http server + wg.Add(1) + go func() { + defer wg.Done() + err := httpS.Shutdown(shutdownCtx) + if err != nil { + log.Errorf("Error shutting down http server: %s", err) + } + }() + + if a.useTLS() { + // Shutdown https server + wg.Add(1) + go func() { + defer wg.Done() + err := httpsS.Shutdown(shutdownCtx) + if err != nil { + log.Errorf("Error shutting down https server: %s", err) + } + }() + } + + // Shutdown gRPC server + wg.Add(1) + go func() { + defer wg.Done() + grpcS.GracefulStop() + }() + + // Shutdown metrics server + wg.Add(1) + go func() { + defer wg.Done() + err := metricsServ.Shutdown(shutdownCtx) + if err != nil { + log.Errorf("Error shutting down metrics server: %s", err) + } + }() + + if a.useTLS() { + // Shutdown tls server + wg.Add(1) + go func() { + defer wg.Done() + tlsm.Close() + }() + } + + // Shutdown tcp server + wg.Add(1) + go func() { + defer wg.Done() + tcpm.Close() + }() + + c := make(chan struct{}) + // This goroutine will wait for all servers to conclude the shutdown + // process + go func() { + defer close(c) + wg.Wait() + }() + + select { + case <-c: + log.Info("All servers were gracefully shutdown. Exiting...") + case <-shutdownCtx.Done(): + log.Warn("Graceful shutdown timeout. Exiting...") + } + } + a.shutdown = shutdownFunc + signal.Notify(a.stopCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + a.available.Store(true) + + select { + case signal := <-a.stopCh: + log.Infof("API Server received signal: %s", signal.String()) + // SIGUSR1 is used for triggering a server restart + if signal != syscall.SIGUSR1 { + a.terminateRequested.Store(true) + } + a.shutdown() + case <-ctx.Done(): + log.Infof("API Server: %s", ctx.Err()) + a.terminateRequested.Store(true) + a.shutdown() + } } func (a *ArgoCDServer) Initialized() bool { return a.projInformer.HasSynced() && a.appInformer.HasSynced() } +// TerminateRequested returns whether a shutdown was initiated by a signal or context cancel +// as opposed to a watch. +func (a *ArgoCDServer) TerminateRequested() bool { + return a.terminateRequested.Load() +} + // checkServeErr checks the error from a .Serve() call to decide if it was a graceful shutdown func (a *ArgoCDServer) checkServeErr(name string, err error) { - if err != nil { - if a.stopCh == nil { - // a nil stopCh indicates a graceful shutdown - log.Infof("graceful shutdown %s: %v", name, err) - } else { - log.Fatalf("%s: %v", name, err) - } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Errorf("Error received from server %s: %v", name, err) } else { - log.Infof("graceful shutdown %s", name) - } -} - -// Shutdown stops the Argo CD server -func (a *ArgoCDServer) Shutdown() { - log.Info("Shut down requested") - stopCh := a.stopCh - a.stopCh = nil - if stopCh != nil { - close(stopCh) + log.Infof("Graceful shutdown of %s initiated", name) } } @@ -734,9 +843,10 @@ func (a *ArgoCDServer) watchSettings() { } } log.Info("shutting down settings watch") - a.Shutdown() a.settingsMgr.Unsubscribe(updateCh) close(updateCh) + // Triggers server restart + a.stopCh <- syscall.SIGUSR1 } func (a *ArgoCDServer) rbacPolicyLoader(ctx context.Context) { diff --git a/server/server_test.go b/server/server_test.go index 1f715d00d4e91..e67a3763d592d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -10,6 +10,8 @@ import ( "os" "path/filepath" "strings" + gosync "sync" + "syscall" "testing" "time" @@ -419,6 +421,72 @@ func TestCertsAreNotGeneratedInInsecureMode(t *testing.T) { assert.Nil(t, s.settings.Certificate) } +func TestGracefulShutdown(t *testing.T) { + port, err := test.GetFreePort() + require.NoError(t, err) + mockRepoClient := &mocks.Clientset{RepoServerServiceClient: &mocks.RepoServerServiceClient{}} + kubeclientset := fake.NewSimpleClientset(test.NewFakeConfigMap(), test.NewFakeSecret()) + redis, redisCloser := test.NewInMemoryRedis() + defer redisCloser() + s := NewServer( + context.Background(), + ArgoCDServerOpts{ + ListenPort: port, + Namespace: test.FakeArgoCDNamespace, + KubeClientset: kubeclientset, + AppClientset: apps.NewSimpleClientset(), + RepoClientset: mockRepoClient, + RedisClient: redis, + }, + ApplicationSetOpts{}, + ) + + projInformerCancel := test.StartInformer(s.projInformer) + defer projInformerCancel() + appInformerCancel := test.StartInformer(s.appInformer) + defer appInformerCancel() + appsetInformerCancel := test.StartInformer(s.appsetInformer) + defer appsetInformerCancel() + + lns, err := s.Listen() + require.NoError(t, err) + + shutdown := false + runCtx, runCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer runCancel() + + err = s.healthCheck(&http.Request{URL: &url.URL{Path: "/healthz", RawQuery: "full=true"}}) + require.Error(t, err, "API Server is not running. It either hasn't started or is restarting.") + + var wg gosync.WaitGroup + wg.Add(1) + go func(shutdown *bool) { + defer wg.Done() + s.Run(runCtx, lns) + *shutdown = true + }(&shutdown) + + for { + if s.available.Load() { + err = s.healthCheck(&http.Request{URL: &url.URL{Path: "/healthz", RawQuery: "full=true"}}) + require.NoError(t, err) + break + } + time.Sleep(10 * time.Millisecond) + } + + s.stopCh <- syscall.SIGINT + + wg.Wait() + + err = s.healthCheck(&http.Request{URL: &url.URL{Path: "/healthz", RawQuery: "full=true"}}) + require.Error(t, err, "API Server is terminating and unable to serve requests.") + + assert.True(t, s.terminateRequested.Load()) + assert.False(t, s.available.Load()) + assert.True(t, shutdown) +} + func TestAuthenticate(t *testing.T) { type testData struct { test string diff --git a/test/e2e/graceful_restart_test.go b/test/e2e/graceful_restart_test.go new file mode 100644 index 0000000000000..6f5c6960a4f1d --- /dev/null +++ b/test/e2e/graceful_restart_test.go @@ -0,0 +1,58 @@ +package e2e + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings" + "github.com/argoproj/argo-cd/v2/test/e2e/fixture" + . "github.com/argoproj/argo-cd/v2/test/e2e/fixture" + "github.com/argoproj/argo-cd/v2/util/errors" +) + +func checkHealth(t *testing.T, requireHealthy bool) { + t.Helper() + resp, err := DoHttpRequest("GET", "/healthz?full=true", "") + if requireHealthy { + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + if err != nil { + if !strings.Contains(err.Error(), "connection refused") && !strings.Contains(err.Error(), "connection reset by peer") { + require.NoErrorf(t, err, "If an error returned, it must be about connection refused or reset by peer") + } + } else { + require.Contains(t, []int{http.StatusOK, http.StatusServiceUnavailable}, resp.StatusCode) + } + } +} + +func TestAPIServerGracefulRestart(t *testing.T) { + EnsureCleanState(t) + + // Should be healthy. + checkHealth(t, true) + // Should trigger API server restart. + errors.CheckError(fixture.SetParamInSettingConfigMap("url", "http://test-api-server-graceful-restart")) + + // Wait for ~5 seconds + for i := 0; i < 50; i++ { + checkHealth(t, false) + time.Sleep(100 * time.Millisecond) + } + // One final time, should be healthy, or restart is considered too slow for tests + checkHealth(t, true) + closer, settingsClient, err := ArgoCDClientset.NewSettingsClient() + if closer != nil { + defer closer.Close() + } + require.NoError(t, err) + settings, err := settingsClient.Get(context.Background(), &settings.SettingsQuery{}) + require.NoError(t, err) + require.Equal(t, "http://test-api-server-graceful-restart", settings.URL) +} diff --git a/util/cache/redis.go b/util/cache/redis.go index 5a832fd6ccd45..2c938b6998bb7 100644 --- a/util/cache/redis.go +++ b/util/cache/redis.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "sync" "time" ioutil "github.com/argoproj/argo-cd/v2/util/io" @@ -200,6 +201,11 @@ func (redisHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.Proce } // CollectMetrics add transport wrapper that pushes metrics into the specified metrics registry -func CollectMetrics(client *redis.Client, registry MetricsRegistry) { +// Lock should be shared between functions that can add/process a Redis hook. +func CollectMetrics(client *redis.Client, registry MetricsRegistry, lock *sync.RWMutex) { + if lock != nil { + lock.Lock() + defer lock.Unlock() + } client.AddHook(&redisHook{registry: registry}) } diff --git a/util/cache/redis_test.go b/util/cache/redis_test.go index d60c7ea268e2c..8cda6d8086e74 100644 --- a/util/cache/redis_test.go +++ b/util/cache/redis_test.go @@ -136,8 +136,8 @@ func TestRedisMetrics(t *testing.T) { ms := NewMockMetricsServer() redisClient := redis.NewClient(&redis.Options{Addr: mr.Addr()}) faultyRedisClient := redis.NewClient(&redis.Options{Addr: "invalidredishost.invalid:12345"}) - CollectMetrics(redisClient, ms) - CollectMetrics(faultyRedisClient, ms) + CollectMetrics(redisClient, ms, nil) + CollectMetrics(faultyRedisClient, ms, nil) client := NewRedisCache(redisClient, 60*time.Second, RedisCompressionNone) faultyClient := NewRedisCache(faultyRedisClient, 60*time.Second, RedisCompressionNone) diff --git a/util/session/state.go b/util/session/state.go index b4117c0d1733f..db8eda5020ee3 100644 --- a/util/session/state.go +++ b/util/session/state.go @@ -125,6 +125,10 @@ func (storage *userStateStorage) IsTokenRevoked(id string) bool { return storage.revokedTokens[id] } +func (storage *userStateStorage) GetLockObject() *sync.RWMutex { + return &storage.lock +} + type UserStateStorage interface { Init(ctx context.Context) // GetLoginAttempts return number of concurrent login attempts @@ -135,4 +139,6 @@ type UserStateStorage interface { RevokeToken(ctx context.Context, id string, expiringAt time.Duration) error // IsTokenRevoked checks if given token is revoked IsTokenRevoked(id string) bool + // GetLockObject returns a lock used by the storage + GetLockObject() *sync.RWMutex }