From 2b581d6ad4ca175fa346dc66ccc64278f6a408ea Mon Sep 17 00:00:00 2001 From: Elijah Roussos Date: Fri, 7 Feb 2025 10:46:36 -0500 Subject: [PATCH] ensure websockets persists until done on drain add e2e for ws beyond queue drain; move sleep to appropriate loc add ref to go issue separate drain test --- pkg/queue/sharedmain/handlers.go | 13 ++++++++++ pkg/queue/sharedmain/main.go | 21 ++++++++++++++-- test/e2e/websocket_test.go | 43 ++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/pkg/queue/sharedmain/handlers.go b/pkg/queue/sharedmain/handlers.go index 8f8b56d22611..2dcee1317d9a 100644 --- a/pkg/queue/sharedmain/handlers.go +++ b/pkg/queue/sharedmain/handlers.go @@ -20,6 +20,7 @@ import ( "context" "net" "net/http" + "sync/atomic" "time" "go.uber.org/zap" @@ -43,6 +44,7 @@ func mainHandler( prober func() bool, stats *netstats.RequestStats, logger *zap.SugaredLogger, + pendingRequests *atomic.Int32, ) (http.Handler, *pkghandler.Drainer) { target := net.JoinHostPort("127.0.0.1", env.UserPort) @@ -86,6 +88,8 @@ func mainHandler( composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger) + composedHandler = withRequestCounter(composedHandler, pendingRequests) + drainer := &pkghandler.Drainer{ QuietPeriod: drainSleepDuration, // Add Activator probe header to the drainer so it can handle probes directly from activator @@ -100,6 +104,7 @@ func mainHandler( // Hence we need to have RequestLogHandler be the first one. composedHandler = requestLogHandler(logger, composedHandler, env) } + return composedHandler, drainer } @@ -139,3 +144,11 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo h.ServeHTTP(w, r) }) } + +func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pendingRequests.Add(1) + defer pendingRequests.Add(-1) + h.ServeHTTP(w, r) + }) +} diff --git a/pkg/queue/sharedmain/main.go b/pkg/queue/sharedmain/main.go index b8d035c1081e..f128542a35dc 100644 --- a/pkg/queue/sharedmain/main.go +++ b/pkg/queue/sharedmain/main.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strconv" + "sync/atomic" "time" "github.com/kelseyhightower/envconfig" @@ -169,6 +170,8 @@ func Main(opts ...Option) error { d := Defaults{ Ctx: signals.NewContext(), } + pendingRequests := atomic.Int32{} + pendingRequests.Store(0) // Parse the environment. var env config @@ -234,7 +237,7 @@ func Main(opts ...Option) error { // Enable TLS when certificate is mounted. tlsEnabled := exists(logger, certPath) && exists(logger, keyPath) - mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger) + mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger, &pendingRequests) adminHandler := adminHandler(d.Ctx, logger, drainer) // Enable TLS server when activator server certs are mounted. @@ -303,9 +306,23 @@ func Main(opts ...Option) error { return err case <-d.Ctx.Done(): logger.Info("Received TERM signal, attempting to gracefully shutdown servers.") - logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration) drainer.Drain() + // Wait on active requests to complete. This is done explictly + // to avoid closing any connections which have been highjacked, + // as in net/http `.Shutdown` would do so ungracefully. + // See https://github.com/golang/go/issues/17721 + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load()) + WaitOnPendingRequests: + for range ticker.C { + if pendingRequests.Load() <= 0 { + logger.Infof("Drain: all pending requests completed") + break WaitOnPendingRequests + } + } + for name, srv := range httpServers { logger.Info("Shutting down server: ", name) if err := srv.Shutdown(context.Background()); err != nil { diff --git a/test/e2e/websocket_test.go b/test/e2e/websocket_test.go index ec7be61ac064..82ed46dde1f0 100644 --- a/test/e2e/websocket_test.go +++ b/test/e2e/websocket_test.go @@ -322,6 +322,11 @@ func TestWebSocketWithTimeout(t *testing.T) { idleTimeoutSeconds: 10, delay: "20", expectError: true, + }, { + name: "websocket does not drop after queue drain is called at 30s", + timeoutSeconds: 60, + delay: "45", + expectError: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -349,6 +354,44 @@ func TestWebSocketWithTimeout(t *testing.T) { } } +func TestWebSocketDrain(t *testing.T) { + clients := Setup(t) + + testCases := []struct { + name string + timeoutSeconds int64 + delay string + expectError bool + }{{ + name: "websocket does not drop after queue drain is called", + timeoutSeconds: 60, + delay: "45", + expectError: false, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + names := test.ResourceNames{ + Service: test.ObjectNameForTest(t), + Image: wsServerTestImageName, + } + + // Clean up in both abnormal and normal exits. + test.EnsureTearDown(t, clients, &names) + + _, err := v1test.CreateServiceReady(t, clients, &names, + rtesting.WithRevisionTimeoutSeconds(tc.timeoutSeconds), + if err != nil { + t.Fatal("Failed to create WebSocket server:", err) + } + // Validate the websocket connection. + err = ValidateWebSocketConnection(t, clients, names, tc.delay) + if (err == nil && tc.expectError) || (err != nil && !tc.expectError) { + t.Error(err) + } + }) + } +} + func abs(a int) int { if a < 0 { return -a