Skip to content

Commit

Permalink
Improve error handling and graceful shutdown in server
Browse files Browse the repository at this point in the history
Added an error channel to the server struct to enhance error management during startup and shutdown. Refactored shutdown logic to handle context cancellation and improve error reporting. Updated tests to validate graceful shutdown and error handling scenarios.
  • Loading branch information
dmitrymomot committed Nov 13, 2024
1 parent bbaf108 commit 90b3f84
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 46 deletions.
51 changes: 37 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Server struct {
httpServer *http.Server
shutdownTimeout time.Duration
log Logger
errCh chan error
}

// Logger is an interface that defines the logging methods used by the server.
Expand Down Expand Up @@ -52,6 +53,7 @@ func New(addr string, handler http.Handler, opt ...serverOption) *Server {
},
shutdownTimeout: 5 * time.Second,
log: slog.Default().With(slog.String("component", "httpserver")),
errCh: make(chan error, 1), // Initialize error channel
}

// Apply options
Expand All @@ -74,6 +76,10 @@ func (s *Server) Start(ctx context.Context) error {
"idle_timeout", s.httpServer.IdleTimeout,
)

// Create a new context for shutdown
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()

g, _ := errgroup.WithContext(ctx)

// Start the server in a new goroutine within the errgroup
Expand All @@ -84,38 +90,55 @@ func (s *Server) Start(ctx context.Context) error {
return nil
})

// Handle shutdown signal in a new goroutine within the errgroup
// Handle shutdown signals
g.Go(func() error {
select {
case <-ctx.Done():
s.log.InfoContext(ctx, "context cancelled, initiating shutdown")
return ctx.Err()
return s.Stop(shutdownCtx, s.shutdownTimeout)
case sig := <-signalChan():
s.log.InfoContext(ctx, "received shutdown signal", "signal", sig.String())
return s.Stop(ctx, s.shutdownTimeout)
return s.Stop(shutdownCtx, s.shutdownTimeout)
}
})

// Wait for all goroutines in the errgroup to finish
err := g.Wait()
if err != nil {
// Wait for all goroutines to complete
if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
s.log.ErrorContext(ctx, "server stopped with error", "error", err)
} else {
s.log.InfoContext(ctx, "server stopped gracefully")
return err
}
return err

s.log.InfoContext(ctx, "server stopped gracefully")
return nil
}

// Stop stops the server gracefully with the given timeout.
// It uses the provided timeout to gracefully shutdown the underlying HTTP server.
// If the timeout is reached before the server is fully stopped, an error is returned.
func (s *Server) Stop(ctx context.Context, timeout time.Duration) error {
s.log.InfoContext(context.Background(), "stopping HTTP server", "timeout", timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeout)

// Create a new context for shutdown with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

if err := s.httpServer.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
// Create an error group for coordinated shutdown
g := new(errgroup.Group)

// Shutdown the HTTP server
g.Go(func() error {
err := s.httpServer.Shutdown(shutdownCtx)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server shutdown error: %w", err)
}
return nil
})

// Wait for shutdown to complete or timeout
if err := g.Wait(); err != nil {
s.log.ErrorContext(ctx, "error during server shutdown", "error", err)
// Force close if graceful shutdown fails
_ = s.Close(ctx)
return err
}

Expand All @@ -126,11 +149,11 @@ func (s *Server) Stop(ctx context.Context, timeout time.Duration) error {
// Close stops the server immediately without waiting for active connections to finish.
// It returns an error if the server fails to stop.
func (s *Server) Close(ctx context.Context) error {
s.log.InfoContext(context.Background(), "force closing HTTP server")
s.log.InfoContext(ctx, "force closing HTTP server")

if err := s.httpServer.Close(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.log.ErrorContext(context.Background(), "error during force close", "error", err)
return err
s.log.ErrorContext(ctx, "error during force close", "error", err)
return fmt.Errorf("server force close error: %w", err)
}
return nil
}
Expand Down
77 changes: 45 additions & 32 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpserver_test

import (
"context"
"errors"
"fmt"
"net/http"
"testing"
Expand All @@ -12,36 +13,48 @@ import (
)

func TestServer(t *testing.T) {
listenAddr := "localhost:9999"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, World!") })
server := httpserver.New(listenAddr, handler)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Run the server in a separate goroutine
go func() {
if err := server.Start(ctx); err != nil {
require.NoError(t, err, "Unexpected error in server start")
}
}()

// Wait for the server to start
time.Sleep(500 * time.Millisecond)

// Perform an HTTP request to the server
resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr))
require.NoError(t, err, "Unexpected error in GET request")
require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code")

// Shutdown the server
require.NoError(t, server.Stop(ctx, 1*time.Second), "Unexpected error in server shutdown")

// Wait for the server to shut down
// time.Sleep(1 * time.Second)

// Perform an HTTP request to the server after it has shut down
_, err = http.Get(fmt.Sprintf("http://%s", listenAddr))
require.Error(t, err, "Expected error after server shutdown")
listenAddr := "localhost:9999"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, World!")
})
server := httpserver.New(listenAddr, handler)

// Create a context with cancel for server control
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Channel to catch server errors
serverErr := make(chan error, 1)

// Start the server in a goroutine
go func() {
serverErr <- server.Start(ctx)
}()

// Wait for the server to start
time.Sleep(500 * time.Millisecond)

// Test server response
resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr))
require.NoError(t, err, "Unexpected error in GET request")
require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code")
resp.Body.Close()

// Initiate graceful shutdown
cancel()

// Wait for server to shut down with timeout
shutdownTimeout := time.After(5 * time.Second)
select {
case err := <-serverErr:
require.True(t, err == nil || errors.Is(err, context.Canceled),
"Expected nil or context.Canceled error, got: %v", err)
case <-shutdownTimeout:
t.Fatal("Server shutdown timed out")
}

// Verify server is no longer accepting connections
_, err = http.Get(fmt.Sprintf("http://%s", listenAddr))
require.Error(t, err, "Expected error after server shutdown")
}

0 comments on commit 90b3f84

Please sign in to comment.