From 90b3f84d3b463fc6ef4eb67ee32ed059ec57a240 Mon Sep 17 00:00:00 2001 From: Dmytro Momot Date: Thu, 14 Nov 2024 00:08:32 +0200 Subject: [PATCH] Improve error handling and graceful shutdown in server Added an error channel to the server struct to enhance error management during startup and shutdown. Refactored shutdown logic to handle context cancellation and improve error reporting. Updated tests to validate graceful shutdown and error handling scenarios. --- server.go | 51 ++++++++++++++++++++++++--------- server_test.go | 77 +++++++++++++++++++++++++++++--------------------- 2 files changed, 82 insertions(+), 46 deletions(-) diff --git a/server.go b/server.go index 77f6e15..650bad2 100644 --- a/server.go +++ b/server.go @@ -22,6 +22,7 @@ type Server struct { httpServer *http.Server shutdownTimeout time.Duration log Logger + errCh chan error } // Logger is an interface that defines the logging methods used by the server. @@ -52,6 +53,7 @@ func New(addr string, handler http.Handler, opt ...serverOption) *Server { }, shutdownTimeout: 5 * time.Second, log: slog.Default().With(slog.String("component", "httpserver")), + errCh: make(chan error, 1), // Initialize error channel } // Apply options @@ -74,6 +76,10 @@ func (s *Server) Start(ctx context.Context) error { "idle_timeout", s.httpServer.IdleTimeout, ) + // Create a new context for shutdown + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + g, _ := errgroup.WithContext(ctx) // Start the server in a new goroutine within the errgroup @@ -84,26 +90,26 @@ func (s *Server) Start(ctx context.Context) error { return nil }) - // Handle shutdown signal in a new goroutine within the errgroup + // Handle shutdown signals g.Go(func() error { select { case <-ctx.Done(): s.log.InfoContext(ctx, "context cancelled, initiating shutdown") - return ctx.Err() + return s.Stop(shutdownCtx, s.shutdownTimeout) case sig := <-signalChan(): s.log.InfoContext(ctx, "received shutdown signal", "signal", sig.String()) - return s.Stop(ctx, s.shutdownTimeout) + return s.Stop(shutdownCtx, s.shutdownTimeout) } }) - // Wait for all goroutines in the errgroup to finish - err := g.Wait() - if err != nil { + // Wait for all goroutines to complete + if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) { s.log.ErrorContext(ctx, "server stopped with error", "error", err) - } else { - s.log.InfoContext(ctx, "server stopped gracefully") + return err } - return err + + s.log.InfoContext(ctx, "server stopped gracefully") + return nil } // Stop stops the server gracefully with the given timeout. @@ -111,11 +117,28 @@ func (s *Server) Start(ctx context.Context) error { // If the timeout is reached before the server is fully stopped, an error is returned. func (s *Server) Stop(ctx context.Context, timeout time.Duration) error { s.log.InfoContext(context.Background(), "stopping HTTP server", "timeout", timeout) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + + // Create a new context for shutdown with timeout + shutdownCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - if err := s.httpServer.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + // Create an error group for coordinated shutdown + g := new(errgroup.Group) + + // Shutdown the HTTP server + g.Go(func() error { + err := s.httpServer.Shutdown(shutdownCtx) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("server shutdown error: %w", err) + } + return nil + }) + + // Wait for shutdown to complete or timeout + if err := g.Wait(); err != nil { s.log.ErrorContext(ctx, "error during server shutdown", "error", err) + // Force close if graceful shutdown fails + _ = s.Close(ctx) return err } @@ -126,11 +149,11 @@ func (s *Server) Stop(ctx context.Context, timeout time.Duration) error { // Close stops the server immediately without waiting for active connections to finish. // It returns an error if the server fails to stop. func (s *Server) Close(ctx context.Context) error { - s.log.InfoContext(context.Background(), "force closing HTTP server") + s.log.InfoContext(ctx, "force closing HTTP server") if err := s.httpServer.Close(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.log.ErrorContext(context.Background(), "error during force close", "error", err) - return err + s.log.ErrorContext(ctx, "error during force close", "error", err) + return fmt.Errorf("server force close error: %w", err) } return nil } diff --git a/server_test.go b/server_test.go index 845bf47..712e849 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package httpserver_test import ( "context" + "errors" "fmt" "net/http" "testing" @@ -12,36 +13,48 @@ import ( ) func TestServer(t *testing.T) { - listenAddr := "localhost:9999" - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, World!") }) - server := httpserver.New(listenAddr, handler) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Run the server in a separate goroutine - go func() { - if err := server.Start(ctx); err != nil { - require.NoError(t, err, "Unexpected error in server start") - } - }() - - // Wait for the server to start - time.Sleep(500 * time.Millisecond) - - // Perform an HTTP request to the server - resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr)) - require.NoError(t, err, "Unexpected error in GET request") - require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code") - - // Shutdown the server - require.NoError(t, server.Stop(ctx, 1*time.Second), "Unexpected error in server shutdown") - - // Wait for the server to shut down - // time.Sleep(1 * time.Second) - - // Perform an HTTP request to the server after it has shut down - _, err = http.Get(fmt.Sprintf("http://%s", listenAddr)) - require.Error(t, err, "Expected error after server shutdown") + listenAddr := "localhost:9999" + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, World!") + }) + server := httpserver.New(listenAddr, handler) + + // Create a context with cancel for server control + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Channel to catch server errors + serverErr := make(chan error, 1) + + // Start the server in a goroutine + go func() { + serverErr <- server.Start(ctx) + }() + + // Wait for the server to start + time.Sleep(500 * time.Millisecond) + + // Test server response + resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr)) + require.NoError(t, err, "Unexpected error in GET request") + require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code") + resp.Body.Close() + + // Initiate graceful shutdown + cancel() + + // Wait for server to shut down with timeout + shutdownTimeout := time.After(5 * time.Second) + select { + case err := <-serverErr: + require.True(t, err == nil || errors.Is(err, context.Canceled), + "Expected nil or context.Canceled error, got: %v", err) + case <-shutdownTimeout: + t.Fatal("Server shutdown timed out") + } + + // Verify server is no longer accepting connections + _, err = http.Get(fmt.Sprintf("http://%s", listenAddr)) + require.Error(t, err, "Expected error after server shutdown") }